mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 18:33:19 +02:00
core/authorize: check for expired tokens (#4547)
core/authorize: check for expired tokens (#4543) * core/authorize: check for expired tokens * Update pkg/grpc/session/session.go * lint * fix zero timestamps * fix --------- Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com> Co-authored-by: Denis Mishin <dmishin@pomerium.com>
This commit is contained in:
parent
b904242e25
commit
57aead4eda
7 changed files with 127 additions and 1 deletions
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
type sessionOrServiceAccount interface {
|
type sessionOrServiceAccount interface {
|
||||||
GetUserId() string
|
GetUserId() string
|
||||||
|
Validate() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDataBrokerRecord(
|
func getDataBrokerRecord(
|
||||||
|
@ -77,6 +78,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s = msg.(sessionOrServiceAccount)
|
s = msg.(sessionOrServiceAccount)
|
||||||
|
if err := s.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := s.(*session.Session); ok {
|
if _, ok := s.(*session.Session); ok {
|
||||||
a.accessTracker.TrackSessionAccess(sessionID)
|
a.accessTracker.TrackSessionAccess(sessionID)
|
||||||
|
|
|
@ -7,7 +7,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -54,3 +57,20 @@ func Test_getDataBrokerRecord(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
opt := config.NewDefaultOptions()
|
||||||
|
a, err := New(&config.Config{Options: opt})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}
|
||||||
|
sq := storage.NewStaticQuerier(s1)
|
||||||
|
qctx := storage.WithQuerier(ctx, sq)
|
||||||
|
_, err = a.getDataBrokerSessionOrServiceAccount(qctx, "s1", 0)
|
||||||
|
assert.ErrorIs(t, err, session.ErrSessionExpired)
|
||||||
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
if sessionState != nil {
|
if sessionState != nil {
|
||||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account")
|
log.Warn(ctx).Err(err).Msg("clearing session due to missing or invalid session or service account")
|
||||||
sessionState = nil
|
sessionState = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ package session
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
@ -86,3 +87,22 @@ func (x *Session) RemoveDeviceCredentialID(deviceCredentialID string) {
|
||||||
return el.GetId() != deviceCredentialID
|
return el.GetId() != deviceCredentialID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrSessionExpired indicates the session has expired
|
||||||
|
var ErrSessionExpired = fmt.Errorf("session has expired")
|
||||||
|
|
||||||
|
// Validate returns an error if the session is not valid.
|
||||||
|
func (x *Session) Validate() error {
|
||||||
|
now := time.Now()
|
||||||
|
for name, expiresAt := range map[string]*timestamppb.Timestamp{
|
||||||
|
"session": x.GetExpiresAt(),
|
||||||
|
"access_token": x.GetOauthToken().GetExpiresAt(),
|
||||||
|
"id_token": x.GetIdToken().GetExpiresAt(),
|
||||||
|
} {
|
||||||
|
if expiresAt.AsTime().Year() > 1970 && now.After(expiresAt.AsTime()) {
|
||||||
|
return fmt.Errorf("%w: %s expired at %s", ErrSessionExpired, name, expiresAt.AsTime())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
32
pkg/grpc/session/session_test.go
Normal file
32
pkg/grpc/session/session_test.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSession_Validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t0 := timestamppb.New(time.Now().Add(-time.Second))
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
session *Session
|
||||||
|
expect error
|
||||||
|
}{
|
||||||
|
{"valid", &Session{}, nil},
|
||||||
|
{"expired", &Session{ExpiresAt: t0}, ErrSessionExpired},
|
||||||
|
{"expired id token", &Session{IdToken: &IDToken{ExpiresAt: t0}}, ErrSessionExpired},
|
||||||
|
{"expired oauth token", &Session{OauthToken: &OAuthToken{ExpiresAt: t0}}, ErrSessionExpired},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
assert.ErrorIs(t, tc.session.Validate(), tc.expect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,8 +3,11 @@ package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -28,6 +31,23 @@ func PutServiceAccount(ctx context.Context, client databroker.DataBrokerServiceC
|
||||||
return databroker.Put(ctx, client, serviceAccount)
|
return databroker.Put(ctx, client, serviceAccount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrServiceAccountExpired indicates the service account has expired.
|
||||||
|
var ErrServiceAccountExpired = fmt.Errorf("service account has expired")
|
||||||
|
|
||||||
|
// Validate returns an error if the service account is not valid.
|
||||||
|
func (x *ServiceAccount) Validate() error {
|
||||||
|
now := time.Now()
|
||||||
|
for _, expiresAt := range []*timestamppb.Timestamp{
|
||||||
|
x.GetExpiresAt(),
|
||||||
|
} {
|
||||||
|
if expiresAt.AsTime().Year() > 1970 && now.After(expiresAt.AsTime()) {
|
||||||
|
return ErrServiceAccountExpired
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddClaims adds the flattened claims to the user.
|
// AddClaims adds the flattened claims to the user.
|
||||||
func (x *User) AddClaims(claims identity.FlattenedClaims) {
|
func (x *User) AddClaims(claims identity.FlattenedClaims) {
|
||||||
if x.Claims == nil {
|
if x.Claims == nil {
|
||||||
|
|
30
pkg/grpc/user/user_test.go
Normal file
30
pkg/grpc/user/user_test.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServiceAccount_Validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t0 := timestamppb.New(time.Now().Add(-time.Second))
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
serviceAccount *ServiceAccount
|
||||||
|
expect error
|
||||||
|
}{
|
||||||
|
{"valid", &ServiceAccount{}, nil},
|
||||||
|
{"expired", &ServiceAccount{ExpiresAt: t0}, ErrServiceAccountExpired},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
assert.ErrorIs(t, tc.serviceAccount.Validate(), tc.expect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue