package sessions // import "github.com/pomerium/pomerium/internal/sessions" import ( "encoding/base64" "encoding/json" "errors" "fmt" "strings" "time" "github.com/pomerium/pomerium/internal/cryptutil" ) const MaxCookieSize = 4096 var ( // ErrLifetimeExpired is an error for the lifetime deadline expiring ErrLifetimeExpired = errors.New("user lifetime expired") ) // SessionState is our object that keeps track of a user's session state type SessionState struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` IDToken string `json:"id_token"` RefreshDeadline time.Time `json:"refresh_deadline"` Email string `json:"email"` User string `json:"user"` Groups []string `json:"groups"` ImpersonateEmail string ImpersonateGroups []string } // RefreshPeriodExpired returns true if the refresh period has expired func (s *SessionState) RefreshPeriodExpired() bool { return isExpired(s.RefreshDeadline) } // Impersonating returns if the request is impersonating. func (s *SessionState) Impersonating() bool { return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0 } // RequestEmail is the email to make the request as. func (s *SessionState) RequestEmail() string { if s.ImpersonateEmail != "" { return s.ImpersonateEmail } return s.Email } // RequestGroups returns the groups of the Groups making the request; uses // impersonating user if set. func (s *SessionState) RequestGroups() string { if len(s.ImpersonateGroups) != 0 { return strings.Join(s.ImpersonateGroups, ",") } return strings.Join(s.Groups, ",") } type idToken struct { Issuer string `json:"iss"` Subject string `json:"sub"` Expiry jsonTime `json:"exp"` IssuedAt jsonTime `json:"iat"` Nonce string `json:"nonce"` AtHash string `json:"at_hash"` } // IssuedAt parses the IDToken's issue date and returns a valid go time.Time. func (s *SessionState) IssuedAt() (time.Time, error) { payload, err := parseJWT(s.IDToken) if err != nil { return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err) } var token idToken if err := json.Unmarshal(payload, &token); err != nil { return time.Time{}, fmt.Errorf("internal/sessions: failed to unmarshal claims: %v", err) } return time.Time(token.IssuedAt), nil } func isExpired(t time.Time) bool { return t.Before(time.Now()) } // MarshalSession marshals the session state as JSON, encrypts the JSON using the // given cipher, and base64-encodes the result func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) { v, err := c.Marshal(s) if err != nil { return "", err } if len(v) >= MaxCookieSize { return "", fmt.Errorf("session too large, got %d bytes", len(v)) } return v, nil } // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) { s := &SessionState{} err := c.Unmarshal(value, s) if err != nil { return nil, err } return s, nil } // ExtendDeadline returns the time extended by a given duration, truncated by second func ExtendDeadline(ttl time.Duration) time.Time { return time.Now().Add(ttl).Truncate(time.Second) } func parseJWT(p string) ([]byte, error) { parts := strings.Split(p, ".") if len(parts) < 2 { return nil, fmt.Errorf("internal/sessions: malformed jwt, expected 3 parts got %d", len(parts)) } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("internal/sessions: malformed jwt payload: %v", err) } return payload, nil } type jsonTime time.Time func (j *jsonTime) UnmarshalJSON(b []byte) error { var n json.Number if err := json.Unmarshal(b, &n); err != nil { return err } var unix int64 if t, err := n.Int64(); err == nil { unix = t } else { f, err := n.Float64() if err != nil { return err } unix = int64(f) } *j = jsonTime(time.Unix(unix, 0)) return nil }