diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 7881e3d20..5e6cf714f 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -14,8 +14,6 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc" - "github.com/pomerium/pomerium/pkg/grpc/databroker" ) // ValidateOptions checks that configuration are complete and valid. @@ -50,9 +48,6 @@ func ValidateOptions(o *config.Options) error { // Authenticate contains data required to run the authenticate service. type Authenticate struct { - // dataBrokerClient is used to retrieve sessions - dataBrokerClient databroker.DataBrokerServiceClient - templates *template.Template options *config.AtomicOptions @@ -62,39 +57,11 @@ type Authenticate struct { // New validates and creates a new authenticate service from a set of Options. func New(cfg *config.Config) (*Authenticate, error) { - if err := ValidateOptions(cfg.Options); err != nil { - return nil, err - } - - dataBrokerConn, err := grpc.NewGRPCClientConn( - &grpc.Options{ - Addr: cfg.Options.DataBrokerURL, - OverrideCertificateName: cfg.Options.OverrideCertificateName, - CA: cfg.Options.CA, - CAFile: cfg.Options.CAFile, - RequestTimeout: cfg.Options.GRPCClientTimeout, - ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, - WithInsecure: cfg.Options.GRPCInsecure, - ServiceName: cfg.Options.Services, - }) - if err != nil { - return nil, err - } - - dataBrokerClient := databroker.NewDataBrokerServiceClient(dataBrokerConn) - a := &Authenticate{ - // grpc client for cache - dataBrokerClient: dataBrokerClient, - templates: template.Must(frontend.NewTemplates()), - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), - state: newAtomicAuthenticateState(newAuthenticateState()), - } - - err = a.updateProvider(cfg) - if err != nil { - return nil, err + templates: template.Must(frontend.NewTemplates()), + options: config.NewAtomicOptions(), + provider: identity.NewAtomicAuthenticator(), + state: newAtomicAuthenticateState(newAuthenticateState()), } state, err := newAuthenticateStateFromConfig(cfg) @@ -103,6 +70,11 @@ func New(cfg *config.Config) (*Authenticate, error) { } a.state.Store(state) + err = a.updateProvider(cfg) + if err != nil { + return nil, err + } + return a, nil } @@ -114,14 +86,14 @@ func (a *Authenticate) OnConfigChange(cfg *config.Config) { log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options") a.options.Store(cfg.Options) - if err := a.updateProvider(cfg); err != nil { - log.Error().Err(err).Msg("authenticate: failed to update identity provider") - } if state, err := newAuthenticateStateFromConfig(cfg); err != nil { log.Error().Err(err).Msg("authenticate: failed to update state") } else { a.state.Store(state) } + if err := a.updateProvider(cfg); err != nil { + log.Error().Err(err).Msg("authenticate: failed to update identity provider") + } } func (a *Authenticate) updateProvider(cfg *config.Config) error { diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 31714e5ac..1cf4694d7 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -144,14 +144,17 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession") defer span.End() + + state := a.state.Load() + sessionState, err := a.getSessionFromCtx(ctx) if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error") return a.reauthenticateOrFail(w, r, err) } - if a.dataBrokerClient != nil { - _, err = session.Get(ctx, a.dataBrokerClient, sessionState.ID) + if state.dataBrokerClient != nil { + _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID) if err != nil { log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker") return a.reauthenticateOrFail(w, r, err) @@ -232,7 +235,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { newSession.Programmatic = true - pbSession, err := session.Get(ctx, a.dataBrokerClient, s.ID) + pbSession, err := session.Get(ctx, state.dataBrokerClient, s.ID) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } @@ -278,7 +281,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { sessionState, err := a.getSessionFromCtx(ctx) if err == nil { - if s, _ := session.Get(ctx, a.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { + if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { log.Warn().Err(err).Msg("failed to revoke access token") } @@ -474,11 +477,8 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, } func (a *Authenticate) deleteSession(ctx context.Context, sessionID string) error { - if a.dataBrokerClient == nil { - return nil - } - err := session.Delete(ctx, a.dataBrokerClient, sessionID) - return err + state := a.state.Load() + return session.Delete(ctx, state.dataBrokerClient, sessionID) } func (a *Authenticate) isAdmin(user string) bool { @@ -489,24 +489,26 @@ func (a *Authenticate) isAdmin(user string) bool { // Dashboard renders the /.pomerium/ user dashboard. func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error { + state := a.state.Load() + s, err := a.getSessionFromCtx(r.Context()) if err != nil { s.ID = uuid.New().String() } - pbSession, err := session.Get(r.Context(), a.dataBrokerClient, s.ID) + pbSession, err := session.Get(r.Context(), state.dataBrokerClient, s.ID) if err != nil { pbSession = &session.Session{ Id: s.ID, } } - pbUser, err := user.Get(r.Context(), a.dataBrokerClient, pbSession.GetUserId()) + pbUser, err := user.Get(r.Context(), state.dataBrokerClient, pbSession.GetUserId()) if err != nil { pbUser = &user.User{ Id: pbSession.GetUserId(), } } - pbDirectoryUser, err := directory.GetUser(r.Context(), a.dataBrokerClient, pbSession.GetUserId()) + pbDirectoryUser, err := directory.GetUser(r.Context(), state.dataBrokerClient, pbSession.GetUserId()) if err != nil { pbDirectoryUser = &directory.User{ Id: pbSession.GetUserId(), @@ -514,7 +516,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error { } var groups []*directory.Group for _, groupID := range pbDirectoryUser.GetGroupIds() { - pbDirectoryGroup, err := directory.GetGroup(r.Context(), a.dataBrokerClient, groupID) + pbDirectoryGroup, err := directory.GetGroup(r.Context(), state.dataBrokerClient, groupID) if err != nil { pbDirectoryGroup = &directory.Group{ Id: groupID, @@ -556,10 +558,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error { } func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState *sessions.State, accessToken *oauth2.Token) error { - if a.dataBrokerClient == nil { - return nil - } - + state := a.state.Load() options := a.options.Load() sessionExpiry, _ := ptypes.TimestampProto(time.Now().Add(options.CookieExpire)) @@ -580,7 +579,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState } // if no user exists yet, create a new one - currentUser, _ := user.Get(ctx, a.dataBrokerClient, s.GetUserId()) + currentUser, _ := user.Get(ctx, state.dataBrokerClient, s.GetUserId()) if currentUser == nil { mu := manager.User{ User: &user.User{ @@ -591,13 +590,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState if err != nil { return fmt.Errorf("authenticate: error retrieving user info: %w", err) } - _, err = user.Set(ctx, a.dataBrokerClient, mu.User) + _, err = user.Set(ctx, state.dataBrokerClient, mu.User) if err != nil { return fmt.Errorf("authenticate: error saving user: %w", err) } } - res, err := session.Set(ctx, a.dataBrokerClient, s) + res, err := session.Set(ctx, state.dataBrokerClient, s) if err != nil { return fmt.Errorf("authenticate: error saving session: %w", err) } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 629cf59a0..d08a057aa 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -152,26 +152,27 @@ func TestAuthenticate_SignIn(t *testing.T) { redirectURL: uriParseHelper("https://some.example"), sharedEncoder: tt.encoder, encryptedEncoder: tt.encoder, - }), - dataBrokerClient: mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, err := ptypes.MarshalAny(&session.Session{ - Id: "SESSION_ID", - }) - if err != nil { - return nil, err - } + dataBrokerClient: mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + data, err := ptypes.MarshalAny(&session.Session{ + Id: "SESSION_ID", + }) + if err != nil { + return nil, err + } - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: "0001", - Type: data.GetTypeUrl(), - Id: "SESSION_ID", - Data: data, - }, - }, nil + return &databroker.GetResponse{ + Record: &databroker.Record{ + Version: "0001", + Type: data.GetTypeUrl(), + Id: "SESSION_ID", + Data: data, + }, + }, nil + }, }, - }, + }), + options: config.NewAtomicOptions(), provider: identity.NewAtomicAuthenticator(), } @@ -237,32 +238,32 @@ func TestAuthenticate_SignOut(t *testing.T) { sessionStore: tt.sessionStore, encryptedEncoder: mock.Encoder{}, sharedEncoder: mock.Encoder{}, + dataBrokerClient: mockDataBrokerServiceClient{ + delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + return nil, nil + }, + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + data, err := ptypes.MarshalAny(&session.Session{ + Id: "SESSION_ID", + }) + if err != nil { + return nil, err + } + + return &databroker.GetResponse{ + Record: &databroker.Record{ + Version: "0001", + Type: data.GetTypeUrl(), + Id: "SESSION_ID", + Data: data, + }, + }, nil + }, + }, }), templates: template.Must(frontend.NewTemplates()), - dataBrokerClient: mockDataBrokerServiceClient{ - delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { - return nil, nil - }, - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, err := ptypes.MarshalAny(&session.Session{ - Id: "SESSION_ID", - }) - if err != nil { - return nil, err - } - - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: "0001", - Type: data.GetTypeUrl(), - Id: "SESSION_ID", - Data: data, - }, - }, nil - }, - }, - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), + options: config.NewAtomicOptions(), + provider: identity.NewAtomicAuthenticator(), } a.provider.Store(tt.provider) u, _ := url.Parse("/sign_out") @@ -347,6 +348,14 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { authURL, _ := url.Parse(tt.authenticateURL) a := &Authenticate{ state: newAtomicAuthenticateState(&authenticateState{ + dataBrokerClient: mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + return nil, fmt.Errorf("not implemented") + }, + set: func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) { + return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil + }, + }, redirectURL: authURL, sessionStore: tt.session, cookieCipher: aead, @@ -477,26 +486,26 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { cookieCipher: aead, encryptedEncoder: signer, sharedEncoder: signer, - }), - dataBrokerClient: mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, err := ptypes.MarshalAny(&session.Session{ - Id: "SESSION_ID", - }) - if err != nil { - return nil, err - } + dataBrokerClient: mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + data, err := ptypes.MarshalAny(&session.Session{ + Id: "SESSION_ID", + }) + if err != nil { + return nil, err + } - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: "0001", - Type: data.GetTypeUrl(), - Id: "SESSION_ID", - Data: data, - }, - }, nil + return &databroker.GetResponse{ + Record: &databroker.Record{ + Version: "0001", + Type: data.GetTypeUrl(), + Id: "SESSION_ID", + Data: data, + }, + }, nil + }, }, - }, + }), options: config.NewAtomicOptions(), provider: identity.NewAtomicAuthenticator(), } @@ -593,29 +602,29 @@ func TestAuthenticate_Dashboard(t *testing.T) { sessionStore: tt.sessionStore, encryptedEncoder: signer, sharedEncoder: signer, + dataBrokerClient: mockDataBrokerServiceClient{ + get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { + data, err := ptypes.MarshalAny(&session.Session{ + Id: "SESSION_ID", + UserId: "USER_ID", + IdToken: &session.IDToken{IssuedAt: pbNow}, + }) + if err != nil { + return nil, err + } + + return &databroker.GetResponse{ + Record: &databroker.Record{ + Version: "0001", + Type: data.GetTypeUrl(), + Id: "SESSION_ID", + Data: data, + }, + }, nil + }, + }, }), templates: template.Must(frontend.NewTemplates()), - dataBrokerClient: mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - data, err := ptypes.MarshalAny(&session.Session{ - Id: "SESSION_ID", - UserId: "USER_ID", - IdToken: &session.IDToken{IssuedAt: pbNow}, - }) - if err != nil { - return nil, err - } - - return &databroker.GetResponse{ - Record: &databroker.Record{ - Version: "0001", - Type: data.GetTypeUrl(), - Id: "SESSION_ID", - Data: data, - }, - }, nil - }, - }, } u, _ := url.Parse("/") r := httptest.NewRequest(tt.method, u.String(), nil) @@ -646,6 +655,7 @@ type mockDataBrokerServiceClient struct { delete func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) + set func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) } func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { @@ -655,3 +665,7 @@ func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker. func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { return m.get(ctx, in, opts...) } + +func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) { + return m.set(ctx, in, opts...) +} diff --git a/authenticate/state.go b/authenticate/state.go index ff2633979..00eae4d43 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -20,6 +20,8 @@ import ( "github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc" + "github.com/pomerium/pomerium/pkg/grpc/databroker" ) type authenticateState struct { @@ -42,6 +44,8 @@ type authenticateState struct { sessionLoaders []sessions.SessionLoader jwk *jose.JSONWebKeySet + + dataBrokerClient databroker.DataBrokerServiceClient } func newAuthenticateState() *authenticateState { @@ -52,6 +56,11 @@ func newAuthenticateState() *authenticateState { } func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) { + err := ValidateOptions(cfg.Options) + if err != nil { + return nil, err + } + state := &authenticateState{} state.redirectURL, _ = urlutil.DeepCopy(cfg.Options.AuthenticateURL) @@ -63,7 +72,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err } // shared state encoder setup - var err error state.sharedEncoder, err = jws.NewHS256Signer([]byte(cfg.Options.SharedKey), cfg.Options.GetAuthenticateURL().Host) if err != nil { return nil, err @@ -106,6 +114,22 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err state.jwk.Keys = append(state.jwk.Keys, *jwk) } + dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ + Addr: cfg.Options.DataBrokerURL, + OverrideCertificateName: cfg.Options.OverrideCertificateName, + CA: cfg.Options.CA, + CAFile: cfg.Options.CAFile, + RequestTimeout: cfg.Options.GRPCClientTimeout, + ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, + WithInsecure: cfg.Options.GRPCInsecure, + ServiceName: cfg.Options.Services, + }) + if err != nil { + return nil, err + } + + state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) + return state, nil }