mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 07:37:33 +02:00
idp: delete sessions on refresh error, handle zero times in oauth/id tokens for refresh (#961)
This commit is contained in:
parent
452c9be06d
commit
b3ccdfe00f
5 changed files with 82 additions and 41 deletions
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/pomerium/csrf"
|
||||
"github.com/rs/cors"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
|
@ -491,7 +492,10 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
|||
}
|
||||
|
||||
sessionExpiry, _ := ptypes.TimestampProto(time.Now().Add(time.Hour))
|
||||
idTokenExpiry, _ := ptypes.TimestampProto(sessionState.Expiry.Time())
|
||||
var idTokenExpiry *timestamppb.Timestamp
|
||||
if sessionState.Expiry != nil {
|
||||
idTokenExpiry, _ = ptypes.TimestampProto(sessionState.Expiry.Time())
|
||||
}
|
||||
idTokenIssuedAt, _ := ptypes.TimestampProto(sessionState.IssuedAt.Time())
|
||||
oauthTokenExpiry, _ := ptypes.TimestampProto(accessToken.Expiry)
|
||||
|
||||
|
|
22
cache/session.go
vendored
22
cache/session.go
vendored
|
@ -51,14 +51,32 @@ func (srv *SessionServer) Add(ctx context.Context, req *session.AddRequest) (*se
|
|||
Str("session_id", req.GetSession().GetId()).
|
||||
Msg("add")
|
||||
|
||||
data, err := ptypes.MarshalAny(req.GetSession())
|
||||
s := req.GetSession()
|
||||
|
||||
data, err := ptypes.MarshalAny(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := srv.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: req.GetSession().GetId(),
|
||||
Id: s.GetId(),
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Version = res.GetServerVersion()
|
||||
|
||||
data, err = ptypes.MarshalAny(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err = srv.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: s.GetId(),
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -74,26 +74,22 @@ type Session struct {
|
|||
func (s Session) NextRefresh() time.Time {
|
||||
var tm time.Time
|
||||
|
||||
expiry, err := ptypes.Timestamp(s.GetOauthToken().GetExpiresAt())
|
||||
if err == nil {
|
||||
expiry = expiry.Add(-s.gracePeriod)
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
if s.GetOauthToken().GetExpiresAt() != nil {
|
||||
expiry, err := ptypes.Timestamp(s.GetOauthToken().GetExpiresAt())
|
||||
if err == nil && !expiry.IsZero() {
|
||||
expiry = expiry.Add(-s.gracePeriod)
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expiry, err = ptypes.Timestamp(s.GetIdToken().GetExpiresAt())
|
||||
if err == nil {
|
||||
expiry = expiry.Add(-s.gracePeriod)
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
}
|
||||
}
|
||||
|
||||
expiry, err = ptypes.Timestamp(s.GetExpiresAt())
|
||||
if err == nil {
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
if s.GetExpiresAt() != nil {
|
||||
expiry, err := ptypes.Timestamp(s.GetExpiresAt())
|
||||
if err == nil && !expiry.IsZero() {
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,17 +48,10 @@ func TestSession_NextRefresh(t *testing.T) {
|
|||
}
|
||||
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh())
|
||||
|
||||
tm3 := time.Date(2020, 6, 5, 12, 30, 0, 0, time.UTC)
|
||||
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
|
||||
pbtm3, _ := ptypes.TimestampProto(tm3)
|
||||
s.IdToken = &session.IDToken{
|
||||
ExpiresAt: pbtm3,
|
||||
}
|
||||
assert.Equal(t, tm3.Add(-time.Second*10), s.NextRefresh())
|
||||
|
||||
tm4 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
|
||||
pbtm4, _ := ptypes.TimestampProto(tm4)
|
||||
s.ExpiresAt = pbtm4
|
||||
assert.Equal(t, tm4, s.NextRefresh())
|
||||
s.ExpiresAt = pbtm3
|
||||
assert.Equal(t, tm3, s.NextRefresh())
|
||||
}
|
||||
|
||||
func TestSession_UnmarshalJSON(t *testing.T) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package manager
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -262,15 +263,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("deleting expired session")
|
||||
s.DeletedAt, _ = ptypes.TimestampProto(time.Now())
|
||||
_, err = mgr.sessionClient.Add(ctx, &session.AddRequest{Session: s.Session})
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to delete session")
|
||||
return
|
||||
}
|
||||
mgr.deleteSession(ctx, s.Session)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -283,12 +276,19 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
}
|
||||
|
||||
newToken, err := mgr.authenticator.Refresh(ctx, fromOAuthToken(s.OauthToken), &s)
|
||||
if err != nil {
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token")
|
||||
return
|
||||
} else if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token, deleting session")
|
||||
mgr.deleteSession(ctx, s.Session)
|
||||
return
|
||||
}
|
||||
s.OauthToken = toOAuthToken(newToken)
|
||||
|
||||
|
@ -328,11 +328,18 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
}
|
||||
|
||||
err := mgr.authenticator.UpdateUserInfo(ctx, fromOAuthToken(s.OauthToken), &u)
|
||||
if err != nil {
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info")
|
||||
return
|
||||
} else if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info, deleting session")
|
||||
mgr.deleteSession(ctx, s.Session)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -551,3 +558,26 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
|||
Msg("failed to create user")
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
||||
pbSession.DeletedAt = ptypes.TimestampNow()
|
||||
_, err := mgr.sessionClient.Add(ctx, &session.AddRequest{Session: pbSession})
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("session_id", pbSession.GetId()).
|
||||
Msg("failed to delete session")
|
||||
}
|
||||
}
|
||||
|
||||
func isTemporaryError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
if e, ok := err.(interface{ Temporary() bool }); ok && e.Temporary() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue