From 505ff5cc5c6329fd3d7072af81acf8ab2c86a2be Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 24 Jun 2020 22:06:17 +0700 Subject: [PATCH] internal/sessions: handle claims "ver" field generally (#990) "ver" field is not specified by RFC 7519, so in practice, most providers return it as string, but okta returns it as number, which cause okta authenticate broken. To fix it, we handle "ver" field more generally, to allow both string and number in json payload. --- authenticate/handlers.go | 2 +- authorize/grpc.go | 4 ++-- internal/sessions/state.go | 33 ++++++++++++++++++++++++++++++++- internal/sessions/state_test.go | 27 +++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9da49987e..16c6247d1 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -571,7 +571,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState if err != nil { return fmt.Errorf("authenticate: error saving session: %w", err) } - sessionState.Version = res.GetServerVersion() + sessionState.Version = sessions.Version(res.GetServerVersion()) return nil } diff --git a/authorize/grpc.go b/authorize/grpc.go index 348c8a6d7..72991bb6b 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -48,10 +48,10 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe // only accept sessions whose databroker server versions match if sessionState != nil { a.dataBrokerDataLock.RLock() - if a.dataBrokerSessionServerVersion != sessionState.Version { + if a.dataBrokerSessionServerVersion != sessionState.Version.String() { log.Warn(). Str("server_version", a.dataBrokerSessionServerVersion). - Str("session_version", sessionState.Version). + Str("session_version", sessionState.Version.String()). Msg("clearing session due to invalid version") sessionState = nil } diff --git a/internal/sessions/state.go b/internal/sessions/state.go index f14610d50..64d31d00b 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -3,6 +3,7 @@ package sessions import ( "encoding/json" "errors" + "fmt" "strings" "time" @@ -17,6 +18,34 @@ var ErrMissingID = errors.New("invalid session: missing id") // timeNow is time.Now but pulled out as a variable for tests. var timeNow = time.Now +// Version represents "ver" field in JWT public claims. +// +// The field is not specified by RFC 7519, so providers can +// return either string or number (like okta). +type Version string + +// String implements fmt.Stringer interface. +func (v *Version) String() string { + return string(*v) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (v *Version) UnmarshalJSON(b []byte) error { + var tmp interface{} + if err := json.Unmarshal(b, &tmp); err != nil { + return err + } + switch val := tmp.(type) { + case string: + *v = Version(val) + case float64: + *v = Version(fmt.Sprintf("%g", val)) + default: + return errors.New("invalid type for Version") + } + return nil +} + // State is our object that keeps track of a user's session state type State struct { // Public claim values (as specified in RFC 7519). @@ -27,7 +56,9 @@ type State struct { NotBefore *jwt.NumericDate `json:"nbf,omitempty"` IssuedAt *jwt.NumericDate `json:"iat,omitempty"` ID string `json:"jti,omitempty"` - Version string `json:"ver,omitempty"` + + // "ver" field is not standard, but is supported by most providers. + Version Version `json:"ver,omitempty"` // Azure returns OID which should be used instead of subject. OID string `json:"oid,omitempty"` diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index 3cb807a41..006b186eb 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -122,3 +122,30 @@ func TestState_UnmarshalJSON(t *testing.T) { }) } } + +func TestVersion_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + jsonStr string + wantVersion string + wantErr bool + }{ + {"Version is string", `"1"`, "1", false}, + {"Version is integer", `1`, "1", false}, + {"Version is float", `1.1`, "1.1", false}, + {"Invalid version", `[1]`, "", true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + var v Version + if err := v.UnmarshalJSON([]byte(tc.jsonStr)); (err != nil) != tc.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr && v.String() != tc.wantVersion { + t.Errorf("mismatch version, want: %s, got: %s", tc.wantVersion, v.String()) + } + }) + } +}