mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-28 16:37:24 +02:00
authenticate: move databroker connection to state (#1292)
* authenticate: move databroker connection to state * re-use err * just return * remove nil checks
This commit is contained in:
parent
a1378c81f8
commit
882b6b54ee
4 changed files with 151 additions and 142 deletions
|
@ -144,14 +144,17 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession")
|
||||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
|
||||
sessionState, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
||||
if a.dataBrokerClient != nil {
|
||||
_, err = session.Get(ctx, a.dataBrokerClient, sessionState.ID)
|
||||
if state.dataBrokerClient != nil {
|
||||
_, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
|
@ -232,7 +235,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
||||
newSession.Programmatic = true
|
||||
|
||||
pbSession, err := session.Get(ctx, a.dataBrokerClient, s.ID)
|
||||
pbSession, err := session.Get(ctx, state.dataBrokerClient, s.ID)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
@ -278,7 +281,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
|
||||
sessionState, err := a.getSessionFromCtx(ctx)
|
||||
if err == nil {
|
||||
if s, _ := session.Get(ctx, a.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
|
||||
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
|
||||
if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to revoke access token")
|
||||
}
|
||||
|
@ -474,11 +477,8 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State,
|
|||
}
|
||||
|
||||
func (a *Authenticate) deleteSession(ctx context.Context, sessionID string) error {
|
||||
if a.dataBrokerClient == nil {
|
||||
return nil
|
||||
}
|
||||
err := session.Delete(ctx, a.dataBrokerClient, sessionID)
|
||||
return err
|
||||
state := a.state.Load()
|
||||
return session.Delete(ctx, state.dataBrokerClient, sessionID)
|
||||
}
|
||||
|
||||
func (a *Authenticate) isAdmin(user string) bool {
|
||||
|
@ -489,24 +489,26 @@ func (a *Authenticate) isAdmin(user string) bool {
|
|||
|
||||
// Dashboard renders the /.pomerium/ user dashboard.
|
||||
func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
|
||||
state := a.state.Load()
|
||||
|
||||
s, err := a.getSessionFromCtx(r.Context())
|
||||
if err != nil {
|
||||
s.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
pbSession, err := session.Get(r.Context(), a.dataBrokerClient, s.ID)
|
||||
pbSession, err := session.Get(r.Context(), state.dataBrokerClient, s.ID)
|
||||
if err != nil {
|
||||
pbSession = &session.Session{
|
||||
Id: s.ID,
|
||||
}
|
||||
}
|
||||
pbUser, err := user.Get(r.Context(), a.dataBrokerClient, pbSession.GetUserId())
|
||||
pbUser, err := user.Get(r.Context(), state.dataBrokerClient, pbSession.GetUserId())
|
||||
if err != nil {
|
||||
pbUser = &user.User{
|
||||
Id: pbSession.GetUserId(),
|
||||
}
|
||||
}
|
||||
pbDirectoryUser, err := directory.GetUser(r.Context(), a.dataBrokerClient, pbSession.GetUserId())
|
||||
pbDirectoryUser, err := directory.GetUser(r.Context(), state.dataBrokerClient, pbSession.GetUserId())
|
||||
if err != nil {
|
||||
pbDirectoryUser = &directory.User{
|
||||
Id: pbSession.GetUserId(),
|
||||
|
@ -514,7 +516,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
var groups []*directory.Group
|
||||
for _, groupID := range pbDirectoryUser.GetGroupIds() {
|
||||
pbDirectoryGroup, err := directory.GetGroup(r.Context(), a.dataBrokerClient, groupID)
|
||||
pbDirectoryGroup, err := directory.GetGroup(r.Context(), state.dataBrokerClient, groupID)
|
||||
if err != nil {
|
||||
pbDirectoryGroup = &directory.Group{
|
||||
Id: groupID,
|
||||
|
@ -556,10 +558,7 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
|
||||
func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState *sessions.State, accessToken *oauth2.Token) error {
|
||||
if a.dataBrokerClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
|
||||
sessionExpiry, _ := ptypes.TimestampProto(time.Now().Add(options.CookieExpire))
|
||||
|
@ -580,7 +579,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
|||
}
|
||||
|
||||
// if no user exists yet, create a new one
|
||||
currentUser, _ := user.Get(ctx, a.dataBrokerClient, s.GetUserId())
|
||||
currentUser, _ := user.Get(ctx, state.dataBrokerClient, s.GetUserId())
|
||||
if currentUser == nil {
|
||||
mu := manager.User{
|
||||
User: &user.User{
|
||||
|
@ -591,13 +590,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
|||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
|
||||
}
|
||||
_, err = user.Set(ctx, a.dataBrokerClient, mu.User)
|
||||
_, err = user.Set(ctx, state.dataBrokerClient, mu.User)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := session.Set(ctx, a.dataBrokerClient, s)
|
||||
res, err := session.Set(ctx, state.dataBrokerClient, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving session: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue