diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9da49987e..16c6247d1 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -571,7 +571,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState if err != nil { return fmt.Errorf("authenticate: error saving session: %w", err) } - sessionState.Version = res.GetServerVersion() + sessionState.Version = sessions.Version(res.GetServerVersion()) return nil } diff --git a/authorize/grpc.go b/authorize/grpc.go index 348c8a6d7..72991bb6b 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -48,10 +48,10 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe // only accept sessions whose databroker server versions match if sessionState != nil { a.dataBrokerDataLock.RLock() - if a.dataBrokerSessionServerVersion != sessionState.Version { + if a.dataBrokerSessionServerVersion != sessionState.Version.String() { log.Warn(). Str("server_version", a.dataBrokerSessionServerVersion). - Str("session_version", sessionState.Version). + Str("session_version", sessionState.Version.String()). Msg("clearing session due to invalid version") sessionState = nil } diff --git a/internal/sessions/state.go b/internal/sessions/state.go index f14610d50..64d31d00b 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -3,6 +3,7 @@ package sessions import ( "encoding/json" "errors" + "fmt" "strings" "time" @@ -17,6 +18,34 @@ var ErrMissingID = errors.New("invalid session: missing id") // timeNow is time.Now but pulled out as a variable for tests. var timeNow = time.Now +// Version represents "ver" field in JWT public claims. +// +// The field is not specified by RFC 7519, so providers can +// return either string or number (like okta). +type Version string + +// String implements fmt.Stringer interface. +func (v *Version) String() string { + return string(*v) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (v *Version) UnmarshalJSON(b []byte) error { + var tmp interface{} + if err := json.Unmarshal(b, &tmp); err != nil { + return err + } + switch val := tmp.(type) { + case string: + *v = Version(val) + case float64: + *v = Version(fmt.Sprintf("%g", val)) + default: + return errors.New("invalid type for Version") + } + return nil +} + // State is our object that keeps track of a user's session state type State struct { // Public claim values (as specified in RFC 7519). @@ -27,7 +56,9 @@ type State struct { NotBefore *jwt.NumericDate `json:"nbf,omitempty"` IssuedAt *jwt.NumericDate `json:"iat,omitempty"` ID string `json:"jti,omitempty"` - Version string `json:"ver,omitempty"` + + // "ver" field is not standard, but is supported by most providers. + Version Version `json:"ver,omitempty"` // Azure returns OID which should be used instead of subject. OID string `json:"oid,omitempty"` diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index 3cb807a41..006b186eb 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -122,3 +122,30 @@ func TestState_UnmarshalJSON(t *testing.T) { }) } } + +func TestVersion_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + jsonStr string + wantVersion string + wantErr bool + }{ + {"Version is string", `"1"`, "1", false}, + {"Version is integer", `1`, "1", false}, + {"Version is float", `1.1`, "1.1", false}, + {"Invalid version", `[1]`, "", true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + var v Version + if err := v.UnmarshalJSON([]byte(tc.jsonStr)); (err != nil) != tc.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr && v.String() != tc.wantVersion { + t.Errorf("mismatch version, want: %s, got: %s", tc.wantVersion, v.String()) + } + }) + } +}