mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
authenticateflow: move stateless flow logic (#4820)
Consolidate all logic specific to the stateless authenticate flow into a a new Stateless type in a new package internal/authenticateflow. This is in preparation for adding a new Stateful type implementing the older stateful authenticate flow (from Pomerium v0.20 and previous). This change is intended as a pure refactoring of existing logic, with no changes in functionality.
This commit is contained in:
parent
3b2bdd059a
commit
b7896b3153
18 changed files with 823 additions and 461 deletions
31
internal/authenticateflow/authenticateflow.go
Normal file
31
internal/authenticateflow/authenticateflow.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Package authenticateflow implements the core authentication flow. This
|
||||
// includes creating and parsing sign-in redirect URLs, storing and retrieving
|
||||
// session data, and handling authentication callback URLs.
|
||||
package authenticateflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
||||
func populateUserFromClaims(u *user.User, claims map[string]interface{}) {
|
||||
if v, ok := claims["name"]; ok {
|
||||
u.Name = fmt.Sprint(v)
|
||||
}
|
||||
if v, ok := claims["email"]; ok {
|
||||
u.Email = fmt.Sprint(v)
|
||||
}
|
||||
if u.Claims == nil {
|
||||
u.Claims = make(map[string]*structpb.ListValue)
|
||||
}
|
||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||
u.Claims[k] = vs
|
||||
}
|
||||
}
|
87
internal/authenticateflow/events.go
Normal file
87
internal/authenticateflow/events.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package authenticateflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// AuthEventKind is the type of an authentication event
|
||||
type AuthEventKind string
|
||||
|
||||
const (
|
||||
// AuthEventSignInRequest is an authentication event for a sign in request before IdP redirect
|
||||
AuthEventSignInRequest AuthEventKind = "sign_in_request"
|
||||
// AuthEventSignInComplete is an authentication event for a sign in request after IdP redirect
|
||||
AuthEventSignInComplete AuthEventKind = "sign_in_complete"
|
||||
)
|
||||
|
||||
// AuthEvent is a log event for an authentication event
|
||||
type AuthEvent struct {
|
||||
// Event is the type of authentication event
|
||||
Event AuthEventKind
|
||||
// IP is the IP address of the client
|
||||
IP string
|
||||
// Version is the version of the Pomerium client
|
||||
Version string
|
||||
// RequestUUID is the UUID of the request
|
||||
RequestUUID string
|
||||
// PubKey is the public key of the client
|
||||
PubKey string
|
||||
// UID is the IdP user ID of the user
|
||||
UID *string
|
||||
// Email is the email of the user
|
||||
Email *string
|
||||
// Domain is the domain of the request (for sign in complete events)
|
||||
Domain *string
|
||||
}
|
||||
|
||||
// AuthEventFn is a function that handles an authentication event
|
||||
type AuthEventFn func(context.Context, AuthEvent)
|
||||
|
||||
// TODO: move into stateless.go; this is here for now just so that Git will
|
||||
// track the file history as a rename from authenticate/events.go.
|
||||
func (s *Stateless) logAuthenticateEvent(r *http.Request, profile *identitypb.Profile) {
|
||||
if s.authEventFn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
pub, params, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("log authenticate event: failed to decrypt request params")
|
||||
}
|
||||
|
||||
evt := AuthEvent{
|
||||
IP: httputil.GetClientIP(r),
|
||||
Version: params.Get(urlutil.QueryVersion),
|
||||
RequestUUID: params.Get(urlutil.QueryRequestUUID),
|
||||
PubKey: pub.String(),
|
||||
}
|
||||
|
||||
if uid := getUserClaim(profile, "sub"); uid != nil {
|
||||
evt.UID = uid
|
||||
}
|
||||
if email := getUserClaim(profile, "email"); email != nil {
|
||||
evt.Email = email
|
||||
}
|
||||
|
||||
if evt.UID != nil {
|
||||
evt.Event = AuthEventSignInComplete
|
||||
} else {
|
||||
evt.Event = AuthEventSignInRequest
|
||||
}
|
||||
|
||||
if redirectURL, err := url.Parse(params.Get(urlutil.QueryRedirectURI)); err == nil {
|
||||
domain := redirectURL.Hostname()
|
||||
evt.Domain = &domain
|
||||
}
|
||||
|
||||
s.authEventFn(ctx, evt)
|
||||
}
|
172
internal/authenticateflow/identityprofile.go
Normal file
172
internal/authenticateflow/identityprofile.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
package authenticateflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
||||
// An "identity profile" is an alternative to a session, used in the stateless
|
||||
// authenticate flow. An identity profile contains an IdP ID (to distinguish
|
||||
// between different IdP's or between different clients of the same IdP), a
|
||||
// user ID token, and an OAuth2 token.
|
||||
|
||||
var cookieChunker = httputil.NewCookieChunker()
|
||||
|
||||
// buildIdentityProfile populates an identity profile.
|
||||
func buildIdentityProfile(
|
||||
idpID string,
|
||||
claims identity.SessionClaims,
|
||||
oauthToken *oauth2.Token,
|
||||
) (*identitypb.Profile, error) {
|
||||
rawIDToken := []byte(claims.RawIDToken)
|
||||
rawOAuthToken, err := json.Marshal(oauthToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error marshaling oauth token: %w", err)
|
||||
}
|
||||
rawClaims, err := structpb.NewStruct(claims.Claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error creating claims struct: %w", err)
|
||||
}
|
||||
|
||||
return &identitypb.Profile{
|
||||
ProviderId: idpID,
|
||||
IdToken: rawIDToken,
|
||||
OauthToken: rawOAuthToken,
|
||||
Claims: rawClaims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadIdentityProfile loads an identity profile from a chunked set of cookies.
|
||||
func loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) {
|
||||
cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err)
|
||||
}
|
||||
|
||||
encrypted, err := base64.RawURLEncoding.DecodeString(cookie.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error decoding identity profile cookie: %w", err)
|
||||
}
|
||||
|
||||
decrypted, err := cryptutil.Decrypt(aead, encrypted, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error decrypting identity profile cookie: %w", err)
|
||||
}
|
||||
|
||||
var profile identitypb.Profile
|
||||
err = protojson.Unmarshal(decrypted, &profile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error unmarshaling identity profile cookie: %w", err)
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// storeIdentityProfile writes the identity profile to a chunked set of cookies.
|
||||
func storeIdentityProfile(
|
||||
w http.ResponseWriter,
|
||||
cookie *http.Cookie,
|
||||
aead cipher.AEAD,
|
||||
profile *identitypb.Profile,
|
||||
) error {
|
||||
decrypted, err := protojson.Marshal(profile)
|
||||
if err != nil {
|
||||
// this shouldn't happen
|
||||
panic(fmt.Errorf("error marshaling message: %w", err))
|
||||
}
|
||||
encrypted := cryptutil.Encrypt(aead, decrypted, nil)
|
||||
cookie.Name = urlutil.QueryIdentityProfile
|
||||
cookie.Value = base64.RawURLEncoding.EncodeToString(encrypted)
|
||||
cookie.Path = "/"
|
||||
return cookieChunker.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
// validateIdentityProfile checks expirations timestamps for the ID token and
|
||||
// OAuth2 token, and makes a user info request to the IdP in order to determine
|
||||
// whether the OAuth2 token is still valid.
|
||||
func validateIdentityProfile(
|
||||
ctx context.Context,
|
||||
authenticator identity.Authenticator,
|
||||
profile *identitypb.Profile,
|
||||
) error {
|
||||
oauthToken := new(oauth2.Token)
|
||||
err := json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid oauth token in profile: %w", err)
|
||||
}
|
||||
|
||||
if !oauthToken.Valid() {
|
||||
return fmt.Errorf("invalid oauth token in profile")
|
||||
}
|
||||
|
||||
var claims identity.SessionClaims
|
||||
err = authenticator.UpdateUserInfo(ctx, oauthToken, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating user info from oauth token: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State {
|
||||
claims := p.GetClaims().AsMap()
|
||||
|
||||
ss := sessions.NewState(p.GetProviderId())
|
||||
|
||||
// set the subject
|
||||
if v, ok := claims["sub"]; ok {
|
||||
ss.Subject = fmt.Sprint(v)
|
||||
} else if v, ok := claims["user"]; ok {
|
||||
ss.Subject = fmt.Sprint(v)
|
||||
}
|
||||
|
||||
// set the oid
|
||||
if v, ok := claims["oid"]; ok {
|
||||
ss.OID = fmt.Sprint(v)
|
||||
}
|
||||
|
||||
return ss
|
||||
}
|
||||
|
||||
func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) {
|
||||
claims := p.GetClaims().AsMap()
|
||||
oauthToken := new(oauth2.Token)
|
||||
_ = json.Unmarshal(p.GetOauthToken(), oauthToken)
|
||||
|
||||
s.UserId = ss.UserID()
|
||||
s.IssuedAt = timestamppb.Now()
|
||||
s.AccessedAt = timestamppb.Now()
|
||||
s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire))
|
||||
s.IdToken = &session.IDToken{
|
||||
Issuer: ss.Issuer,
|
||||
Subject: ss.Subject,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)),
|
||||
IssuedAt: timestamppb.Now(),
|
||||
Raw: string(p.GetIdToken()),
|
||||
}
|
||||
s.OauthToken = manager.ToOAuthToken(oauthToken)
|
||||
if s.Claims == nil {
|
||||
s.Claims = make(map[string]*structpb.ListValue)
|
||||
}
|
||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||
s.Claims[k] = vs
|
||||
}
|
||||
}
|
36
internal/authenticateflow/request.go
Normal file
36
internal/authenticateflow/request.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package authenticateflow
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
type signatureVerifier struct {
|
||||
options *config.Options
|
||||
sharedKey []byte
|
||||
}
|
||||
|
||||
// VerifyAuthenticateSignature checks that the provided request has a valid
|
||||
// signature (for the authenticate service).
|
||||
func (v signatureVerifier) VerifyAuthenticateSignature(r *http.Request) error {
|
||||
return middleware.ValidateRequestURL(GetExternalAuthenticateRequest(r, v.options), v.sharedKey)
|
||||
}
|
||||
|
||||
// GetExternalAuthenticateRequest canonicalizes an authenticate request URL
|
||||
// based on the provided configuration options.
|
||||
func GetExternalAuthenticateRequest(r *http.Request, options *config.Options) *http.Request {
|
||||
externalURL, err := options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
internalURL, err := options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
return urlutil.GetExternalRequest(internalURL, externalURL, r)
|
||||
}
|
58
internal/authenticateflow/request_test.go
Normal file
58
internal/authenticateflow/request_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package authenticateflow
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
func TestVerifyAuthenticateSignature(t *testing.T) {
|
||||
options := &config.Options{
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
AuthenticateInternalURLString: "https://authenticate.internal",
|
||||
}
|
||||
key := []byte("SHARED KEY--(must be 32 bytes)--")
|
||||
v := signatureVerifier{options, key}
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
u := mustParseURL("https://example.com/")
|
||||
r := &http.Request{Host: "example.com", URL: urlutil.NewSignedURL(key, u).Sign()}
|
||||
err := v.VerifyAuthenticateSignature(r)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
t.Run("NoSignature", func(t *testing.T) {
|
||||
r := &http.Request{Host: "example.com", URL: mustParseURL("https://example.com/")}
|
||||
err := v.VerifyAuthenticateSignature(r)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("DifferentKey", func(t *testing.T) {
|
||||
zeros := make([]byte, 32)
|
||||
u := mustParseURL("https://example.com/")
|
||||
r := &http.Request{Host: "example.com", URL: urlutil.NewSignedURL(zeros, u).Sign()}
|
||||
err := v.VerifyAuthenticateSignature(r)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("InternalDomain", func(t *testing.T) {
|
||||
// A request with the internal authenticate service URL should first be
|
||||
// canonicalized to use the external authenticate service URL before
|
||||
// validating the request signature.
|
||||
u := urlutil.NewSignedURL(key, mustParseURL("https://authenticate.example.com/")).Sign()
|
||||
u.Host = "authenticate.internal"
|
||||
r := &http.Request{Host: "authenticate.internal", URL: u}
|
||||
err := v.VerifyAuthenticateSignature(r)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
449
internal/authenticateflow/stateless.go
Normal file
449
internal/authenticateflow/stateless.go
Normal file
|
@ -0,0 +1,449 @@
|
|||
package authenticateflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// Stateless implements the stateless authentication flow. In this flow, the
|
||||
// authenticate service has no direct access to the databroker and instead
|
||||
// stores profile information in a cookie.
|
||||
type Stateless struct {
|
||||
signatureVerifier
|
||||
|
||||
// sharedEncoder is the encoder to use to serialize data to be consumed
|
||||
// by other services
|
||||
sharedEncoder encoding.MarshalUnmarshaler
|
||||
// cookieCipher is the cipher to use to encrypt/decrypt session data
|
||||
cookieCipher cipher.AEAD
|
||||
|
||||
sessionStore sessions.SessionStore
|
||||
|
||||
hpkePrivateKey *hpke.PrivateKey
|
||||
authenticateKeyFetcher hpke.KeyFetcher
|
||||
|
||||
jwk *jose.JSONWebKeySet
|
||||
|
||||
authenticateURL *url.URL
|
||||
|
||||
options *config.Options
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
|
||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
||||
profileTrimFn func(*identitypb.Profile)
|
||||
authEventFn AuthEventFn
|
||||
}
|
||||
|
||||
// NewStateless initializes the authentication flow for the given
|
||||
// configuration, session store, and additional options.
|
||||
func NewStateless(
|
||||
cfg *config.Config,
|
||||
sessionStore sessions.SessionStore,
|
||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error),
|
||||
profileTrimFn func(*identitypb.Profile),
|
||||
authEventFn AuthEventFn,
|
||||
) (*Stateless, error) {
|
||||
s := &Stateless{
|
||||
options: cfg.Options,
|
||||
sessionStore: sessionStore,
|
||||
getIdentityProvider: getIdentityProvider,
|
||||
profileTrimFn: profileTrimFn,
|
||||
authEventFn: authEventFn,
|
||||
}
|
||||
|
||||
var err error
|
||||
s.authenticateURL, err = cfg.Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// shared cipher to encrypt data before passing data between services
|
||||
sharedKey, err := cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// shared state encoder setup
|
||||
s.sharedEncoder, err = jws.NewHS256Signer(sharedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// private state encoder setup, used to encrypt oauth2 tokens
|
||||
cookieSecret, err := cfg.Options.GetCookieSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.cookieCipher, err = cryptutil.NewAEADCipher(cookieSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.jwk = new(jose.JSONWebKeySet)
|
||||
signingKey, err := cfg.Options.GetSigningKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(signingKey) > 0 {
|
||||
ks, err := cryptutil.PublicJWKsFromBytes(signingKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: failed to convert jwks: %w", err)
|
||||
}
|
||||
for _, k := range ks {
|
||||
s.jwk.Keys = append(s.jwk.Keys, *k)
|
||||
}
|
||||
}
|
||||
|
||||
s.signatureVerifier = signatureVerifier{cfg.Options, sharedKey}
|
||||
|
||||
s.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
|
||||
|
||||
s.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
|
||||
}
|
||||
|
||||
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
||||
OutboundPort: cfg.OutboundPort,
|
||||
InstallationID: cfg.Options.InstallationID,
|
||||
ServiceName: cfg.Options.Services,
|
||||
SignedJWTKey: sharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// VerifySession checks that an existing session is still valid.
|
||||
func (s *Stateless) VerifySession(ctx context.Context, r *http.Request, _ *sessions.State) error {
|
||||
profile, err := loadIdentityProfile(r, s.cookieCipher)
|
||||
if err != nil {
|
||||
return fmt.Errorf("identity profile load error: %w", err)
|
||||
}
|
||||
|
||||
authenticator, err := s.getIdentityProvider(s.options, profile.GetProviderId())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get identity provider: %w", err)
|
||||
}
|
||||
|
||||
if err := validateIdentityProfile(ctx, authenticator, profile); err != nil {
|
||||
return fmt.Errorf("invalid identity profile: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignIn redirects to a route callback URL, if the provided request and
|
||||
// session state are valid.
|
||||
func (s *Stateless) SignIn(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
sessionState *sessions.State,
|
||||
) error {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
proxyPublicKey, requestParams, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||
|
||||
// start over if this is a different identity provider
|
||||
if sessionState == nil || sessionState.IdentityProviderID != idpID {
|
||||
sessionState = sessions.NewState(idpID)
|
||||
}
|
||||
|
||||
// re-persist the session, useful when session was evicted from session store
|
||||
if err := s.sessionStore.SaveSession(w, r, sessionState); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
profile, err := loadIdentityProfile(r, s.cookieCipher)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
if s.profileTrimFn != nil {
|
||||
s.profileTrimFn(profile)
|
||||
}
|
||||
|
||||
s.logAuthenticateEvent(r, profile)
|
||||
|
||||
encryptURLValues := hpke.EncryptURLValuesV1
|
||||
if hpke.IsEncryptedURLV2(r.Form) {
|
||||
encryptURLValues = hpke.EncryptURLValuesV2
|
||||
}
|
||||
|
||||
redirectTo, err := urlutil.CallbackURL(s.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
httputil.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PersistSession stores session data in a cookie.
|
||||
func (s *Stateless) PersistSession(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
sessionState *sessions.State,
|
||||
claims identity.SessionClaims,
|
||||
accessToken *oauth2.Token,
|
||||
) error {
|
||||
idpID := sessionState.IdentityProviderID
|
||||
profile, err := buildIdentityProfile(idpID, claims, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = storeIdentityProfile(w, s.options.NewCookie(), s.cookieCipher, profile)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("failed to store identity profile")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserInfoData returns user info data associated with the given request (if
|
||||
// any).
|
||||
func (s *Stateless) GetUserInfoData(r *http.Request, _ *sessions.State) handlers.UserInfoData {
|
||||
profile, _ := loadIdentityProfile(r, s.cookieCipher)
|
||||
return handlers.UserInfoData{
|
||||
Profile: profile,
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeSession revokes the session associated with the provided request,
|
||||
// returning the ID token from the revoked session.
|
||||
func (s *Stateless) RevokeSession(
|
||||
ctx context.Context, r *http.Request, authenticator identity.Authenticator, _ *sessions.State,
|
||||
) string {
|
||||
profile, err := loadIdentityProfile(r, s.cookieCipher)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
oauthToken := new(oauth2.Token)
|
||||
_ = json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
||||
if err := authenticator.Revoke(ctx, oauthToken); err != nil {
|
||||
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
||||
}
|
||||
|
||||
return string(profile.GetIdToken())
|
||||
}
|
||||
|
||||
// GetIdentityProviderIDForURLValues returns the identity provider ID
|
||||
// associated with the given URL values.
|
||||
func (s *Stateless) GetIdentityProviderIDForURLValues(vs url.Values) string {
|
||||
idpID := ""
|
||||
if _, requestParams, err := hpke.DecryptURLValues(s.hpkePrivateKey, vs); err == nil {
|
||||
if idpID == "" {
|
||||
idpID = requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||
}
|
||||
}
|
||||
if idpID == "" {
|
||||
idpID = vs.Get(urlutil.QueryIdentityProviderID)
|
||||
}
|
||||
return idpID
|
||||
}
|
||||
|
||||
// LogAuthenticateEvent logs an authenticate service event.
|
||||
func (s *Stateless) LogAuthenticateEvent(r *http.Request) {
|
||||
s.logAuthenticateEvent(r, nil)
|
||||
}
|
||||
|
||||
func getUserClaim(profile *identitypb.Profile, field string) *string {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
if profile.Claims == nil {
|
||||
return nil
|
||||
}
|
||||
val, ok := profile.Claims.Fields[field]
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
txt := val.GetStringValue()
|
||||
return &txt
|
||||
}
|
||||
|
||||
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
|
||||
// domain.
|
||||
func (s *Stateless) AuthenticateSignInURL(
|
||||
ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string,
|
||||
) (string, error) {
|
||||
authenticateHPKEPublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
authenticateURLWithParams := *s.authenticateURL
|
||||
q := authenticateURLWithParams.Query()
|
||||
for k, v := range queryParams {
|
||||
q[k] = v
|
||||
}
|
||||
authenticateURLWithParams.RawQuery = q.Encode()
|
||||
|
||||
return urlutil.SignInURL(
|
||||
s.hpkePrivateKey,
|
||||
authenticateHPKEPublicKey,
|
||||
&authenticateURLWithParams,
|
||||
redirectURL,
|
||||
idpID,
|
||||
)
|
||||
}
|
||||
|
||||
// Callback handles a redirect to a route domain once signed in.
|
||||
func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// decrypt the URL values
|
||||
senderPublicKey, values, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err))
|
||||
}
|
||||
|
||||
// confirm this request came from the authenticate service
|
||||
err = s.validateSenderPublicKey(r.Context(), senderPublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate that the request has not expired
|
||||
err = urlutil.ValidateTimeParameters(values)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
profile, err := getProfileFromValues(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ss := newSessionStateFromProfile(profile)
|
||||
sess, err := session.Get(r.Context(), s.dataBrokerClient, ss.ID)
|
||||
if err != nil {
|
||||
sess = &session.Session{Id: ss.ID}
|
||||
}
|
||||
populateSessionFromProfile(sess, profile, ss, s.options.CookieExpire)
|
||||
u, err := user.Get(r.Context(), s.dataBrokerClient, ss.UserID())
|
||||
if err != nil {
|
||||
u = &user.User{Id: ss.UserID()}
|
||||
}
|
||||
populateUserFromClaims(u, profile.GetClaims().AsMap())
|
||||
|
||||
redirectURI, err := getRedirectURIFromValues(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save the records
|
||||
res, err := s.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{
|
||||
Records: []*databroker.Record{
|
||||
databroker.NewRecord(sess),
|
||||
databroker.NewRecord(u),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err))
|
||||
}
|
||||
ss.DatabrokerServerVersion = res.GetServerVersion()
|
||||
for _, record := range res.GetRecords() {
|
||||
if record.GetVersion() > ss.DatabrokerRecordVersion {
|
||||
ss.DatabrokerRecordVersion = record.GetVersion()
|
||||
}
|
||||
}
|
||||
|
||||
// save the session state
|
||||
rawJWT, err := s.sharedEncoder.Marshal(ss)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err))
|
||||
}
|
||||
if err = s.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
|
||||
}
|
||||
|
||||
// if programmatic, encode the session jwt as a query param
|
||||
if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
|
||||
q := redirectURI.Query()
|
||||
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
||||
redirectURI.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// redirect
|
||||
httputil.Redirect(w, r, redirectURI.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stateless) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error {
|
||||
authenticatePublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err))
|
||||
}
|
||||
|
||||
if !authenticatePublicKey.Equals(senderPublicKey) {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getProfileFromValues(values url.Values) (*identitypb.Profile, error) {
|
||||
rawProfile := values.Get(urlutil.QueryIdentityProfile)
|
||||
if rawProfile == "" {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile))
|
||||
}
|
||||
|
||||
var profile identitypb.Profile
|
||||
err := protojson.Unmarshal([]byte(rawProfile), &profile)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err))
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func getRedirectURIFromValues(values url.Values) (*url.URL, error) {
|
||||
rawRedirectURI := values.Get(urlutil.QueryRedirectURI)
|
||||
if rawRedirectURI == "" {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI))
|
||||
}
|
||||
redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err))
|
||||
}
|
||||
return redirectURI, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue