mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
internal/sessions: handle claims "ver" field generally (#990)
"ver" field is not specified by RFC 7519, so in practice, most providers return it as string, but okta returns it as number, which cause okta authenticate broken. To fix it, we handle "ver" field more generally, to allow both string and number in json payload.
This commit is contained in:
parent
1e3c381e1e
commit
505ff5cc5c
4 changed files with 62 additions and 4 deletions
|
@ -571,7 +571,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
|||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving session: %w", err)
|
||||
}
|
||||
sessionState.Version = res.GetServerVersion()
|
||||
sessionState.Version = sessions.Version(res.GetServerVersion())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -48,10 +48,10 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
|
|||
// only accept sessions whose databroker server versions match
|
||||
if sessionState != nil {
|
||||
a.dataBrokerDataLock.RLock()
|
||||
if a.dataBrokerSessionServerVersion != sessionState.Version {
|
||||
if a.dataBrokerSessionServerVersion != sessionState.Version.String() {
|
||||
log.Warn().
|
||||
Str("server_version", a.dataBrokerSessionServerVersion).
|
||||
Str("session_version", sessionState.Version).
|
||||
Str("session_version", sessionState.Version.String()).
|
||||
Msg("clearing session due to invalid version")
|
||||
sessionState = nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package sessions
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -17,6 +18,34 @@ var ErrMissingID = errors.New("invalid session: missing id")
|
|||
// timeNow is time.Now but pulled out as a variable for tests.
|
||||
var timeNow = time.Now
|
||||
|
||||
// Version represents "ver" field in JWT public claims.
|
||||
//
|
||||
// The field is not specified by RFC 7519, so providers can
|
||||
// return either string or number (like okta).
|
||||
type Version string
|
||||
|
||||
// String implements fmt.Stringer interface.
|
||||
func (v *Version) String() string {
|
||||
return string(*v)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler interface.
|
||||
func (v *Version) UnmarshalJSON(b []byte) error {
|
||||
var tmp interface{}
|
||||
if err := json.Unmarshal(b, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
switch val := tmp.(type) {
|
||||
case string:
|
||||
*v = Version(val)
|
||||
case float64:
|
||||
*v = Version(fmt.Sprintf("%g", val))
|
||||
default:
|
||||
return errors.New("invalid type for Version")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// State is our object that keeps track of a user's session state
|
||||
type State struct {
|
||||
// Public claim values (as specified in RFC 7519).
|
||||
|
@ -27,7 +56,9 @@ type State struct {
|
|||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||
ID string `json:"jti,omitempty"`
|
||||
Version string `json:"ver,omitempty"`
|
||||
|
||||
// "ver" field is not standard, but is supported by most providers.
|
||||
Version Version `json:"ver,omitempty"`
|
||||
|
||||
// Azure returns OID which should be used instead of subject.
|
||||
OID string `json:"oid,omitempty"`
|
||||
|
|
|
@ -122,3 +122,30 @@ func TestState_UnmarshalJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonStr string
|
||||
wantVersion string
|
||||
wantErr bool
|
||||
}{
|
||||
{"Version is string", `"1"`, "1", false},
|
||||
{"Version is integer", `1`, "1", false},
|
||||
{"Version is float", `1.1`, "1.1", false},
|
||||
{"Invalid version", `[1]`, "", true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var v Version
|
||||
if err := v.UnmarshalJSON([]byte(tc.jsonStr)); (err != nil) != tc.wantErr {
|
||||
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tc.wantErr)
|
||||
}
|
||||
if !tc.wantErr && v.String() != tc.wantVersion {
|
||||
t.Errorf("mismatch version, want: %s, got: %s", tc.wantVersion, v.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue