authenticate: move databroker connection to state (#1292)

* authenticate: move databroker connection to state

* re-use err

* just return

* remove nil checks
This commit is contained in:
Caleb Doxsey 2020-08-18 09:33:43 -06:00 committed by GitHub
parent a1378c81f8
commit 882b6b54ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 151 additions and 142 deletions

View file

@ -14,8 +14,6 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
// ValidateOptions checks that configuration are complete and valid. // ValidateOptions checks that configuration are complete and valid.
@ -50,9 +48,6 @@ func ValidateOptions(o *config.Options) error {
// Authenticate contains data required to run the authenticate service. // Authenticate contains data required to run the authenticate service.
type Authenticate struct { type Authenticate struct {
// dataBrokerClient is used to retrieve sessions
dataBrokerClient databroker.DataBrokerServiceClient
templates *template.Template templates *template.Template
options *config.AtomicOptions options *config.AtomicOptions
@ -62,47 +57,24 @@ type Authenticate struct {
// New validates and creates a new authenticate service from a set of Options. // New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config) (*Authenticate, error) { func New(cfg *config.Config) (*Authenticate, error) {
if err := ValidateOptions(cfg.Options); err != nil {
return nil, err
}
dataBrokerConn, err := grpc.NewGRPCClientConn(
&grpc.Options{
Addr: cfg.Options.DataBrokerURL,
OverrideCertificateName: cfg.Options.OverrideCertificateName,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure,
ServiceName: cfg.Options.Services,
})
if err != nil {
return nil, err
}
dataBrokerClient := databroker.NewDataBrokerServiceClient(dataBrokerConn)
a := &Authenticate{ a := &Authenticate{
// grpc client for cache
dataBrokerClient: dataBrokerClient,
templates: template.Must(frontend.NewTemplates()), templates: template.Must(frontend.NewTemplates()),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(), provider: identity.NewAtomicAuthenticator(),
state: newAtomicAuthenticateState(newAuthenticateState()), state: newAtomicAuthenticateState(newAuthenticateState()),
} }
err = a.updateProvider(cfg)
if err != nil {
return nil, err
}
state, err := newAuthenticateStateFromConfig(cfg) state, err := newAuthenticateStateFromConfig(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
a.state.Store(state) a.state.Store(state)
err = a.updateProvider(cfg)
if err != nil {
return nil, err
}
return a, nil return a, nil
} }
@ -114,14 +86,14 @@ func (a *Authenticate) OnConfigChange(cfg *config.Config) {
log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options") log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options")
a.options.Store(cfg.Options) a.options.Store(cfg.Options)
if err := a.updateProvider(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update identity provider")
}
if state, err := newAuthenticateStateFromConfig(cfg); err != nil { if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update state") log.Error().Err(err).Msg("authenticate: failed to update state")
} else { } else {
a.state.Store(state) a.state.Store(state)
} }
if err := a.updateProvider(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update identity provider")
}
} }
func (a *Authenticate) updateProvider(cfg *config.Config) error { func (a *Authenticate) updateProvider(cfg *config.Config) error {

View file

@ -144,14 +144,17 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession") ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession")
defer span.End() defer span.End()
state := a.state.Load()
sessionState, err := a.getSessionFromCtx(ctx) sessionState, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error") log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
if a.dataBrokerClient != nil { if state.dataBrokerClient != nil {
_, err = session.Get(ctx, a.dataBrokerClient, sessionState.ID) _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID)
if err != nil { if err != nil {
log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker") log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
@ -232,7 +235,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
newSession.Programmatic = true newSession.Programmatic = true
pbSession, err := session.Get(ctx, a.dataBrokerClient, s.ID) pbSession, err := session.Get(ctx, state.dataBrokerClient, s.ID)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
@ -278,7 +281,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
sessionState, err := a.getSessionFromCtx(ctx) sessionState, err := a.getSessionFromCtx(ctx)
if err == nil { if err == nil {
if s, _ := session.Get(ctx, a.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
log.Warn().Err(err).Msg("failed to revoke access token") log.Warn().Err(err).Msg("failed to revoke access token")
} }
@ -474,11 +477,8 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State,
} }
func (a *Authenticate) deleteSession(ctx context.Context, sessionID string) error { func (a *Authenticate) deleteSession(ctx context.Context, sessionID string) error {
if a.dataBrokerClient == nil { state := a.state.Load()
return nil return session.Delete(ctx, state.dataBrokerClient, sessionID)
}
err := session.Delete(ctx, a.dataBrokerClient, sessionID)
return err
} }
func (a *Authenticate) isAdmin(user string) bool { func (a *Authenticate) isAdmin(user string) bool {
@ -489,24 +489,26 @@ func (a *Authenticate) isAdmin(user string) bool {
// Dashboard renders the /.pomerium/ user dashboard. // Dashboard renders the /.pomerium/ user dashboard.
func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
state := a.state.Load()
s, err := a.getSessionFromCtx(r.Context()) s, err := a.getSessionFromCtx(r.Context())
if err != nil { if err != nil {
s.ID = uuid.New().String() s.ID = uuid.New().String()
} }
pbSession, err := session.Get(r.Context(), a.dataBrokerClient, s.ID) pbSession, err := session.Get(r.Context(), state.dataBrokerClient, s.ID)
if err != nil { if err != nil {
pbSession = &session.Session{ pbSession = &session.Session{
Id: s.ID, Id: s.ID,
} }
} }
pbUser, err := user.Get(r.Context(), a.dataBrokerClient, pbSession.GetUserId()) pbUser, err := user.Get(r.Context(), state.dataBrokerClient, pbSession.GetUserId())
if err != nil { if err != nil {
pbUser = &user.User{ pbUser = &user.User{
Id: pbSession.GetUserId(), Id: pbSession.GetUserId(),
} }
} }
pbDirectoryUser, err := directory.GetUser(r.Context(), a.dataBrokerClient, pbSession.GetUserId()) pbDirectoryUser, err := directory.GetUser(r.Context(), state.dataBrokerClient, pbSession.GetUserId())
if err != nil { if err != nil {
pbDirectoryUser = &directory.User{ pbDirectoryUser = &directory.User{
Id: pbSession.GetUserId(), Id: pbSession.GetUserId(),
@ -514,7 +516,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
} }
var groups []*directory.Group var groups []*directory.Group
for _, groupID := range pbDirectoryUser.GetGroupIds() { for _, groupID := range pbDirectoryUser.GetGroupIds() {
pbDirectoryGroup, err := directory.GetGroup(r.Context(), a.dataBrokerClient, groupID) pbDirectoryGroup, err := directory.GetGroup(r.Context(), state.dataBrokerClient, groupID)
if err != nil { if err != nil {
pbDirectoryGroup = &directory.Group{ pbDirectoryGroup = &directory.Group{
Id: groupID, Id: groupID,
@ -556,10 +558,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
} }
func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState *sessions.State, accessToken *oauth2.Token) error { func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState *sessions.State, accessToken *oauth2.Token) error {
if a.dataBrokerClient == nil { state := a.state.Load()
return nil
}
options := a.options.Load() options := a.options.Load()
sessionExpiry, _ := ptypes.TimestampProto(time.Now().Add(options.CookieExpire)) sessionExpiry, _ := ptypes.TimestampProto(time.Now().Add(options.CookieExpire))
@ -580,7 +579,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
} }
// if no user exists yet, create a new one // if no user exists yet, create a new one
currentUser, _ := user.Get(ctx, a.dataBrokerClient, s.GetUserId()) currentUser, _ := user.Get(ctx, state.dataBrokerClient, s.GetUserId())
if currentUser == nil { if currentUser == nil {
mu := manager.User{ mu := manager.User{
User: &user.User{ User: &user.User{
@ -591,13 +590,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err) return fmt.Errorf("authenticate: error retrieving user info: %w", err)
} }
_, err = user.Set(ctx, a.dataBrokerClient, mu.User) _, err = user.Set(ctx, state.dataBrokerClient, mu.User)
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err) return fmt.Errorf("authenticate: error saving user: %w", err)
} }
} }
res, err := session.Set(ctx, a.dataBrokerClient, s) res, err := session.Set(ctx, state.dataBrokerClient, s)
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error saving session: %w", err) return fmt.Errorf("authenticate: error saving session: %w", err)
} }

View file

@ -152,7 +152,6 @@ func TestAuthenticate_SignIn(t *testing.T) {
redirectURL: uriParseHelper("https://some.example"), redirectURL: uriParseHelper("https://some.example"),
sharedEncoder: tt.encoder, sharedEncoder: tt.encoder,
encryptedEncoder: tt.encoder, encryptedEncoder: tt.encoder,
}),
dataBrokerClient: mockDataBrokerServiceClient{ dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{ data, err := ptypes.MarshalAny(&session.Session{
@ -172,6 +171,8 @@ func TestAuthenticate_SignIn(t *testing.T) {
}, nil }, nil
}, },
}, },
}),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(), provider: identity.NewAtomicAuthenticator(),
} }
@ -237,8 +238,6 @@ func TestAuthenticate_SignOut(t *testing.T) {
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
encryptedEncoder: mock.Encoder{}, encryptedEncoder: mock.Encoder{},
sharedEncoder: mock.Encoder{}, sharedEncoder: mock.Encoder{},
}),
templates: template.Must(frontend.NewTemplates()),
dataBrokerClient: mockDataBrokerServiceClient{ dataBrokerClient: mockDataBrokerServiceClient{
delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { delete: func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
return nil, nil return nil, nil
@ -261,6 +260,8 @@ func TestAuthenticate_SignOut(t *testing.T) {
}, nil }, nil
}, },
}, },
}),
templates: template.Must(frontend.NewTemplates()),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(), provider: identity.NewAtomicAuthenticator(),
} }
@ -347,6 +348,14 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
authURL, _ := url.Parse(tt.authenticateURL) authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{ a := &Authenticate{
state: newAtomicAuthenticateState(&authenticateState{ state: newAtomicAuthenticateState(&authenticateState{
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return nil, fmt.Errorf("not implemented")
},
set: func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
},
},
redirectURL: authURL, redirectURL: authURL,
sessionStore: tt.session, sessionStore: tt.session,
cookieCipher: aead, cookieCipher: aead,
@ -477,7 +486,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
cookieCipher: aead, cookieCipher: aead,
encryptedEncoder: signer, encryptedEncoder: signer,
sharedEncoder: signer, sharedEncoder: signer,
}),
dataBrokerClient: mockDataBrokerServiceClient{ dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{ data, err := ptypes.MarshalAny(&session.Session{
@ -497,6 +505,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
}, nil }, nil
}, },
}, },
}),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(), provider: identity.NewAtomicAuthenticator(),
} }
@ -593,8 +602,6 @@ func TestAuthenticate_Dashboard(t *testing.T) {
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
encryptedEncoder: signer, encryptedEncoder: signer,
sharedEncoder: signer, sharedEncoder: signer,
}),
templates: template.Must(frontend.NewTemplates()),
dataBrokerClient: mockDataBrokerServiceClient{ dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{ data, err := ptypes.MarshalAny(&session.Session{
@ -616,6 +623,8 @@ func TestAuthenticate_Dashboard(t *testing.T) {
}, nil }, nil
}, },
}, },
}),
templates: template.Must(frontend.NewTemplates()),
} }
u, _ := url.Parse("/") u, _ := url.Parse("/")
r := httptest.NewRequest(tt.method, u.String(), nil) r := httptest.NewRequest(tt.method, u.String(), nil)
@ -646,6 +655,7 @@ type mockDataBrokerServiceClient struct {
delete func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) delete func(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
set func(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error)
} }
func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker.DeleteRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
@ -655,3 +665,7 @@ func (m mockDataBrokerServiceClient) Delete(ctx context.Context, in *databroker.
func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return m.get(ctx, in, opts...) return m.get(ctx, in, opts...)
} }
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
return m.set(ctx, in, opts...)
}

View file

@ -20,6 +20,8 @@ import (
"github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
type authenticateState struct { type authenticateState struct {
@ -42,6 +44,8 @@ type authenticateState struct {
sessionLoaders []sessions.SessionLoader sessionLoaders []sessions.SessionLoader
jwk *jose.JSONWebKeySet jwk *jose.JSONWebKeySet
dataBrokerClient databroker.DataBrokerServiceClient
} }
func newAuthenticateState() *authenticateState { func newAuthenticateState() *authenticateState {
@ -52,6 +56,11 @@ func newAuthenticateState() *authenticateState {
} }
func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) { func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) {
err := ValidateOptions(cfg.Options)
if err != nil {
return nil, err
}
state := &authenticateState{} state := &authenticateState{}
state.redirectURL, _ = urlutil.DeepCopy(cfg.Options.AuthenticateURL) state.redirectURL, _ = urlutil.DeepCopy(cfg.Options.AuthenticateURL)
@ -63,7 +72,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
} }
// shared state encoder setup // shared state encoder setup
var err error
state.sharedEncoder, err = jws.NewHS256Signer([]byte(cfg.Options.SharedKey), cfg.Options.GetAuthenticateURL().Host) state.sharedEncoder, err = jws.NewHS256Signer([]byte(cfg.Options.SharedKey), cfg.Options.GetAuthenticateURL().Host)
if err != nil { if err != nil {
return nil, err return nil, err
@ -106,6 +114,22 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
state.jwk.Keys = append(state.jwk.Keys, *jwk) state.jwk.Keys = append(state.jwk.Keys, *jwk)
} }
dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
Addr: cfg.Options.DataBrokerURL,
OverrideCertificateName: cfg.Options.OverrideCertificateName,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure,
ServiceName: cfg.Options.Services,
})
if err != nil {
return nil, err
}
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
return state, nil return state, nil
} }