pomerium/config/session.go

460 lines
14 KiB
Go

package config
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/sync/singleflight"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/authenticateapi"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
// A SessionStore saves and loads sessions based on the options.
type SessionStore struct {
store sessions.SessionStore
loader sessions.SessionLoader
options *Options
encoder encoding.MarshalUnmarshaler
}
var _ sessions.SessionStore = (*SessionStore)(nil)
// NewSessionStore creates a new SessionStore from the Options.
func NewSessionStore(options *Options) (*SessionStore, error) {
store := &SessionStore{
options: options,
}
sharedKey, err := options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("config/sessions: shared_key is required: %w", err)
}
store.encoder, err = jws.NewHS256Signer(sharedKey)
if err != nil {
return nil, fmt.Errorf("config/sessions: invalid session encoder: %w", err)
}
store.store, err = cookie.NewStore(func() cookie.Options {
return cookie.Options{
Name: options.CookieName,
Domain: options.CookieDomain,
Secure: true,
HTTPOnly: options.CookieHTTPOnly,
Expire: options.CookieExpire,
SameSite: options.GetCookieSameSite(),
}
}, store.encoder)
if err != nil {
return nil, err
}
headerStore := header.NewStore(store.encoder)
queryParamStore := queryparam.NewStore(store.encoder, urlutil.QuerySession)
store.loader = sessions.MultiSessionLoader(store.store, headerStore, queryParamStore)
return store, nil
}
// ClearSession clears the session.
func (store *SessionStore) ClearSession(w http.ResponseWriter, r *http.Request) {
store.store.ClearSession(w, r)
}
// LoadSession loads the session.
func (store *SessionStore) LoadSession(r *http.Request) (string, error) {
return store.loader.LoadSession(r)
}
// LoadSessionState loads the session state from a request.
func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, error) {
rawJWT, err := store.loader.LoadSession(r)
if err != nil {
return nil, err
}
var state sessions.State
err = store.encoder.Unmarshal([]byte(rawJWT), &state)
if err != nil {
return nil, err
}
return &state, nil
}
// LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches.
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) {
state, err := store.LoadSessionState(r)
if err != nil {
return nil, err
}
// confirm that the identity provider id matches the state
if state.IdentityProviderID != "" {
idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
if err != nil {
return nil, err
}
if idp.GetId() != state.IdentityProviderID {
return nil, fmt.Errorf("unexpected session state identity provider id: %s != %s",
idp.GetId(), state.IdentityProviderID)
}
}
return state, nil
}
// SaveSession saves the session.
func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v any) error {
return store.store.SaveSession(w, r, v)
}
type IncomingIDPTokenSessionCreator interface {
CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
}
type incomingIDPTokenSessionCreator struct {
timeNow func() time.Time
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
putRecords func(ctx context.Context, records []*databroker.Record) error
singleflight singleflight.Group
}
func NewIncomingIDPTokenSessionCreator(
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
putRecords func(ctx context.Context, records []*databroker.Record) error,
) IncomingIDPTokenSessionCreator {
return &incomingIDPTokenSessionCreator{timeNow: time.Now, getRecord: getRecord, putRecords: putRecords}
}
// CreateSession attempts to create a session for incoming idp access and
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
// If the tokens are not valid an error will be returned.
func (c *incomingIDPTokenSessionCreator) CreateSession(
ctx context.Context,
cfg *Config,
policy *Policy,
r *http.Request,
) (session *session.Session, err error) {
if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
}
if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
}
return nil, sessions.ErrNoSessionFound
}
func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
ctx context.Context,
cfg *Config,
policy *Policy,
rawAccessToken string,
) (*session.Session, error) {
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
}
sessionID := getAccessTokenSessionID(idp, rawAccessToken)
res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
AccessToken: rawAccessToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession)
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
ExpiresAt: s.ExpiresAt,
}
u, err := c.getUser(ctx, s.GetUserId())
if storage.IsNotFound(err) {
u = &user.User{Id: s.GetUserId()}
} else if err != nil {
return nil, fmt.Errorf("error retrieving existing user: %w", err)
}
c.fillUserFromIDPClaims(u, res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
})
if err != nil {
return nil, err
}
return res.(*session.Session), nil
}
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
ctx context.Context,
cfg *Config,
policy *Policy,
rawIdentityToken string,
) (*session.Session, error) {
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
}
sessionID := getIdentityTokenSessionID(idp, rawIdentityToken)
res, err, _ := c.singleflight.Do(sessionID, func() (any, error) {
s, err := c.getSession(ctx, sessionID)
if err == nil {
return s, nil
} else if !storage.IsNotFound(err) {
return nil, err
}
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
}
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
IdentityToken: rawIdentityToken,
IdentityProviderID: idp.GetId(),
})
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
} else if !res.Valid {
return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession)
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u, err := c.getUser(ctx, s.GetUserId())
if errors.Is(err, storage.ErrNotFound) {
u = &user.User{Id: s.GetUserId()}
} else if err != nil {
return nil, fmt.Errorf("error retrieving existing user: %w", err)
}
c.fillUserFromIDPClaims(u, res.Claims)
err = c.putSessionAndUser(ctx, s, u)
if err != nil {
return nil, fmt.Errorf("error saving session and user: %w", err)
}
return s, nil
})
if err != nil {
return nil, err
}
return res.(*session.Session), nil
}
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
cfg *Config,
sessionID string,
claims jwtutil.Claims,
) *session.Session {
now := c.timeNow()
s := new(session.Session)
s.Id = sessionID
if userID, ok := claims.GetUserID(); ok {
s.UserId = userID
}
if issuedAt, ok := claims.GetIssuedAt(); ok {
s.IssuedAt = timestamppb.New(issuedAt)
} else {
s.IssuedAt = timestamppb.New(now)
}
if expiresAt, ok := claims.GetExpirationTime(); ok {
s.ExpiresAt = timestamppb.New(expiresAt)
} else {
s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
}
s.AccessedAt = timestamppb.New(now)
s.AddClaims(identity.Claims(claims).Flatten())
if aud, ok := claims.GetAudience(); ok {
s.Audience = aud
}
s.RefreshDisabled = true
return s
}
func (c *incomingIDPTokenSessionCreator) fillUserFromIDPClaims(
u *user.User,
claims jwtutil.Claims,
) {
if userID, ok := claims.GetUserID(); ok {
u.Id = userID
}
if name, ok := claims.GetString("name"); ok {
u.Name = name
}
if email, ok := claims.GetString("email"); ok {
u.Email = email
}
u.AddClaims(identity.Claims(claims).Flatten())
}
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
if err != nil {
return nil, err
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, storage.ErrNotFound
}
s, ok := msg.(*session.Session)
if !ok {
return nil, storage.ErrNotFound
}
return s, nil
}
func (c *incomingIDPTokenSessionCreator) getUser(ctx context.Context, userID string) (*user.User, error) {
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
if err != nil {
return nil, err
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, storage.ErrNotFound
}
u, ok := msg.(*user.User)
if !ok {
return nil, storage.ErrNotFound
}
return u, nil
}
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
var records []*databroker.Record
if id := s.GetId(); id != "" {
records = append(records, &databroker.Record{
Type: grpcutil.GetTypeURL(s),
Id: id,
Data: protoutil.NewAny(s),
})
}
if id := u.GetId(); id != "" {
records = append(records, &databroker.Record{
Type: grpcutil.GetTypeURL(u),
Id: id,
Data: protoutil.NewAny(u),
})
}
return c.putRecords(ctx, records)
}
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
bearerTokenFormat := BearerTokenFormatUnknown
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
bearerTokenFormat = *cfg.Options.BearerTokenFormat
}
if policy != nil && policy.BearerTokenFormat != nil {
bearerTokenFormat = *policy.BearerTokenFormat
}
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
prefix := "Bearer "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
bearerTokenFormat == BearerTokenFormatIDPAccessToken {
return auth[len(prefix):], true
}
}
return "", false
}
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
bearerTokenFormat := BearerTokenFormatDefault
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
bearerTokenFormat = *cfg.Options.BearerTokenFormat
}
if policy != nil && policy.BearerTokenFormat != nil {
bearerTokenFormat = *policy.BearerTokenFormat
}
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
prefix := "Bearer "
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
return auth[len(prefix):], true
}
}
return "", false
}
var accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")
func getAccessTokenSessionID(idp *identitypb.Provider, rawAccessToken string) string {
namespace := accessTokenUUIDNamespace
// make the session ID per-idp settings
if idp != nil {
namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
}
return uuid.NewSHA1(namespace, []byte(rawAccessToken)).String()
}
var identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")
func getIdentityTokenSessionID(idp *identitypb.Provider, rawIdentityToken string) string {
namespace := identityTokenUUIDNamespace
// make the session ID per-idp settings
if idp != nil {
namespace = uuid.NewSHA1(namespace, []byte(idp.GetId()))
}
return uuid.NewSHA1(namespace, []byte(rawIdentityToken)).String()
}