mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
"ver" field is not specified by RFC 7519, so in practice, most providers return it as string, but okta returns it as number, which cause okta authenticate broken. To fix it, we handle "ver" field more generally, to allow both string and number in json payload.
133 lines
3.5 KiB
Go
133 lines
3.5 KiB
Go
package sessions
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"gopkg.in/square/go-jose.v2/jwt"
|
|
|
|
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
|
)
|
|
|
|
// ErrMissingID is the error for a session state that has no ID set.
|
|
var ErrMissingID = errors.New("invalid session: missing id")
|
|
|
|
// timeNow is time.Now but pulled out as a variable for tests.
|
|
var timeNow = time.Now
|
|
|
|
// Version represents "ver" field in JWT public claims.
|
|
//
|
|
// The field is not specified by RFC 7519, so providers can
|
|
// return either string or number (like okta).
|
|
type Version string
|
|
|
|
// String implements fmt.Stringer interface.
|
|
func (v *Version) String() string {
|
|
return string(*v)
|
|
}
|
|
|
|
// UnmarshalJSON implements json.Unmarshaler interface.
|
|
func (v *Version) UnmarshalJSON(b []byte) error {
|
|
var tmp interface{}
|
|
if err := json.Unmarshal(b, &tmp); err != nil {
|
|
return err
|
|
}
|
|
switch val := tmp.(type) {
|
|
case string:
|
|
*v = Version(val)
|
|
case float64:
|
|
*v = Version(fmt.Sprintf("%g", val))
|
|
default:
|
|
return errors.New("invalid type for Version")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// State is our object that keeps track of a user's session state
|
|
type State struct {
|
|
// Public claim values (as specified in RFC 7519).
|
|
Issuer string `json:"iss,omitempty"`
|
|
Subject string `json:"sub,omitempty"`
|
|
Audience jwt.Audience `json:"aud,omitempty"`
|
|
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
|
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
|
ID string `json:"jti,omitempty"`
|
|
|
|
// "ver" field is not standard, but is supported by most providers.
|
|
Version Version `json:"ver,omitempty"`
|
|
|
|
// Azure returns OID which should be used instead of subject.
|
|
OID string `json:"oid,omitempty"`
|
|
|
|
// Impersonate-able fields
|
|
ImpersonateEmail string `json:"impersonate_email,omitempty"`
|
|
ImpersonateGroups []string `json:"impersonate_groups,omitempty"`
|
|
|
|
// Programmatic whether this state is used for machine-to-machine
|
|
// programatic access.
|
|
Programmatic bool `json:"programatic"`
|
|
}
|
|
|
|
// NewSession updates issuer, audience, and issuance timestamps but keeps
|
|
// parent expiry.
|
|
func NewSession(s *State, issuer string, audience []string) State {
|
|
newState := *s
|
|
newState.IssuedAt = jwt.NewNumericDate(timeNow())
|
|
newState.NotBefore = newState.IssuedAt
|
|
newState.Audience = audience
|
|
newState.Issuer = issuer
|
|
return newState
|
|
}
|
|
|
|
// IsExpired returns true if the users's session is expired.
|
|
func (s *State) IsExpired() bool {
|
|
return s.Expiry != nil && timeNow().After(s.Expiry.Time())
|
|
}
|
|
|
|
// Impersonating returns if the request is impersonating.
|
|
func (s *State) Impersonating() bool {
|
|
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
|
|
}
|
|
|
|
// UserID returns the corresponding user ID for a session.
|
|
func (s *State) UserID(provider string) string {
|
|
if s.OID != "" {
|
|
return databroker.GetUserID(provider, s.OID)
|
|
}
|
|
return databroker.GetUserID(provider, s.Subject)
|
|
}
|
|
|
|
// SetImpersonation sets impersonation user and groups.
|
|
func (s *State) SetImpersonation(email, groups string) {
|
|
s.ImpersonateEmail = email
|
|
if groups == "" {
|
|
s.ImpersonateGroups = nil
|
|
} else {
|
|
s.ImpersonateGroups = strings.Split(groups, ",")
|
|
}
|
|
}
|
|
|
|
// UnmarshalJSON returns a State struct from JSON. Additionally munges
|
|
// a user's session by using by setting `user` claim to `sub` if empty.
|
|
func (s *State) UnmarshalJSON(data []byte) error {
|
|
type StateAlias State
|
|
a := &struct {
|
|
*StateAlias
|
|
}{
|
|
StateAlias: (*StateAlias)(s),
|
|
}
|
|
|
|
if err := json.Unmarshal(data, &a); err != nil {
|
|
return err
|
|
}
|
|
|
|
if s.ID == "" {
|
|
return ErrMissingID
|
|
}
|
|
|
|
return nil
|
|
}
|