diff --git a/authenticate/handlers.go b/authenticate/handlers.go index b283486c9..24b2e62c2 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -86,15 +86,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return httputil.NewError(http.StatusBadRequest, err) } - if err := s.Verify(r.Host); errors.Is(err, sessions.ErrExpired) { + if s.IsExpired() { ctx, err = a.refresh(w, r, &s) if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") return a.reauthenticateOrFail(w, r, err) } - } else if err != nil { - log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session") - return a.reauthenticateOrFail(w, r, err) } next.ServeHTTP(w, r.WithContext(ctx)) return nil @@ -164,9 +161,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { return httputil.NewError(http.StatusBadRequest, err) } - if err := s.Verify(r.Host); err != nil && !errors.Is(err, sessions.ErrExpired) { - return httputil.NewError(http.StatusBadRequest, err) - } + // user impersonation if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" { s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups)) @@ -376,10 +371,6 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { return httputil.NewError(http.StatusBadRequest, err) } - err = s.Verify(r.Host) - if err != nil && !errors.Is(err, sessions.ErrExpired) { - return httputil.NewError(http.StatusBadRequest, err) - } newSession, err := a.provider.Refresh(r.Context(), &s) if err != nil { return err @@ -425,9 +416,7 @@ func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error { if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { return httputil.NewError(http.StatusBadRequest, err) } - if err := s.Verify(r.Host); err != nil && !errors.Is(err, sessions.ErrExpired) { - return httputil.NewError(http.StatusBadRequest, err) - } + aud := strings.Split(r.FormValue(urlutil.QueryAudience), ",") routeSession := s.NewSession(r.Host, aud) routeSession.AccessTokenID = s.AccessTokenID diff --git a/internal/sessions/state.go b/internal/sessions/state.go index dfab6cf42..457fcf925 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -114,25 +114,18 @@ func (s State) RouteSession() *State { return &s } -// Verify returns an error if the users's session state is not valid. -func (s *State) Verify(audience string) error { +// IsExpired returns true if the users's session is expired. +func (s *State) IsExpired() bool { if s.Expiry != nil && timeNow().After(s.Expiry.Time()) { - return ErrExpired + return true } - // if we have an associated access token, check if that token has expired as well if s.AccessToken != nil && timeNow().After(s.AccessToken.Expiry) { - return ErrExpired + return true } - if len(s.Audience) != 0 { - if !s.Audience.Contains(audience) { - return ErrInvalidAudience - } - - } - return nil + return false } // Impersonating returns if the request is impersonating. diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index baffba552..d9db5e4df 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -48,7 +48,7 @@ func TestState_Impersonating(t *testing.T) { } } -func TestState_Verify(t *testing.T) { +func TestState_IsExpired(t *testing.T) { t.Parallel() tests := []struct { name string @@ -63,7 +63,6 @@ func TestState_Verify(t *testing.T) { }{ {"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false}, {"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true}, - {"bad audience", []string{"x", "y", "z"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true}, {"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true}, } for _, tt := range tests { @@ -75,8 +74,8 @@ func TestState_Verify(t *testing.T) { IssuedAt: tt.IssuedAt, AccessToken: tt.AccessToken, } - if err := s.Verify(tt.audience); (err != nil) != tt.wantErr { - t.Errorf("State.Verify() error = %v, wantErr %v", err, tt.wantErr) + if exp := s.IsExpired(); exp != tt.wantErr { + t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr) } }) }