diff --git a/authorize/databroker.go b/authorize/databroker.go index 4da2f4b5c..a65db4e27 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -13,6 +13,7 @@ import ( type sessionOrServiceAccount interface { GetUserId() string + Validate() error } func getDataBrokerRecord( @@ -77,6 +78,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount( return nil, err } s = msg.(sessionOrServiceAccount) + if err := s.Validate(); err != nil { + return nil, err + } if _, ok := s.(*session.Session); ok { a.accessTracker.TrackSessionAccess(sessionID) diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index d90803d9e..a9e441c21 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -7,7 +7,10 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/storage" @@ -54,3 +57,20 @@ func Test_getDataBrokerRecord(t *testing.T) { }) } } + +func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + opt := config.NewDefaultOptions() + a, err := New(&config.Config{Options: opt}) + require.NoError(t, err) + + s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))} + sq := storage.NewStaticQuerier(s1) + qctx := storage.WithQuerier(ctx, sq) + _, err = a.getDataBrokerSessionOrServiceAccount(qctx, "s1", 0) + assert.ErrorIs(t, err, session.ErrSessionExpired) +} diff --git a/authorize/grpc.go b/authorize/grpc.go index aca5f567d..6d9994802 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe if sessionState != nil { s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion) if err != nil { - log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account") + log.Warn(ctx).Err(err).Msg("clearing session due to missing or invalid session or service account") sessionState = nil } } diff --git a/pkg/grpc/session/session.go b/pkg/grpc/session/session.go index adeab5f39..edea20f14 100644 --- a/pkg/grpc/session/session.go +++ b/pkg/grpc/session/session.go @@ -4,6 +4,7 @@ package session import ( context "context" "fmt" + "time" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" @@ -86,3 +87,22 @@ func (x *Session) RemoveDeviceCredentialID(deviceCredentialID string) { return el.GetId() != deviceCredentialID }) } + +// ErrSessionExpired indicates the session has expired +var ErrSessionExpired = fmt.Errorf("session has expired") + +// Validate returns an error if the session is not valid. +func (x *Session) Validate() error { + now := time.Now() + for name, expiresAt := range map[string]*timestamppb.Timestamp{ + "session": x.GetExpiresAt(), + "access_token": x.GetOauthToken().GetExpiresAt(), + "id_token": x.GetIdToken().GetExpiresAt(), + } { + if expiresAt.AsTime().Year() > 1970 && now.After(expiresAt.AsTime()) { + return fmt.Errorf("%w: %s expired at %s", ErrSessionExpired, name, expiresAt.AsTime()) + } + } + + return nil +} diff --git a/pkg/grpc/session/session_test.go b/pkg/grpc/session/session_test.go new file mode 100644 index 000000000..b2ddb1e46 --- /dev/null +++ b/pkg/grpc/session/session_test.go @@ -0,0 +1,32 @@ +package session + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestSession_Validate(t *testing.T) { + t.Parallel() + + t0 := timestamppb.New(time.Now().Add(-time.Second)) + for _, tc := range []struct { + name string + session *Session + expect error + }{ + {"valid", &Session{}, nil}, + {"expired", &Session{ExpiresAt: t0}, ErrSessionExpired}, + {"expired id token", &Session{IdToken: &IDToken{ExpiresAt: t0}}, ErrSessionExpired}, + {"expired oauth token", &Session{OauthToken: &OAuthToken{ExpiresAt: t0}}, ErrSessionExpired}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + assert.ErrorIs(t, tc.session.Validate(), tc.expect) + }) + } +} diff --git a/pkg/grpc/user/user.go b/pkg/grpc/user/user.go index b3fb600fe..6c9e41dce 100644 --- a/pkg/grpc/user/user.go +++ b/pkg/grpc/user/user.go @@ -3,8 +3,11 @@ package user import ( context "context" + "fmt" + "time" "google.golang.org/protobuf/types/known/structpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -28,6 +31,23 @@ func PutServiceAccount(ctx context.Context, client databroker.DataBrokerServiceC return databroker.Put(ctx, client, serviceAccount) } +// ErrServiceAccountExpired indicates the service account has expired. +var ErrServiceAccountExpired = fmt.Errorf("service account has expired") + +// Validate returns an error if the service account is not valid. +func (x *ServiceAccount) Validate() error { + now := time.Now() + for _, expiresAt := range []*timestamppb.Timestamp{ + x.GetExpiresAt(), + } { + if expiresAt.AsTime().Year() > 1970 && now.After(expiresAt.AsTime()) { + return ErrServiceAccountExpired + } + } + + return nil +} + // AddClaims adds the flattened claims to the user. func (x *User) AddClaims(claims identity.FlattenedClaims) { if x.Claims == nil { diff --git a/pkg/grpc/user/user_test.go b/pkg/grpc/user/user_test.go new file mode 100644 index 000000000..e6e63fc55 --- /dev/null +++ b/pkg/grpc/user/user_test.go @@ -0,0 +1,30 @@ +package user + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestServiceAccount_Validate(t *testing.T) { + t.Parallel() + + t0 := timestamppb.New(time.Now().Add(-time.Second)) + for _, tc := range []struct { + name string + serviceAccount *ServiceAccount + expect error + }{ + {"valid", &ServiceAccount{}, nil}, + {"expired", &ServiceAccount{ExpiresAt: t0}, ErrServiceAccountExpired}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + assert.ErrorIs(t, tc.serviceAccount.Validate(), tc.expect) + }) + } +}