mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 11:26:29 +02:00
145 lines
3.9 KiB
Go
145 lines
3.9 KiB
Go
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
)
|
|
|
|
var (
|
|
// ErrLifetimeExpired is an error for the lifetime deadline expiring
|
|
ErrLifetimeExpired = errors.New("user lifetime expired")
|
|
)
|
|
|
|
// SessionState is our object that keeps track of a user's session state
|
|
type SessionState struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
IDToken string `json:"id_token"`
|
|
RefreshDeadline time.Time `json:"refresh_deadline"`
|
|
|
|
Email string `json:"email"`
|
|
User string `json:"user"`
|
|
Groups []string `json:"groups"`
|
|
|
|
ImpersonateEmail string
|
|
ImpersonateGroups []string
|
|
}
|
|
|
|
// RefreshPeriodExpired returns true if the refresh period has expired
|
|
func (s *SessionState) RefreshPeriodExpired() bool {
|
|
return isExpired(s.RefreshDeadline)
|
|
}
|
|
|
|
// Impersonating returns if the request is impersonating.
|
|
func (s *SessionState) Impersonating() bool {
|
|
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
|
|
}
|
|
|
|
// RequestEmail is the email to make the request as.
|
|
func (s *SessionState) RequestEmail() string {
|
|
if s.ImpersonateEmail != "" {
|
|
return s.ImpersonateEmail
|
|
}
|
|
return s.Email
|
|
}
|
|
|
|
// RequestGroups returns the groups of the Groups making the request; uses
|
|
// impersonating user if set.
|
|
func (s *SessionState) RequestGroups() string {
|
|
if len(s.ImpersonateGroups) != 0 {
|
|
return strings.Join(s.ImpersonateGroups, ",")
|
|
}
|
|
return strings.Join(s.Groups, ",")
|
|
}
|
|
|
|
type idToken struct {
|
|
Issuer string `json:"iss"`
|
|
Subject string `json:"sub"`
|
|
Expiry jsonTime `json:"exp"`
|
|
IssuedAt jsonTime `json:"iat"`
|
|
Nonce string `json:"nonce"`
|
|
AtHash string `json:"at_hash"`
|
|
}
|
|
|
|
// IssuedAt parses the IDToken's issue date and returns a valid go time.Time.
|
|
func (s *SessionState) IssuedAt() (time.Time, error) {
|
|
payload, err := parseJWT(s.IDToken)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err)
|
|
}
|
|
var token idToken
|
|
if err := json.Unmarshal(payload, &token); err != nil {
|
|
return time.Time{}, fmt.Errorf("internal/sessions: failed to unmarshal claims: %v", err)
|
|
}
|
|
return time.Time(token.IssuedAt), nil
|
|
}
|
|
|
|
func isExpired(t time.Time) bool {
|
|
return t.Before(time.Now())
|
|
}
|
|
|
|
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
|
// given cipher, and base64-encodes the result
|
|
func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) {
|
|
v, err := c.Marshal(s)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
|
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
|
func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
|
|
s := &SessionState{}
|
|
err := c.Unmarshal(value, s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// ExtendDeadline returns the time extended by a given duration, truncated by second
|
|
func ExtendDeadline(ttl time.Duration) time.Time {
|
|
return time.Now().Add(ttl).Truncate(time.Second)
|
|
}
|
|
|
|
func parseJWT(p string) ([]byte, error) {
|
|
parts := strings.Split(p, ".")
|
|
if len(parts) < 2 {
|
|
return nil, fmt.Errorf("internal/sessions: malformed jwt, expected 3 parts got %d", len(parts))
|
|
}
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("internal/sessions: malformed jwt payload: %v", err)
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
type jsonTime time.Time
|
|
|
|
func (j *jsonTime) UnmarshalJSON(b []byte) error {
|
|
var n json.Number
|
|
if err := json.Unmarshal(b, &n); err != nil {
|
|
return err
|
|
}
|
|
var unix int64
|
|
|
|
if t, err := n.Int64(); err == nil {
|
|
unix = t
|
|
} else {
|
|
f, err := n.Float64()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
unix = int64(f)
|
|
}
|
|
*j = jsonTime(time.Unix(unix, 0))
|
|
return nil
|
|
}
|