mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-22 21:47:16 +02:00
proxy: move properties to atomically updated state (#1280)
* authenticate: remove cookie options * authenticate: remove shared key field * authenticate: remove shared cipher property * authenticate: move properties to separate state struct * proxy: allow local state to be updated on configuration changes * fix test * return new connection * use warn, collapse to single line * address concerns, fix tests
This commit is contained in:
parent
23eea09ed0
commit
d9a224a5e8
12 changed files with 305 additions and 147 deletions
|
@ -177,7 +177,7 @@ func setupProxy(src config.Source, cfg *config.Config, controlPlane *controlplan
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := proxy.New(cfg.Options)
|
svc, err := proxy.New(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating proxy service: %w", err)
|
return fmt.Errorf("error creating proxy service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,10 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/balancer/roundrobin"
|
"google.golang.org/grpc/balancer/roundrobin"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
@ -144,3 +146,46 @@ func grpcTimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
|
||||||
return invoker(ctx, method, req, reply, cc, opts...)
|
return invoker(ctx, method, req, reply, cc, opts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type grpcClientConnRecord struct {
|
||||||
|
conn *grpc.ClientConn
|
||||||
|
opts *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
var grpcClientConns = struct {
|
||||||
|
sync.Mutex
|
||||||
|
m map[string]grpcClientConnRecord
|
||||||
|
}{
|
||||||
|
m: make(map[string]grpcClientConnRecord),
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGRPCClientConn returns a gRPC client connection for the given name. If a connection for that name has already been
|
||||||
|
// established the existing connection will be returned. If any options change for that connection, the existing
|
||||||
|
// connection will be closed and a new one established.
|
||||||
|
func GetGRPCClientConn(name string, opts *Options) (*grpc.ClientConn, error) {
|
||||||
|
grpcClientConns.Lock()
|
||||||
|
defer grpcClientConns.Unlock()
|
||||||
|
|
||||||
|
current, ok := grpcClientConns.m[name]
|
||||||
|
if ok {
|
||||||
|
if cmp.Equal(current.opts, opts) {
|
||||||
|
return current.conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := current.conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("grpc: failed to close existing connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cc, err := NewGRPCClientConn(opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcClientConns.m[name] = grpcClientConnRecord{
|
||||||
|
conn: cc,
|
||||||
|
opts: opts,
|
||||||
|
}
|
||||||
|
return cc, nil
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -77,3 +78,39 @@ func TestNewGRPC(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetGRPC(t *testing.T) {
|
||||||
|
cc1, err := GetGRPCClientConn("example", &Options{
|
||||||
|
Addr: mustParseURL("https://localhost.example"),
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cc2, err := GetGRPCClientConn("example", &Options{
|
||||||
|
Addr: mustParseURL("https://localhost.example"),
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, cc1, cc2, "GetGRPCClientConn should return the same connection when there are no changes")
|
||||||
|
|
||||||
|
cc3, err := GetGRPCClientConn("example", &Options{
|
||||||
|
Addr: mustParseURL("http://localhost.example"),
|
||||||
|
WithInsecure: true,
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotEqual(t, cc1, cc3, "GetGRPCClientConn should return a new connection when there are changes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseURL(rawurl string) *url.URL {
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,9 @@ import (
|
||||||
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
||||||
r := httputil.NewRouter()
|
r := httputil.NewRouter()
|
||||||
r.StrictSlash(true)
|
r.StrictSlash(true)
|
||||||
r.Use(sessions.RetrieveSession(p.sessionStore))
|
r.Use(func(h http.Handler) http.Handler {
|
||||||
|
return sessions.RetrieveSession(p.state.Load().sessionStore)(h)
|
||||||
|
})
|
||||||
r.Use(p.jwtClaimMiddleware(true))
|
r.Use(p.jwtClaimMiddleware(true))
|
||||||
|
|
||||||
// NGNIX's forward-auth capabilities are split across two settings:
|
// NGNIX's forward-auth capabilities are split across two settings:
|
||||||
|
@ -96,6 +98,8 @@ func (p *Proxy) forwardedURIHeaderCallback(w http.ResponseWriter, r *http.Reques
|
||||||
// provider. If the user is unauthorized, a `401` error is returned.
|
// provider. If the user is unauthorized, a `401` error is returned.
|
||||||
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if status := r.FormValue("auth_status"); status == fmt.Sprint(http.StatusForbidden) {
|
if status := r.FormValue("auth_status"); status == fmt.Sprint(http.StatusForbidden) {
|
||||||
return httputil.NewError(http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden)))
|
return httputil.NewError(http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden)))
|
||||||
|
@ -120,7 +124,7 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
||||||
|
|
||||||
unAuthenticated := ar.statusCode == http.StatusUnauthorized
|
unAuthenticated := ar.statusCode == http.StatusUnauthorized
|
||||||
if unAuthenticated {
|
if unAuthenticated {
|
||||||
p.sessionStore.ClearSession(w, r)
|
state.sessionStore.ClearSession(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = sessions.FromContext(r.Context())
|
_, err = sessions.FromContext(r.Context())
|
||||||
|
@ -141,6 +145,8 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
||||||
// forwardAuthRedirectToSignInWithURI redirects request to authenticate signin url,
|
// forwardAuthRedirectToSignInWithURI redirects request to authenticate signin url,
|
||||||
// with all necessary information extracted from given input uri.
|
// with all necessary information extracted from given input uri.
|
||||||
func (p *Proxy) forwardAuthRedirectToSignInWithURI(w http.ResponseWriter, r *http.Request, uri *url.URL) {
|
func (p *Proxy) forwardAuthRedirectToSignInWithURI(w http.ResponseWriter, r *http.Request, uri *url.URL) {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
// Traefik set the uri in the header, we must set it in redirect uri if present. Otherwise, request like
|
// Traefik set the uri in the header, we must set it in redirect uri if present. Otherwise, request like
|
||||||
// https://example.com/foo will be redirected to https://example.com after authentication.
|
// https://example.com/foo will be redirected to https://example.com after authentication.
|
||||||
if xfu := r.Header.Get(httputil.HeaderForwardedURI); xfu != "/" {
|
if xfu := r.Header.Get(httputil.HeaderForwardedURI); xfu != "/" {
|
||||||
|
@ -148,13 +154,13 @@ func (p *Proxy) forwardAuthRedirectToSignInWithURI(w http.ResponseWriter, r *htt
|
||||||
}
|
}
|
||||||
|
|
||||||
// redirect to authenticate
|
// redirect to authenticate
|
||||||
authN := *p.authenticateSigninURL
|
authN := *state.authenticateSigninURL
|
||||||
q := authN.Query()
|
q := authN.Query()
|
||||||
q.Set(urlutil.QueryCallbackURI, uri.String())
|
q.Set(urlutil.QueryCallbackURI, uri.String())
|
||||||
q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination
|
q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination
|
||||||
q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience
|
q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience
|
||||||
authN.RawQuery = q.Encode()
|
authN.RawQuery = q.Encode()
|
||||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound)
|
httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &authN).String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getURIStringFromRequest(r *http.Request) (*url.URL, error) {
|
func getURIStringFromRequest(r *http.Request) (*url.URL, error) {
|
||||||
|
|
|
@ -83,18 +83,19 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(tt.options)
|
p, err := New(&config.Config{Options: tt.options})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.authzClient = tt.authorizer
|
p.OnConfigChange(&config.Config{Options: tt.options})
|
||||||
p.sessionStore = tt.sessionStore
|
state := p.state.Load()
|
||||||
|
state.authzClient = tt.authorizer
|
||||||
|
state.sessionStore = tt.sessionStore
|
||||||
signer, err := jws.NewHS256Signer(nil, "mock")
|
signer, err := jws.NewHS256Signer(nil, "mock")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.encoder = signer
|
state.encoder = signer
|
||||||
p.OnConfigChange(&config.Config{Options: tt.options})
|
|
||||||
uri, err := url.Parse(tt.requestURI)
|
uri, err := url.Parse(tt.requestURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -110,10 +111,10 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
||||||
uri.RawQuery = queryString.Encode()
|
uri.RawQuery = queryString.Encode()
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||||
state, _ := tt.sessionStore.LoadSession(r)
|
ss, _ := tt.sessionStore.LoadSession(r)
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
ctx = sessions.NewContext(ctx, ss, tt.ctxError)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
if len(tt.headers) != 0 {
|
if len(tt.headers) != 0 {
|
||||||
|
|
|
@ -21,14 +21,17 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
h := r.PathPrefix(dashboardPath).Subrouter()
|
h := r.PathPrefix(dashboardPath).Subrouter()
|
||||||
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||||
// 1. Retrieve the user session and add it to the request context
|
// 1. Retrieve the user session and add it to the request context
|
||||||
h.Use(sessions.RetrieveSession(p.sessionStore))
|
h.Use(func(h http.Handler) http.Handler {
|
||||||
|
return sessions.RetrieveSession(p.state.Load().sessionStore)(h)
|
||||||
|
})
|
||||||
// 2. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
// 2. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
||||||
h.Use(p.AuthenticateSession)
|
h.Use(p.AuthenticateSession)
|
||||||
// 3. Enforce CSRF protections for any non-idempotent http method
|
// 3. Enforce CSRF protections for any non-idempotent http method
|
||||||
h.Use(func(h http.Handler) http.Handler {
|
h.Use(func(h http.Handler) http.Handler {
|
||||||
opts := p.currentOptions.Load()
|
opts := p.currentOptions.Load()
|
||||||
|
state := p.state.Load()
|
||||||
return csrf.Protect(
|
return csrf.Protect(
|
||||||
p.cookieSecret,
|
state.cookieSecret,
|
||||||
csrf.Secure(opts.CookieSecure),
|
csrf.Secure(opts.CookieSecure),
|
||||||
csrf.CookieName(fmt.Sprintf("%s_csrf", opts.CookieName)),
|
csrf.CookieName(fmt.Sprintf("%s_csrf", opts.CookieName)),
|
||||||
csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)),
|
csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||||
|
@ -42,7 +45,9 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
// callback used to set route-scoped session and redirect back to destination
|
// callback used to set route-scoped session and redirect back to destination
|
||||||
// only accept signed requests (hmac) from other trusted pomerium services
|
// only accept signed requests (hmac) from other trusted pomerium services
|
||||||
c := r.PathPrefix(dashboardPath + "/callback").Subrouter()
|
c := r.PathPrefix(dashboardPath + "/callback").Subrouter()
|
||||||
c.Use(middleware.ValidateSignature(p.SharedKey))
|
h.Use(func(h http.Handler) http.Handler {
|
||||||
|
return middleware.ValidateSignature(p.state.Load().sharedKey)(h)
|
||||||
|
})
|
||||||
|
|
||||||
c.Path("/").
|
c.Path("/").
|
||||||
Handler(httputil.HandlerFunc(p.ProgrammaticCallback)).
|
Handler(httputil.HandlerFunc(p.ProgrammaticCallback)).
|
||||||
|
@ -71,28 +76,32 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
// of the authenticate service to revoke the remote session and clear
|
// of the authenticate service to revoke the remote session and clear
|
||||||
// the local session state.
|
// the local session state.
|
||||||
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||||
if uri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)); err == nil && uri.String() != "" {
|
if uri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)); err == nil && uri.String() != "" {
|
||||||
redirectURL = uri
|
redirectURL = uri
|
||||||
}
|
}
|
||||||
|
|
||||||
signoutURL := *p.authenticateSignoutURL
|
signoutURL := *state.authenticateSignoutURL
|
||||||
q := signoutURL.Query()
|
q := signoutURL.Query()
|
||||||
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
||||||
signoutURL.RawQuery = q.Encode()
|
signoutURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
p.sessionStore.ClearSession(w, r)
|
state.sessionStore.ClearSession(w, r)
|
||||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signoutURL).String(), http.StatusFound)
|
httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signoutURL).String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserDashboard redirects to the authenticate dasbhoard.
|
// UserDashboard redirects to the authenticate dasbhoard.
|
||||||
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
redirectURL := urlutil.GetAbsoluteURL(r).String()
|
redirectURL := urlutil.GetAbsoluteURL(r).String()
|
||||||
if ref := r.Header.Get(httputil.HeaderReferrer); ref != "" {
|
if ref := r.Header.Get(httputil.HeaderReferrer); ref != "" {
|
||||||
redirectURL = ref
|
redirectURL = ref
|
||||||
}
|
}
|
||||||
|
|
||||||
url := p.authenticateDashboardURL.ResolveReference(&url.URL{
|
url := state.authenticateDashboardURL.ResolveReference(&url.URL{
|
||||||
RawQuery: url.Values{
|
RawQuery: url.Values{
|
||||||
urlutil.QueryRedirectURI: {redirectURL},
|
urlutil.QueryRedirectURI: {redirectURL},
|
||||||
}.Encode(),
|
}.Encode(),
|
||||||
|
@ -116,18 +125,20 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||||
// saveCallbackSession takes an encrypted per-route session token, and decrypts
|
// saveCallbackSession takes an encrypted per-route session token, and decrypts
|
||||||
// it using the shared service key, then stores it the local session store.
|
// it using the shared service key, then stores it the local session store.
|
||||||
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
|
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
// 1. extract the base64 encoded and encrypted JWT from query params
|
// 1. extract the base64 encoded and encrypted JWT from query params
|
||||||
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
|
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
|
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
|
||||||
}
|
}
|
||||||
// 2. decrypt the JWT using the cipher using the _shared_ secret key
|
// 2. decrypt the JWT using the cipher using the _shared_ secret key
|
||||||
rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil)
|
rawJWT, err := cryptutil.Decrypt(state.sharedCipher, encryptedJWT, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
|
return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
|
||||||
}
|
}
|
||||||
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
|
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
|
||||||
if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||||
return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
|
return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
|
||||||
}
|
}
|
||||||
return rawJWT, nil
|
return rawJWT, nil
|
||||||
|
@ -136,11 +147,13 @@ func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enct
|
||||||
// ProgrammaticLogin returns a signed url that can be used to login
|
// ProgrammaticLogin returns a signed url that can be used to login
|
||||||
// using the authenticate service.
|
// using the authenticate service.
|
||||||
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
|
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
signinURL := *p.authenticateSigninURL
|
signinURL := *state.authenticateSigninURL
|
||||||
callbackURI := urlutil.GetAbsoluteURL(r)
|
callbackURI := urlutil.GetAbsoluteURL(r)
|
||||||
callbackURI.Path = dashboardPath + "/callback/"
|
callbackURI.Path = dashboardPath + "/callback/"
|
||||||
q := signinURL.Query()
|
q := signinURL.Query()
|
||||||
|
@ -148,7 +161,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
|
||||||
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||||
q.Set(urlutil.QueryIsProgrammatic, "true")
|
q.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
signinURL.RawQuery = q.Encode()
|
signinURL.RawQuery = q.Encode()
|
||||||
response := urlutil.NewSignedURL(p.SharedKey, &signinURL).String()
|
response := urlutil.NewSignedURL(state.sharedKey, &signinURL).String()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
|
@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
proxy, err := New(opts)
|
proxy, err := New(&config.Config{Options: opts})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ func TestProxy_Signout(t *testing.T) {
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
||||||
}
|
}
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
want := (proxy.authenticateURL.String())
|
want := proxy.state.Load().authenticateURL.String()
|
||||||
if !strings.Contains(body, want) {
|
if !strings.Contains(body, want) {
|
||||||
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
|
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
p, err := New(opts)
|
p, err := New(&config.Config{Options: opts})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -221,13 +221,14 @@ func TestProxy_Callback(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(tt.options)
|
p, err := New(&config.Config{Options: tt.options})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.encoder = tt.cipher
|
|
||||||
p.sessionStore = tt.sessionStore
|
|
||||||
p.OnConfigChange(&config.Config{Options: tt.options})
|
p.OnConfigChange(&config.Config{Options: tt.options})
|
||||||
|
state := p.state.Load()
|
||||||
|
state.encoder = tt.cipher
|
||||||
|
state.sessionStore = tt.sessionStore
|
||||||
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
|
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
|
||||||
queryString := redirectURI.Query()
|
queryString := redirectURI.Query()
|
||||||
for k, v := range tt.qp {
|
for k, v := range tt.qp {
|
||||||
|
@ -297,7 +298,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(tt.options)
|
p, err := New(&config.Config{Options: tt.options})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -428,13 +429,14 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(tt.options)
|
p, err := New(&config.Config{Options: tt.options})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.encoder = tt.cipher
|
|
||||||
p.sessionStore = tt.sessionStore
|
|
||||||
p.OnConfigChange(&config.Config{Options: tt.options})
|
p.OnConfigChange(&config.Config{Options: tt.options})
|
||||||
|
state := p.state.Load()
|
||||||
|
state.encoder = tt.cipher
|
||||||
|
state.sessionStore = tt.sessionStore
|
||||||
redirectURI, _ := url.Parse(tt.redirectURI)
|
redirectURI, _ := url.Parse(tt.redirectURI)
|
||||||
queryString := redirectURI.Query()
|
queryString := redirectURI.Query()
|
||||||
for k, v := range tt.qp {
|
for k, v := range tt.qp {
|
||||||
|
|
|
@ -40,17 +40,21 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
|
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
|
||||||
signinURL := *p.authenticateSigninURL
|
state := p.state.Load()
|
||||||
|
|
||||||
|
signinURL := *state.authenticateSigninURL
|
||||||
q := signinURL.Query()
|
q := signinURL.Query()
|
||||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||||
signinURL.RawQuery = q.Encode()
|
signinURL.RawQuery = q.Encode()
|
||||||
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
|
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
|
||||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signinURL).String(), http.StatusFound)
|
||||||
p.sessionStore.ClearSession(w, r)
|
state.sessionStore.ClearSession(w, r)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) {
|
func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
tm, err := ptypes.TimestampProto(time.Now())
|
tm, err := ptypes.TimestampProto(time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httputil.NewError(http.StatusInternalServerError, fmt.Errorf("error creating protobuf timestamp from current time: %w", err))
|
return nil, httputil.NewError(http.StatusInternalServerError, fmt.Errorf("error creating protobuf timestamp from current time: %w", err))
|
||||||
|
@ -72,7 +76,7 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorize
|
||||||
httpAttrs.Path += "?" + r.URL.RawQuery
|
httpAttrs.Path += "?" + r.URL.RawQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := p.authzClient.Check(r.Context(), &envoy_service_auth_v2.CheckRequest{
|
res, err := state.authzClient.Check(r.Context(), &envoy_service_auth_v2.CheckRequest{
|
||||||
Attributes: &envoy_service_auth_v2.AttributeContext{
|
Attributes: &envoy_service_auth_v2.AttributeContext{
|
||||||
Request: &envoy_service_auth_v2.AttributeContext_Request{
|
Request: &envoy_service_auth_v2.AttributeContext_Request{
|
||||||
Time: tm,
|
Time: tm,
|
||||||
|
@ -118,12 +122,12 @@ func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.
|
||||||
//
|
//
|
||||||
// if returnJWTInfo is set to true, it will also return JWT claim information in the response
|
// if returnJWTInfo is set to true, it will also return JWT claim information in the response
|
||||||
func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc {
|
func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc {
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
defer next.ServeHTTP(w, r)
|
defer next.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
jwt, err := sessions.FromContext(r.Context())
|
jwt, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("proxy: could not locate session from context")
|
log.Error().Err(err).Msg("proxy: could not locate session from context")
|
||||||
|
@ -147,7 +151,7 @@ func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// set headers for any claims specified by config
|
// set headers for any claims specified by config
|
||||||
for _, claimName := range p.jwtClaimHeaders {
|
for _, claimName := range state.jwtClaimHeaders {
|
||||||
if _, ok := formattedJWTClaims[claimName]; ok {
|
if _, ok := formattedJWTClaims[claimName]; ok {
|
||||||
|
|
||||||
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
|
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
|
||||||
|
@ -165,10 +169,12 @@ func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc {
|
||||||
|
|
||||||
// getFormatJWTClaims reformats jwtClaims into something resembling map[string]string
|
// getFormatJWTClaims reformats jwtClaims into something resembling map[string]string
|
||||||
func (p *Proxy) getFormatedJWTClaims(jwt []byte) (map[string]string, error) {
|
func (p *Proxy) getFormatedJWTClaims(jwt []byte) (map[string]string, error) {
|
||||||
|
state := p.state.Load()
|
||||||
|
|
||||||
formattedJWTClaims := make(map[string]string)
|
formattedJWTClaims := make(map[string]string)
|
||||||
|
|
||||||
var jwtClaims map[string]interface{}
|
var jwtClaims map[string]interface{}
|
||||||
if err := p.encoder.Unmarshal(jwt, &jwtClaims); err != nil {
|
if err := state.encoder.Unmarshal(jwt, &jwtClaims); err != nil {
|
||||||
return formattedJWTClaims, err
|
return formattedJWTClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,13 +57,15 @@ func TestProxy_AuthenticateSession(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a := Proxy{
|
a := Proxy{
|
||||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
state: newAtomicProxyState(&proxyState{
|
||||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
sharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||||
authenticateRefreshURL: uriParseHelper(rURL),
|
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||||
sessionStore: tt.session,
|
authenticateRefreshURL: uriParseHelper(rURL),
|
||||||
encoder: tt.encoder,
|
sessionStore: tt.session,
|
||||||
|
encoder: tt.encoder,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
state, _ := tt.session.LoadSession(r)
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
@ -95,10 +97,12 @@ func Test_jwtClaimMiddleware(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a := Proxy{
|
a := Proxy{
|
||||||
SharedKey: sharedKey,
|
state: newAtomicProxyState(&proxyState{
|
||||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
sharedKey: sharedKey,
|
||||||
encoder: encoder,
|
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||||
jwtClaimHeaders: claimHeaders,
|
encoder: encoder,
|
||||||
|
jwtClaimHeaders: claimHeaders,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
106
proxy/proxy.go
106
proxy/proxy.go
|
@ -5,32 +5,20 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
|
|
||||||
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -64,102 +52,29 @@ func ValidateOptions(o *config.Options) error {
|
||||||
|
|
||||||
// Proxy stores all the information associated with proxying a request.
|
// Proxy stores all the information associated with proxying a request.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
// SharedKey used to mutually authenticate service communication
|
templates *template.Template
|
||||||
SharedKey string
|
state *atomicProxyState
|
||||||
sharedCipher cipher.AEAD
|
|
||||||
|
|
||||||
authorizeURL *url.URL
|
|
||||||
authenticateURL *url.URL
|
|
||||||
authenticateDashboardURL *url.URL
|
|
||||||
authenticateSigninURL *url.URL
|
|
||||||
authenticateSignoutURL *url.URL
|
|
||||||
authenticateRefreshURL *url.URL
|
|
||||||
|
|
||||||
encoder encoding.Unmarshaler
|
|
||||||
cookieSecret []byte
|
|
||||||
refreshCooldown time.Duration
|
|
||||||
sessionStore sessions.SessionStore
|
|
||||||
sessionLoaders []sessions.SessionLoader
|
|
||||||
templates *template.Template
|
|
||||||
jwtClaimHeaders []string
|
|
||||||
authzClient envoy_service_auth_v2.AuthorizationClient
|
|
||||||
|
|
||||||
currentOptions *config.AtomicOptions
|
currentOptions *config.AtomicOptions
|
||||||
currentRouter atomic.Value
|
currentRouter atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// New takes a Proxy service from options and a validation function.
|
// New takes a Proxy service from options and a validation function.
|
||||||
// Function returns an error if options fail to validate.
|
// Function returns an error if options fail to validate.
|
||||||
func New(opts *config.Options) (*Proxy, error) {
|
func New(cfg *config.Config) (*Proxy, error) {
|
||||||
if err := ValidateOptions(opts); err != nil {
|
state, err := newProxyStateFromConfig(cfg)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
|
|
||||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
|
||||||
|
|
||||||
// used to load and verify JWT tokens signed by the authenticate service
|
|
||||||
encoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.GetAuthenticateURL().Host)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &Proxy{
|
p := &Proxy{
|
||||||
SharedKey: opts.SharedKey,
|
templates: template.Must(frontend.NewTemplates()),
|
||||||
sharedCipher: sharedCipher,
|
state: newAtomicProxyState(state),
|
||||||
encoder: encoder,
|
currentOptions: config.NewAtomicOptions(),
|
||||||
|
|
||||||
cookieSecret: decodedCookieSecret,
|
|
||||||
refreshCooldown: opts.RefreshCooldown,
|
|
||||||
templates: template.Must(frontend.NewTemplates()),
|
|
||||||
jwtClaimHeaders: opts.JWTClaimsHeaders,
|
|
||||||
currentOptions: config.NewAtomicOptions(),
|
|
||||||
}
|
}
|
||||||
p.currentRouter.Store(httputil.NewRouter())
|
p.currentRouter.Store(httputil.NewRouter())
|
||||||
// errors checked in ValidateOptions
|
|
||||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
|
||||||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
|
||||||
p.authenticateDashboardURL = p.authenticateURL.ResolveReference(&url.URL{Path: dashboardPath})
|
|
||||||
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
|
||||||
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
|
||||||
p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
|
|
||||||
|
|
||||||
cookieStore, err := cookie.NewStore(func() cookie.Options {
|
|
||||||
opts := p.currentOptions.Load()
|
|
||||||
return cookie.Options{
|
|
||||||
Name: opts.CookieName,
|
|
||||||
Domain: opts.CookieDomain,
|
|
||||||
Secure: opts.CookieSecure,
|
|
||||||
HTTPOnly: opts.CookieHTTPOnly,
|
|
||||||
Expire: opts.CookieExpire,
|
|
||||||
}
|
|
||||||
}, encoder)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
p.sessionStore = cookieStore
|
|
||||||
p.sessionLoaders = []sessions.SessionLoader{
|
|
||||||
cookieStore,
|
|
||||||
header.NewStore(encoder, httputil.AuthorizationTypePomerium),
|
|
||||||
queryparam.NewStore(encoder, "pomerium_session")}
|
|
||||||
|
|
||||||
authzConn, err := grpc.NewGRPCClientConn(&grpc.Options{
|
|
||||||
Addr: p.authorizeURL,
|
|
||||||
OverrideCertificateName: opts.OverrideCertificateName,
|
|
||||||
CA: opts.CA,
|
|
||||||
CAFile: opts.CAFile,
|
|
||||||
RequestTimeout: opts.GRPCClientTimeout,
|
|
||||||
ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin,
|
|
||||||
WithInsecure: opts.GRPCInsecure,
|
|
||||||
ServiceName: opts.Services,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
p.authzClient = envoy_service_auth_v2.NewAuthorizationClient(authzConn)
|
|
||||||
|
|
||||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||||
return int64(len(opts.Policies))
|
return int64(len(p.currentOptions.Load().Policies))
|
||||||
})
|
})
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
|
@ -174,6 +89,11 @@ func (p *Proxy) OnConfigChange(cfg *config.Config) {
|
||||||
log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("proxy: updating options")
|
log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("proxy: updating options")
|
||||||
p.currentOptions.Store(cfg.Options)
|
p.currentOptions.Store(cfg.Options)
|
||||||
p.setHandlers(cfg.Options)
|
p.setHandlers(cfg.Options)
|
||||||
|
if state, err := newProxyStateFromConfig(cfg); err != nil {
|
||||||
|
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
|
||||||
|
} else {
|
||||||
|
p.state.Store(state)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) setHandlers(opts *config.Options) {
|
func (p *Proxy) setHandlers(opts *config.Options) {
|
||||||
|
|
|
@ -103,7 +103,7 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := New(tt.opts)
|
got, err := New(&config.Config{Options: tt.opts})
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -192,7 +192,7 @@ func Test_UpdateOptions(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(tt.originalOptions)
|
p, err := New(&config.Config{Options: tt.originalOptions})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
124
proxy/state.go
Normal file
124
proxy/state.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/url"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyState struct {
|
||||||
|
sharedKey string
|
||||||
|
sharedCipher cipher.AEAD
|
||||||
|
|
||||||
|
authorizeURL *url.URL
|
||||||
|
authenticateURL *url.URL
|
||||||
|
authenticateDashboardURL *url.URL
|
||||||
|
authenticateSigninURL *url.URL
|
||||||
|
authenticateSignoutURL *url.URL
|
||||||
|
authenticateRefreshURL *url.URL
|
||||||
|
|
||||||
|
encoder encoding.MarshalUnmarshaler
|
||||||
|
cookieSecret []byte
|
||||||
|
refreshCooldown time.Duration
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
sessionLoaders []sessions.SessionLoader
|
||||||
|
jwtClaimHeaders []string
|
||||||
|
authzClient envoy_service_auth_v2.AuthorizationClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
||||||
|
err := ValidateOptions(cfg.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
state := new(proxyState)
|
||||||
|
state.sharedKey = cfg.Options.SharedKey
|
||||||
|
state.sharedCipher, _ = cryptutil.NewAEADCipherFromBase64(cfg.Options.SharedKey)
|
||||||
|
state.cookieSecret, _ = base64.StdEncoding.DecodeString(cfg.Options.CookieSecret)
|
||||||
|
|
||||||
|
// used to load and verify JWT tokens signed by the authenticate service
|
||||||
|
state.encoder, err = jws.NewHS256Signer([]byte(cfg.Options.SharedKey), cfg.Options.GetAuthenticateURL().Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
state.refreshCooldown = cfg.Options.RefreshCooldown
|
||||||
|
state.jwtClaimHeaders = cfg.Options.JWTClaimsHeaders
|
||||||
|
|
||||||
|
// errors checked in ValidateOptions
|
||||||
|
state.authorizeURL, _ = urlutil.DeepCopy(cfg.Options.AuthorizeURL)
|
||||||
|
state.authenticateURL, _ = urlutil.DeepCopy(cfg.Options.AuthenticateURL)
|
||||||
|
state.authenticateDashboardURL = state.authenticateURL.ResolveReference(&url.URL{Path: dashboardPath})
|
||||||
|
state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||||
|
state.authenticateSignoutURL = state.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
||||||
|
state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
|
||||||
|
|
||||||
|
state.sessionStore, err = cookie.NewStore(func() cookie.Options {
|
||||||
|
return cookie.Options{
|
||||||
|
Name: cfg.Options.CookieName,
|
||||||
|
Domain: cfg.Options.CookieDomain,
|
||||||
|
Secure: cfg.Options.CookieSecure,
|
||||||
|
HTTPOnly: cfg.Options.CookieHTTPOnly,
|
||||||
|
Expire: cfg.Options.CookieExpire,
|
||||||
|
}
|
||||||
|
}, state.encoder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.sessionLoaders = []sessions.SessionLoader{
|
||||||
|
state.sessionStore,
|
||||||
|
header.NewStore(state.encoder, httputil.AuthorizationTypePomerium),
|
||||||
|
queryparam.NewStore(state.encoder, "pomerium_session")}
|
||||||
|
|
||||||
|
authzConn, err := grpc.GetGRPCClientConn("authorize", &grpc.Options{
|
||||||
|
Addr: state.authorizeURL,
|
||||||
|
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||||
|
CA: cfg.Options.CA,
|
||||||
|
CAFile: cfg.Options.CAFile,
|
||||||
|
RequestTimeout: cfg.Options.GRPCClientTimeout,
|
||||||
|
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||||
|
WithInsecure: cfg.Options.GRPCInsecure,
|
||||||
|
ServiceName: cfg.Options.Services,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.authzClient = envoy_service_auth_v2.NewAuthorizationClient(authzConn)
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type atomicProxyState struct {
|
||||||
|
value atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAtomicProxyState(state *proxyState) *atomicProxyState {
|
||||||
|
aps := new(atomicProxyState)
|
||||||
|
aps.Store(state)
|
||||||
|
return aps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aps *atomicProxyState) Load() *proxyState {
|
||||||
|
return aps.value.Load().(*proxyState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aps *atomicProxyState) Store(state *proxyState) {
|
||||||
|
aps.value.Store(state)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue