internal/sessions: handle claims "ver" field generally (#990)

"ver" field is not specified by RFC 7519, so in practice, most providers
return it as string, but okta returns it as number, which cause okta
authenticate broken.

To fix it, we handle "ver" field more generally, to allow both string and
number in json payload.
This commit is contained in:
Cuong Manh Le 2020-06-24 22:06:17 +07:00 committed by GitHub
parent 1e3c381e1e
commit 505ff5cc5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 4 deletions

View file

@ -571,7 +571,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error saving session: %w", err) return fmt.Errorf("authenticate: error saving session: %w", err)
} }
sessionState.Version = res.GetServerVersion() sessionState.Version = sessions.Version(res.GetServerVersion())
return nil return nil
} }

View file

@ -48,10 +48,10 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
// only accept sessions whose databroker server versions match // only accept sessions whose databroker server versions match
if sessionState != nil { if sessionState != nil {
a.dataBrokerDataLock.RLock() a.dataBrokerDataLock.RLock()
if a.dataBrokerSessionServerVersion != sessionState.Version { if a.dataBrokerSessionServerVersion != sessionState.Version.String() {
log.Warn(). log.Warn().
Str("server_version", a.dataBrokerSessionServerVersion). Str("server_version", a.dataBrokerSessionServerVersion).
Str("session_version", sessionState.Version). Str("session_version", sessionState.Version.String()).
Msg("clearing session due to invalid version") Msg("clearing session due to invalid version")
sessionState = nil sessionState = nil
} }

View file

@ -3,6 +3,7 @@ package sessions
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strings" "strings"
"time" "time"
@ -17,6 +18,34 @@ var ErrMissingID = errors.New("invalid session: missing id")
// timeNow is time.Now but pulled out as a variable for tests. // timeNow is time.Now but pulled out as a variable for tests.
var timeNow = time.Now var timeNow = time.Now
// Version represents "ver" field in JWT public claims.
//
// The field is not specified by RFC 7519, so providers can
// return either string or number (like okta).
type Version string
// String implements fmt.Stringer interface.
func (v *Version) String() string {
return string(*v)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (v *Version) UnmarshalJSON(b []byte) error {
var tmp interface{}
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
switch val := tmp.(type) {
case string:
*v = Version(val)
case float64:
*v = Version(fmt.Sprintf("%g", val))
default:
return errors.New("invalid type for Version")
}
return nil
}
// State is our object that keeps track of a user's session state // State is our object that keeps track of a user's session state
type State struct { type State struct {
// Public claim values (as specified in RFC 7519). // Public claim values (as specified in RFC 7519).
@ -27,7 +56,9 @@ type State struct {
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"` ID string `json:"jti,omitempty"`
Version string `json:"ver,omitempty"`
// "ver" field is not standard, but is supported by most providers.
Version Version `json:"ver,omitempty"`
// Azure returns OID which should be used instead of subject. // Azure returns OID which should be used instead of subject.
OID string `json:"oid,omitempty"` OID string `json:"oid,omitempty"`

View file

@ -122,3 +122,30 @@ func TestState_UnmarshalJSON(t *testing.T) {
}) })
} }
} }
func TestVersion_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
jsonStr string
wantVersion string
wantErr bool
}{
{"Version is string", `"1"`, "1", false},
{"Version is integer", `1`, "1", false},
{"Version is float", `1.1`, "1.1", false},
{"Invalid version", `[1]`, "", true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var v Version
if err := v.UnmarshalJSON([]byte(tc.jsonStr)); (err != nil) != tc.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tc.wantErr)
}
if !tc.wantErr && v.String() != tc.wantVersion {
t.Errorf("mismatch version, want: %s, got: %s", tc.wantVersion, v.String())
}
})
}
}