diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index d74a7dcbb..607ea74ec 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -177,7 +177,7 @@ func setupProxy(src config.Source, cfg *config.Config, controlPlane *controlplan return nil } - svc, err := proxy.New(cfg.Options) + svc, err := proxy.New(cfg) if err != nil { return fmt.Errorf("error creating proxy service: %w", err) } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 2a43e3cde..37074cc9e 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -11,8 +11,10 @@ import ( "net" "net/url" "strconv" + "sync" "time" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/credentials" @@ -144,3 +146,46 @@ func grpcTimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { return invoker(ctx, method, req, reply, cc, opts...) } } + +type grpcClientConnRecord struct { + conn *grpc.ClientConn + opts *Options +} + +var grpcClientConns = struct { + sync.Mutex + m map[string]grpcClientConnRecord +}{ + m: make(map[string]grpcClientConnRecord), +} + +// GetGRPCClientConn returns a gRPC client connection for the given name. If a connection for that name has already been +// established the existing connection will be returned. If any options change for that connection, the existing +// connection will be closed and a new one established. +func GetGRPCClientConn(name string, opts *Options) (*grpc.ClientConn, error) { + grpcClientConns.Lock() + defer grpcClientConns.Unlock() + + current, ok := grpcClientConns.m[name] + if ok { + if cmp.Equal(current.opts, opts) { + return current.conn, nil + } + + err := current.conn.Close() + if err != nil { + log.Error().Err(err).Msg("grpc: failed to close existing connection") + } + } + + cc, err := NewGRPCClientConn(opts) + if err != nil { + return nil, err + } + + grpcClientConns.m[name] = grpcClientConnRecord{ + conn: cc, + opts: opts, + } + return cc, nil +} diff --git a/pkg/grpc/client_test.go b/pkg/grpc/client_test.go index dedbed5de..4847cd986 100644 --- a/pkg/grpc/client_test.go +++ b/pkg/grpc/client_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" ) @@ -77,3 +78,39 @@ func TestNewGRPC(t *testing.T) { }) } } + +func TestGetGRPC(t *testing.T) { + cc1, err := GetGRPCClientConn("example", &Options{ + Addr: mustParseURL("https://localhost.example"), + }) + if !assert.NoError(t, err) { + return + } + + cc2, err := GetGRPCClientConn("example", &Options{ + Addr: mustParseURL("https://localhost.example"), + }) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, cc1, cc2, "GetGRPCClientConn should return the same connection when there are no changes") + + cc3, err := GetGRPCClientConn("example", &Options{ + Addr: mustParseURL("http://localhost.example"), + WithInsecure: true, + }) + if !assert.NoError(t, err) { + return + } + + assert.NotEqual(t, cc1, cc3, "GetGRPCClientConn should return a new connection when there are changes") +} + +func mustParseURL(rawurl string) *url.URL { + u, err := url.Parse(rawurl) + if err != nil { + panic(err) + } + return u +} diff --git a/proxy/forward_auth.go b/proxy/forward_auth.go index f66854082..597ef840c 100644 --- a/proxy/forward_auth.go +++ b/proxy/forward_auth.go @@ -18,7 +18,9 @@ import ( func (p *Proxy) registerFwdAuthHandlers() http.Handler { r := httputil.NewRouter() r.StrictSlash(true) - r.Use(sessions.RetrieveSession(p.sessionStore)) + r.Use(func(h http.Handler) http.Handler { + return sessions.RetrieveSession(p.state.Load().sessionStore)(h) + }) r.Use(p.jwtClaimMiddleware(true)) // NGNIX's forward-auth capabilities are split across two settings: @@ -96,6 +98,8 @@ func (p *Proxy) forwardedURIHeaderCallback(w http.ResponseWriter, r *http.Reques // provider. If the user is unauthorized, a `401` error is returned. func (p *Proxy) Verify(verifyOnly bool) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + state := p.state.Load() + var err error if status := r.FormValue("auth_status"); status == fmt.Sprint(http.StatusForbidden) { return httputil.NewError(http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))) @@ -120,7 +124,7 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler { unAuthenticated := ar.statusCode == http.StatusUnauthorized if unAuthenticated { - p.sessionStore.ClearSession(w, r) + state.sessionStore.ClearSession(w, r) } _, err = sessions.FromContext(r.Context()) @@ -141,6 +145,8 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler { // forwardAuthRedirectToSignInWithURI redirects request to authenticate signin url, // with all necessary information extracted from given input uri. func (p *Proxy) forwardAuthRedirectToSignInWithURI(w http.ResponseWriter, r *http.Request, uri *url.URL) { + state := p.state.Load() + // Traefik set the uri in the header, we must set it in redirect uri if present. Otherwise, request like // https://example.com/foo will be redirected to https://example.com after authentication. if xfu := r.Header.Get(httputil.HeaderForwardedURI); xfu != "/" { @@ -148,13 +154,13 @@ func (p *Proxy) forwardAuthRedirectToSignInWithURI(w http.ResponseWriter, r *htt } // redirect to authenticate - authN := *p.authenticateSigninURL + authN := *state.authenticateSigninURL q := authN.Query() q.Set(urlutil.QueryCallbackURI, uri.String()) q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience authN.RawQuery = q.Encode() - httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound) + httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &authN).String(), http.StatusFound) } func getURIStringFromRequest(r *http.Request) (*url.URL, error) { diff --git a/proxy/forward_auth_test.go b/proxy/forward_auth_test.go index 86e0ae974..1465bcc63 100644 --- a/proxy/forward_auth_test.go +++ b/proxy/forward_auth_test.go @@ -83,18 +83,19 @@ func TestProxy_ForwardAuth(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(tt.options) + p, err := New(&config.Config{Options: tt.options}) if err != nil { t.Fatal(err) } - p.authzClient = tt.authorizer - p.sessionStore = tt.sessionStore + p.OnConfigChange(&config.Config{Options: tt.options}) + state := p.state.Load() + state.authzClient = tt.authorizer + state.sessionStore = tt.sessionStore signer, err := jws.NewHS256Signer(nil, "mock") if err != nil { t.Fatal(err) } - p.encoder = signer - p.OnConfigChange(&config.Config{Options: tt.options}) + state.encoder = signer uri, err := url.Parse(tt.requestURI) if err != nil { t.Fatal(err) @@ -110,10 +111,10 @@ func TestProxy_ForwardAuth(t *testing.T) { uri.RawQuery = queryString.Encode() r := httptest.NewRequest(tt.method, uri.String(), nil) - state, _ := tt.sessionStore.LoadSession(r) + ss, _ := tt.sessionStore.LoadSession(r) ctx := r.Context() - ctx = sessions.NewContext(ctx, state, tt.ctxError) + ctx = sessions.NewContext(ctx, ss, tt.ctxError) r = r.WithContext(ctx) r.Header.Set("Accept", "application/json") if len(tt.headers) != 0 { diff --git a/proxy/handlers.go b/proxy/handlers.go index a116e858e..efdfe75d1 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -21,14 +21,17 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { h := r.PathPrefix(dashboardPath).Subrouter() h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) // 1. Retrieve the user session and add it to the request context - h.Use(sessions.RetrieveSession(p.sessionStore)) + h.Use(func(h http.Handler) http.Handler { + return sessions.RetrieveSession(p.state.Load().sessionStore)(h) + }) // 2. AuthN - Verify the user is authenticated. Set email, group, & id headers h.Use(p.AuthenticateSession) // 3. Enforce CSRF protections for any non-idempotent http method h.Use(func(h http.Handler) http.Handler { opts := p.currentOptions.Load() + state := p.state.Load() return csrf.Protect( - p.cookieSecret, + state.cookieSecret, csrf.Secure(opts.CookieSecure), csrf.CookieName(fmt.Sprintf("%s_csrf", opts.CookieName)), csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)), @@ -42,7 +45,9 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { // callback used to set route-scoped session and redirect back to destination // only accept signed requests (hmac) from other trusted pomerium services c := r.PathPrefix(dashboardPath + "/callback").Subrouter() - c.Use(middleware.ValidateSignature(p.SharedKey)) + h.Use(func(h http.Handler) http.Handler { + return middleware.ValidateSignature(p.state.Load().sharedKey)(h) + }) c.Path("/"). Handler(httputil.HandlerFunc(p.ProgrammaticCallback)). @@ -71,28 +76,32 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { // of the authenticate service to revoke the remote session and clear // the local session state. func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { + state := p.state.Load() + redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} if uri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)); err == nil && uri.String() != "" { redirectURL = uri } - signoutURL := *p.authenticateSignoutURL + signoutURL := *state.authenticateSignoutURL q := signoutURL.Query() q.Set(urlutil.QueryRedirectURI, redirectURL.String()) signoutURL.RawQuery = q.Encode() - p.sessionStore.ClearSession(w, r) - httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signoutURL).String(), http.StatusFound) + state.sessionStore.ClearSession(w, r) + httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signoutURL).String(), http.StatusFound) } // UserDashboard redirects to the authenticate dasbhoard. func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { + state := p.state.Load() + redirectURL := urlutil.GetAbsoluteURL(r).String() if ref := r.Header.Get(httputil.HeaderReferrer); ref != "" { redirectURL = ref } - url := p.authenticateDashboardURL.ResolveReference(&url.URL{ + url := state.authenticateDashboardURL.ResolveReference(&url.URL{ RawQuery: url.Values{ urlutil.QueryRedirectURI: {redirectURL}, }.Encode(), @@ -116,18 +125,20 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error { // saveCallbackSession takes an encrypted per-route session token, and decrypts // it using the shared service key, then stores it the local session store. func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) { + state := p.state.Load() + // 1. extract the base64 encoded and encrypted JWT from query params encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken) if err != nil { return nil, fmt.Errorf("proxy: malfromed callback token: %w", err) } // 2. decrypt the JWT using the cipher using the _shared_ secret key - rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil) + rawJWT, err := cryptutil.Decrypt(state.sharedCipher, encryptedJWT, nil) if err != nil { return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err) } // 3. Save the decrypted JWT to the session store directly as a string, without resigning - if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil { + if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil { return nil, fmt.Errorf("proxy: callback session save failure: %w", err) } return rawJWT, nil @@ -136,11 +147,13 @@ func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enct // ProgrammaticLogin returns a signed url that can be used to login // using the authenticate service. func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error { + state := p.state.Load() + redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } - signinURL := *p.authenticateSigninURL + signinURL := *state.authenticateSigninURL callbackURI := urlutil.GetAbsoluteURL(r) callbackURI.Path = dashboardPath + "/callback/" q := signinURL.Query() @@ -148,7 +161,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error q.Set(urlutil.QueryRedirectURI, redirectURI.String()) q.Set(urlutil.QueryIsProgrammatic, "true") signinURL.RawQuery = q.Encode() - response := urlutil.NewSignedURL(p.SharedKey, &signinURL).String() + response := urlutil.NewSignedURL(state.sharedKey, &signinURL).String() w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusOK) diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index d58df1bf4..1871c3a3c 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) { if err != nil { t.Fatal(err) } - proxy, err := New(opts) + proxy, err := New(&config.Config{Options: opts}) if err != nil { t.Fatal(err) } @@ -58,7 +58,7 @@ func TestProxy_Signout(t *testing.T) { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) } body := rr.Body.String() - want := (proxy.authenticateURL.String()) + want := proxy.state.Load().authenticateURL.String() if !strings.Contains(body, want) { t.Errorf("handler returned unexpected body: got %v want %s ", body, want) } @@ -79,7 +79,7 @@ func TestProxy_SignOut(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := testOptions(t) - p, err := New(opts) + p, err := New(&config.Config{Options: opts}) if err != nil { t.Fatal(err) } @@ -221,13 +221,14 @@ func TestProxy_Callback(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(tt.options) + p, err := New(&config.Config{Options: tt.options}) if err != nil { t.Fatal(err) } - p.encoder = tt.cipher - p.sessionStore = tt.sessionStore p.OnConfigChange(&config.Config{Options: tt.options}) + state := p.state.Load() + state.encoder = tt.cipher + state.sessionStore = tt.sessionStore redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path} queryString := redirectURI.Query() for k, v := range tt.qp { @@ -297,7 +298,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(tt.options) + p, err := New(&config.Config{Options: tt.options}) if err != nil { t.Fatal(err) } @@ -428,13 +429,14 @@ func TestProxy_ProgrammaticCallback(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(tt.options) + p, err := New(&config.Config{Options: tt.options}) if err != nil { t.Fatal(err) } - p.encoder = tt.cipher - p.sessionStore = tt.sessionStore p.OnConfigChange(&config.Config{Options: tt.options}) + state := p.state.Load() + state.encoder = tt.cipher + state.sessionStore = tt.sessionStore redirectURI, _ := url.Parse(tt.redirectURI) queryString := redirectURI.Query() for k, v := range tt.qp { diff --git a/proxy/middleware.go b/proxy/middleware.go index 892a4477c..7c3bc8545 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -40,17 +40,21 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { } func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error { - signinURL := *p.authenticateSigninURL + state := p.state.Load() + + signinURL := *state.authenticateSigninURL q := signinURL.Query() q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) signinURL.RawQuery = q.Encode() log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin") - httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) - p.sessionStore.ClearSession(w, r) + httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signinURL).String(), http.StatusFound) + state.sessionStore.ClearSession(w, r) return nil } func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) { + state := p.state.Load() + tm, err := ptypes.TimestampProto(time.Now()) if err != nil { return nil, httputil.NewError(http.StatusInternalServerError, fmt.Errorf("error creating protobuf timestamp from current time: %w", err)) @@ -72,7 +76,7 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorize httpAttrs.Path += "?" + r.URL.RawQuery } - res, err := p.authzClient.Check(r.Context(), &envoy_service_auth_v2.CheckRequest{ + res, err := state.authzClient.Check(r.Context(), &envoy_service_auth_v2.CheckRequest{ Attributes: &envoy_service_auth_v2.AttributeContext{ Request: &envoy_service_auth_v2.AttributeContext_Request{ Time: tm, @@ -118,12 +122,12 @@ func SetResponseHeaders(headers map[string]string) func(next http.Handler) http. // // if returnJWTInfo is set to true, it will also return JWT claim information in the response func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc { - return func(next http.Handler) http.Handler { - return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { defer next.ServeHTTP(w, r) + state := p.state.Load() + jwt, err := sessions.FromContext(r.Context()) if err != nil { log.Error().Err(err).Msg("proxy: could not locate session from context") @@ -147,7 +151,7 @@ func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc { } // set headers for any claims specified by config - for _, claimName := range p.jwtClaimHeaders { + for _, claimName := range state.jwtClaimHeaders { if _, ok := formattedJWTClaims[claimName]; ok { headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName) @@ -165,10 +169,12 @@ func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc { // getFormatJWTClaims reformats jwtClaims into something resembling map[string]string func (p *Proxy) getFormatedJWTClaims(jwt []byte) (map[string]string, error) { + state := p.state.Load() + formattedJWTClaims := make(map[string]string) var jwtClaims map[string]interface{} - if err := p.encoder.Unmarshal(jwt, &jwtClaims); err != nil { + if err := state.encoder.Unmarshal(jwt, &jwtClaims); err != nil { return formattedJWTClaims, err } diff --git a/proxy/middleware_test.go b/proxy/middleware_test.go index afbda08b4..623236e5b 100644 --- a/proxy/middleware_test.go +++ b/proxy/middleware_test.go @@ -57,13 +57,15 @@ func TestProxy_AuthenticateSession(t *testing.T) { } a := Proxy{ - SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", - cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), - authenticateURL: uriParseHelper("https://authenticate.corp.example"), - authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), - authenticateRefreshURL: uriParseHelper(rURL), - sessionStore: tt.session, - encoder: tt.encoder, + state: newAtomicProxyState(&proxyState{ + sharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + authenticateURL: uriParseHelper("https://authenticate.corp.example"), + authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), + authenticateRefreshURL: uriParseHelper(rURL), + sessionStore: tt.session, + encoder: tt.encoder, + }), } r := httptest.NewRequest(http.MethodGet, "/", nil) state, _ := tt.session.LoadSession(r) @@ -95,10 +97,12 @@ func Test_jwtClaimMiddleware(t *testing.T) { } a := Proxy{ - SharedKey: sharedKey, - cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), - encoder: encoder, - jwtClaimHeaders: claimHeaders, + state: newAtomicProxyState(&proxyState{ + sharedKey: sharedKey, + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + encoder: encoder, + jwtClaimHeaders: claimHeaders, + }), } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/proxy/proxy.go b/proxy/proxy.go index add53e860..fa164f404 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,32 +5,20 @@ package proxy import ( - "crypto/cipher" - "encoding/base64" "fmt" "html/template" "net/http" - "net/url" "sync/atomic" - "time" - envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" "github.com/gorilla/mux" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/sessions/cookie" - "github.com/pomerium/pomerium/internal/sessions/header" - "github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc" ) const ( @@ -64,102 +52,29 @@ func ValidateOptions(o *config.Options) error { // Proxy stores all the information associated with proxying a request. type Proxy struct { - // SharedKey used to mutually authenticate service communication - SharedKey string - sharedCipher cipher.AEAD - - authorizeURL *url.URL - authenticateURL *url.URL - authenticateDashboardURL *url.URL - authenticateSigninURL *url.URL - authenticateSignoutURL *url.URL - authenticateRefreshURL *url.URL - - encoder encoding.Unmarshaler - cookieSecret []byte - refreshCooldown time.Duration - sessionStore sessions.SessionStore - sessionLoaders []sessions.SessionLoader - templates *template.Template - jwtClaimHeaders []string - authzClient envoy_service_auth_v2.AuthorizationClient - + templates *template.Template + state *atomicProxyState currentOptions *config.AtomicOptions currentRouter atomic.Value } // New takes a Proxy service from options and a validation function. // Function returns an error if options fail to validate. -func New(opts *config.Options) (*Proxy, error) { - if err := ValidateOptions(opts); err != nil { - return nil, err - } - - sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey) - decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) - - // used to load and verify JWT tokens signed by the authenticate service - encoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.GetAuthenticateURL().Host) +func New(cfg *config.Config) (*Proxy, error) { + state, err := newProxyStateFromConfig(cfg) if err != nil { return nil, err } p := &Proxy{ - SharedKey: opts.SharedKey, - sharedCipher: sharedCipher, - encoder: encoder, - - cookieSecret: decodedCookieSecret, - refreshCooldown: opts.RefreshCooldown, - templates: template.Must(frontend.NewTemplates()), - jwtClaimHeaders: opts.JWTClaimsHeaders, - currentOptions: config.NewAtomicOptions(), + templates: template.Must(frontend.NewTemplates()), + state: newAtomicProxyState(state), + currentOptions: config.NewAtomicOptions(), } p.currentRouter.Store(httputil.NewRouter()) - // errors checked in ValidateOptions - p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) - p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) - p.authenticateDashboardURL = p.authenticateURL.ResolveReference(&url.URL{Path: dashboardPath}) - p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) - p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL}) - p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) - - cookieStore, err := cookie.NewStore(func() cookie.Options { - opts := p.currentOptions.Load() - return cookie.Options{ - Name: opts.CookieName, - Domain: opts.CookieDomain, - Secure: opts.CookieSecure, - HTTPOnly: opts.CookieHTTPOnly, - Expire: opts.CookieExpire, - } - }, encoder) - if err != nil { - return nil, err - } - p.sessionStore = cookieStore - p.sessionLoaders = []sessions.SessionLoader{ - cookieStore, - header.NewStore(encoder, httputil.AuthorizationTypePomerium), - queryparam.NewStore(encoder, "pomerium_session")} - - authzConn, err := grpc.NewGRPCClientConn(&grpc.Options{ - Addr: p.authorizeURL, - OverrideCertificateName: opts.OverrideCertificateName, - CA: opts.CA, - CAFile: opts.CAFile, - RequestTimeout: opts.GRPCClientTimeout, - ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin, - WithInsecure: opts.GRPCInsecure, - ServiceName: opts.Services, - }) - if err != nil { - return nil, err - } - p.authzClient = envoy_service_auth_v2.NewAuthorizationClient(authzConn) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { - return int64(len(opts.Policies)) + return int64(len(p.currentOptions.Load().Policies)) }) return p, nil @@ -174,6 +89,11 @@ func (p *Proxy) OnConfigChange(cfg *config.Config) { log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("proxy: updating options") p.currentOptions.Store(cfg.Options) p.setHandlers(cfg.Options) + if state, err := newProxyStateFromConfig(cfg); err != nil { + log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") + } else { + p.state.Store(state) + } } func (p *Proxy) setHandlers(opts *config.Options) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 57109a963..898d92ba3 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -103,7 +103,7 @@ func TestNew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := New(tt.opts) + got, err := New(&config.Config{Options: tt.opts}) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -192,7 +192,7 @@ func Test_UpdateOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(tt.originalOptions) + p, err := New(&config.Config{Options: tt.originalOptions}) if err != nil { t.Fatal(err) } diff --git a/proxy/state.go b/proxy/state.go new file mode 100644 index 000000000..745831dba --- /dev/null +++ b/proxy/state.go @@ -0,0 +1,124 @@ +package proxy + +import ( + "crypto/cipher" + "encoding/base64" + "net/url" + "sync/atomic" + "time" + + envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/sessions/cookie" + "github.com/pomerium/pomerium/internal/sessions/header" + "github.com/pomerium/pomerium/internal/sessions/queryparam" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc" +) + +type proxyState struct { + sharedKey string + sharedCipher cipher.AEAD + + authorizeURL *url.URL + authenticateURL *url.URL + authenticateDashboardURL *url.URL + authenticateSigninURL *url.URL + authenticateSignoutURL *url.URL + authenticateRefreshURL *url.URL + + encoder encoding.MarshalUnmarshaler + cookieSecret []byte + refreshCooldown time.Duration + sessionStore sessions.SessionStore + sessionLoaders []sessions.SessionLoader + jwtClaimHeaders []string + authzClient envoy_service_auth_v2.AuthorizationClient +} + +func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { + err := ValidateOptions(cfg.Options) + if err != nil { + return nil, err + } + + state := new(proxyState) + state.sharedKey = cfg.Options.SharedKey + state.sharedCipher, _ = cryptutil.NewAEADCipherFromBase64(cfg.Options.SharedKey) + state.cookieSecret, _ = base64.StdEncoding.DecodeString(cfg.Options.CookieSecret) + + // used to load and verify JWT tokens signed by the authenticate service + state.encoder, err = jws.NewHS256Signer([]byte(cfg.Options.SharedKey), cfg.Options.GetAuthenticateURL().Host) + if err != nil { + return nil, err + } + + state.refreshCooldown = cfg.Options.RefreshCooldown + state.jwtClaimHeaders = cfg.Options.JWTClaimsHeaders + + // errors checked in ValidateOptions + state.authorizeURL, _ = urlutil.DeepCopy(cfg.Options.AuthorizeURL) + state.authenticateURL, _ = urlutil.DeepCopy(cfg.Options.AuthenticateURL) + state.authenticateDashboardURL = state.authenticateURL.ResolveReference(&url.URL{Path: dashboardPath}) + state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) + state.authenticateSignoutURL = state.authenticateURL.ResolveReference(&url.URL{Path: signoutURL}) + state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) + + state.sessionStore, err = cookie.NewStore(func() cookie.Options { + return cookie.Options{ + Name: cfg.Options.CookieName, + Domain: cfg.Options.CookieDomain, + Secure: cfg.Options.CookieSecure, + HTTPOnly: cfg.Options.CookieHTTPOnly, + Expire: cfg.Options.CookieExpire, + } + }, state.encoder) + if err != nil { + return nil, err + } + state.sessionLoaders = []sessions.SessionLoader{ + state.sessionStore, + header.NewStore(state.encoder, httputil.AuthorizationTypePomerium), + queryparam.NewStore(state.encoder, "pomerium_session")} + + authzConn, err := grpc.GetGRPCClientConn("authorize", &grpc.Options{ + Addr: state.authorizeURL, + OverrideCertificateName: cfg.Options.OverrideCertificateName, + CA: cfg.Options.CA, + CAFile: cfg.Options.CAFile, + RequestTimeout: cfg.Options.GRPCClientTimeout, + ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, + WithInsecure: cfg.Options.GRPCInsecure, + ServiceName: cfg.Options.Services, + }) + if err != nil { + return nil, err + } + state.authzClient = envoy_service_auth_v2.NewAuthorizationClient(authzConn) + + return state, nil +} + +type atomicProxyState struct { + value atomic.Value +} + +func newAtomicProxyState(state *proxyState) *atomicProxyState { + aps := new(atomicProxyState) + aps.Store(state) + return aps +} + +func (aps *atomicProxyState) Load() *proxyState { + return aps.value.Load().(*proxyState) +} + +func (aps *atomicProxyState) Store(state *proxyState) { + aps.value.Store(state) +}