package sessions // import "github.com/pomerium/pomerium/internal/sessions" import ( "encoding/base64" "encoding/json" "fmt" "strings" "time" "github.com/pomerium/pomerium/internal/cryptutil" ) // State is our object that keeps track of a user's session state type State struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` IDToken string `json:"id_token"` RefreshDeadline time.Time `json:"refresh_deadline"` Email string `json:"email"` User string `json:"user"` Groups []string `json:"groups"` ImpersonateEmail string ImpersonateGroups []string } // Valid returns an error if the users's session state is not valid. func (s *State) Valid() error { if s.Expired() { return ErrExpired } return nil } // ForceRefresh sets the refresh deadline to now. func (s *State) ForceRefresh() { s.RefreshDeadline = time.Now().Truncate(time.Second) } // Expired returns true if the refresh period has expired func (s *State) Expired() bool { return s.RefreshDeadline.Before(time.Now()) } // Impersonating returns if the request is impersonating. func (s *State) Impersonating() bool { return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0 } // RequestEmail is the email to make the request as. func (s *State) RequestEmail() string { if s.ImpersonateEmail != "" { return s.ImpersonateEmail } return s.Email } // RequestGroups returns the groups of the Groups making the request; uses // impersonating user if set. func (s *State) RequestGroups() string { if len(s.ImpersonateGroups) != 0 { return strings.Join(s.ImpersonateGroups, ",") } return strings.Join(s.Groups, ",") } type idToken struct { Issuer string `json:"iss"` Subject string `json:"sub"` Expiry jsonTime `json:"exp"` IssuedAt jsonTime `json:"iat"` Nonce string `json:"nonce"` AtHash string `json:"at_hash"` } // IssuedAt parses the IDToken's issue date and returns a valid go time.Time. func (s *State) IssuedAt() (time.Time, error) { payload, err := parseJWT(s.IDToken) if err != nil { return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err) } var token idToken if err := json.Unmarshal(payload, &token); err != nil { return time.Time{}, fmt.Errorf("internal/sessions: failed to unmarshal claims: %v", err) } return time.Time(token.IssuedAt), nil } // MarshalSession marshals the session state as JSON, encrypts the JSON using the // given cipher, and base64-encodes the result func MarshalSession(s *State, c cryptutil.Cipher) (string, error) { v, err := c.Marshal(s) if err != nil { return "", err } return v, nil } // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct func UnmarshalSession(value string, c cryptutil.Cipher) (*State, error) { s := &State{} err := c.Unmarshal(value, s) if err != nil { return nil, err } return s, nil } func parseJWT(p string) ([]byte, error) { parts := strings.Split(p, ".") if len(parts) < 2 { return nil, fmt.Errorf("internal/sessions: malformed jwt, expected 3 parts got %d", len(parts)) } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("internal/sessions: malformed jwt payload: %v", err) } return payload, nil } type jsonTime time.Time func (j *jsonTime) UnmarshalJSON(b []byte) error { var n json.Number if err := json.Unmarshal(b, &n); err != nil { return err } var unix int64 if t, err := n.Int64(); err == nil { unix = t } else { f, err := n.Float64() if err != nil { return err } unix = int64(f) } *j = jsonTime(time.Unix(unix, 0)) return nil }