pomerium/internal/sessions/state.go
2021-06-10 09:35:44 -06:00

111 lines
2.9 KiB
Go

package sessions
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-jose/go-jose/v3/jwt"
)
// ErrMissingID is the error for a session state that has no ID set.
var ErrMissingID = errors.New("invalid session: missing id")
// timeNow is time.Now but pulled out as a variable for tests.
var timeNow = time.Now
// Version represents "ver" field in JWT public claims.
//
// The field is not specified by RFC 7519, so providers can
// return either string or number (like okta).
type Version string
// String implements fmt.Stringer interface.
func (v *Version) String() string {
return string(*v)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (v *Version) UnmarshalJSON(b []byte) error {
var tmp interface{}
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
switch val := tmp.(type) {
case string:
*v = Version(val)
case float64:
*v = Version(fmt.Sprintf("%g", val))
default:
return errors.New("invalid type for Version")
}
return nil
}
// State is our object that keeps track of a user's session state
type State struct {
// Public claim values (as specified in RFC 7519).
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience jwt.Audience `json:"aud,omitempty"`
Expiry *jwt.NumericDate `json:"exp,omitempty"`
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
// "ver" field is not standard, but is supported by most providers.
Version Version `json:"ver,omitempty"`
// Azure returns OID which should be used instead of subject.
OID string `json:"oid,omitempty"`
// Programmatic whether this state is used for machine-to-machine
// programmatic access.
Programmatic bool `json:"programmatic"`
}
// NewSession updates issuer, audience, and issuance timestamps but keeps
// parent expiry.
func NewSession(s *State, issuer string, audience []string) State {
newState := *s
newState.IssuedAt = jwt.NewNumericDate(timeNow())
newState.NotBefore = newState.IssuedAt
newState.Audience = audience
newState.Issuer = issuer
return newState
}
// IsExpired returns true if the users's session is expired.
func (s *State) IsExpired() bool {
return s.Expiry != nil && timeNow().After(s.Expiry.Time())
}
// UserID returns the corresponding user ID for a session.
func (s *State) UserID(provider string) string {
if s.OID != "" {
return s.OID
}
return s.Subject
}
// UnmarshalJSON returns a State struct from JSON. Additionally munges
// a user's session by using by setting `user` claim to `sub` if empty.
func (s *State) UnmarshalJSON(data []byte) error {
type StateAlias State
a := &struct {
*StateAlias
}{
StateAlias: (*StateAlias)(s),
}
if err := json.Unmarshal(data, &a); err != nil {
return err
}
if s.ID == "" {
return ErrMissingID
}
return nil
}