authenticate: implement hpke-based login flow (#3779)

* urlutil: add time validation functions

* authenticate: implement hpke-based login flow

* fix import cycle

* fix tests

* log error

* fix callback url

* add idp param

* fix test

* fix test
This commit is contained in:
Caleb Doxsey 2022-12-05 15:31:07 -07:00 committed by GitHub
parent 8d1235a5cc
commit 57217af7dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 656 additions and 661 deletions

View file

@ -13,8 +13,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/cors" "github.com/rs/cors"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/csrf" "github.com/pomerium/csrf"
"github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/datasource/pkg/directory"
@ -33,6 +31,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/hpke"
"github.com/pomerium/pomerium/pkg/webauthnutil" "github.com/pomerium/pomerium/pkg/webauthnutil"
) )
@ -95,7 +94,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
sr.Use(a.RetrieveSession) sr.Use(a.RetrieveSession)
sr.Use(a.VerifySession) sr.Use(a.VerifySession)
sr.Path("/").Handler(a.requireValidSignatureOnRedirect(a.userInfo)) sr.Path("/").Handler(a.requireValidSignatureOnRedirect(a.userInfo))
sr.Path("/sign_in").Handler(a.requireValidSignature(a.SignIn)) sr.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
sr.Path("/webauthn").Handler(a.webauthn) sr.Path("/webauthn").Handler(a.webauthn)
sr.Path("/device-enrolled").Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { sr.Path("/device-enrolled").Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
@ -149,15 +148,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
if state.dataBrokerClient == nil { _, err = loadIdentityProfile(r, state.cookieCipher)
return errors.New("authenticate: databroker client cannot be nil") if err != nil {
}
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Err(err). Err(err).
Str("idp_id", idp.GetId()). Str("idp_id", idp.GetId()).
Str("id", sessionState.ID). Msg("authenticate: identity profile load error")
Msg("authenticate: session not found in databroker")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
@ -180,25 +176,18 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
state := a.state.Load() state := a.state.Load()
options := a.options.Load() options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
proxyPublicKey, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
if err != nil { if err != nil {
return err return err
} }
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID))
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return err
}
jwtAudience := []string{state.redirectURL.Host, redirectURL.Host}
// if the callback is explicitly set, set it and add an additional audience
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
callbackURL, err := urlutil.ParseAndValidateURL(callbackStr)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
jwtAudience = append(jwtAudience, callbackURL.Host)
} }
s, err := a.getSessionFromCtx(ctx) s, err := a.getSessionFromCtx(ctx)
@ -212,33 +201,22 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
s = sessions.NewState(idp.GetId()) s = sessions.NewState(idp.GetId())
} }
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
// re-persist the session, useful when session was evicted from session // re-persist the session, useful when session was evicted from session
if err := state.sessionStore.SaveSession(w, r, s); err != nil { if err := state.sessionStore.SaveSession(w, r, s); err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
// sign the route session, as a JWT profile, err := loadIdentityProfile(r, state.cookieCipher)
signedJWT, err := state.sharedEncoder.Marshal(newSession)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
// encrypt our route-scoped JWT to avoid accidental logging of queryparams redirectTo, err := handlers.BuildCallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile)
encryptedJWT := cryptutil.Encrypt(a.state.Load().sharedCipher, signedJWT, nil)
// base64 our encrypted payload for URL-friendlyness
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusInternalServerError, err)
} }
// build our hmac-d redirect URL with our session, pointing back to the httputil.Redirect(w, r, redirectTo, http.StatusFound)
// proxy's callback URL which is responsible for setting our new route-session
uri := urlutil.NewSignedURL(state.sharedKey, callbackURL)
httputil.Redirect(w, r, uri.String(), http.StatusFound)
return nil return nil
} }
@ -460,10 +438,11 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
} }
// save the session and access token to the databroker // save the session and access token to the databroker
err = a.saveSessionToDataBroker(ctx, r, &newState, claims, accessToken) profile, err := a.buildIdentityProfile(ctx, r, &newState, claims, accessToken)
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusInternalServerError, err) return nil, httputil.NewError(http.StatusInternalServerError, err)
} }
storeIdentityProfile(w, state.cookieCipher, profile)
// ... and the user state to local storage. // ... and the user state to local storage.
if err := state.sessionStore.SaveSession(w, r, &newState); err != nil { if err := state.sessionStore.SaveSession(w, r, &newState); err != nil {
@ -542,11 +521,14 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
} }
creationOptions, requestOptions, _ := a.webauthn.GetOptions(r) creationOptions, requestOptions, _ := a.webauthn.GetOptions(r)
profile, _ := loadIdentityProfile(r, state.cookieCipher)
data := handlers.UserInfoData{ data := handlers.UserInfoData{
CSRFToken: csrf.Token(r), CSRFToken: csrf.Token(r),
IsImpersonated: isImpersonated, IsImpersonated: isImpersonated,
Session: pbSession, Session: pbSession,
User: pbUser, User: pbUser,
Profile: profile,
WebAuthnCreationOptions: creationOptions, WebAuthnCreationOptions: creationOptions,
WebAuthnRequestOptions: requestOptions, WebAuthnRequestOptions: requestOptions,
@ -582,73 +564,6 @@ func (a *Authenticate) fillEnterpriseUserInfoData(
} }
} }
func (a *Authenticate) saveSessionToDataBroker(
ctx context.Context,
r *http.Request,
sessionState *sessions.State,
claims identity.SessionClaims,
accessToken *oauth2.Token,
) error {
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}
sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire))
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
s := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(authenticator.Name()),
IssuedAt: timestamppb.Now(),
AccessedAt: timestamppb.Now(),
ExpiresAt: sessionExpiry,
IdToken: &session.IDToken{
Issuer: sessionState.Issuer, // todo(bdd): the issuer is not authN but the downstream IdP from the claims
Subject: sessionState.Subject,
ExpiresAt: sessionExpiry,
IssuedAt: idTokenIssuedAt,
},
OauthToken: manager.ToOAuthToken(accessToken),
Audience: sessionState.Audience,
}
s.SetRawIDToken(claims.RawIDToken)
s.AddClaims(claims.Flatten())
var managerUser manager.User
managerUser.User, _ = user.Get(ctx, state.dataBrokerClient, s.GetUserId())
if managerUser.User == nil {
// if no user exists yet, create a new one
managerUser.User = &user.User{
Id: s.GetUserId(),
}
}
err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser)
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
_, err = databroker.Put(ctx, state.dataBrokerClient, managerUser.User)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
}
res, err := session.Put(ctx, state.dataBrokerClient, s)
if err != nil {
return fmt.Errorf("authenticate: error saving session: %w", err)
}
sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
return nil
}
// revokeSession always clears the local session and tries to revoke the associated session stored in the // revokeSession always clears the local session and tries to revoke the associated session stored in the
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns // databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string. // and empty string.

View file

@ -23,7 +23,6 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/handlers/webauthn"
@ -108,87 +107,6 @@ func TestAuthenticate_Handler(t *testing.T) {
} }
} }
func TestAuthenticate_SignIn(t *testing.T) {
t.Parallel()
tests := []struct {
name string
scheme string
host string
qp map[string]string
session sessions.SessionStore
provider identity.MockProvider
encoder encoding.MarshalUnmarshaler
wantCode int
}{
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
sharedCipher: sharedCipher,
sessionStore: tt.session,
redirectURL: uriParseHelper("https://some.example"),
sharedEncoder: tt.encoder,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Record: databroker.NewRecord(&session.Session{
Id: "SESSION_ID",
}),
}, nil
},
},
}),
options: config.NewAtomicOptions(),
}
a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
queryString := uri.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
uri.RawQuery = queryString.Encode()
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
r.Header.Set("Accept", "application/json")
state, err := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, err)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
httputil.HandlerFunc(a.SignIn).ServeHTTP(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri)
t.Errorf("\n%+v", w.Body)
}
})
}
}
func uriParseHelper(s string) *url.URL { func uriParseHelper(s string) *url.URL {
uri, _ := url.Parse(s) uri, _ := url.Parse(s)
return uri return uri
@ -475,14 +393,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
wantStatus int wantStatus int
}{ }{
{
"good",
nil,
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK,
},
{ {
"invalid session", "invalid session",
nil, nil,
@ -491,14 +401,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
identity.MockProvider{}, identity.MockProvider{},
http.StatusFound, http.StatusFound,
}, },
{
"good refresh expired",
nil,
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK,
},
{ {
"expired,refresh error", "expired,refresh error",
nil, nil,

View file

@ -0,0 +1,104 @@
package authenticate
import (
"context"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"golang.org/x/oauth2"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
)
var cookieChunker = httputil.NewCookieChunker()
func (a *Authenticate) buildIdentityProfile(
ctx context.Context,
r *http.Request,
sessionState *sessions.State,
claims identity.SessionClaims,
oauthToken *oauth2.Token,
) (*identitypb.Profile, error) {
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err)
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err)
}
err = authenticator.UpdateUserInfo(ctx, oauthToken, &claims)
if err != nil {
return nil, fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
rawIDToken := []byte(claims.RawIDToken)
rawOAuthToken, err := json.Marshal(oauthToken)
if err != nil {
return nil, fmt.Errorf("authenticate: error marshaling oauth token: %w", err)
}
rawClaims, err := structpb.NewStruct(claims.Claims)
if err != nil {
return nil, fmt.Errorf("authenticate: error creating claims struct: %w", err)
}
return &identitypb.Profile{
ProviderId: idp.GetId(),
IdToken: rawIDToken,
OauthToken: rawOAuthToken,
Claims: rawClaims,
}, nil
}
func loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) {
cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile)
if err != nil {
return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err)
}
encrypted, err := base64.RawURLEncoding.DecodeString(cookie.Value)
if err != nil {
return nil, fmt.Errorf("authenticate: error decoding identity profile cookie: %w", err)
}
decrypted, err := cryptutil.Decrypt(aead, encrypted, nil)
if err != nil {
return nil, fmt.Errorf("authenticate: error decrypting identity profile cookie: %w", err)
}
var profile identitypb.Profile
err = protojson.Unmarshal(decrypted, &profile)
if err != nil {
return nil, fmt.Errorf("authenticate: error unmarshaling identity profile cookie: %w", err)
}
return &profile, nil
}
func storeIdentityProfile(w http.ResponseWriter, aead cipher.AEAD, profile *identitypb.Profile) {
decrypted, err := protojson.Marshal(profile)
if err != nil {
// this shouldn't happen
panic(fmt.Errorf("error marshaling message: %w", err))
}
encrypted := cryptutil.Encrypt(aead, decrypted, nil)
err = cookieChunker.SetCookie(w, &http.Cookie{
Name: urlutil.QueryIdentityProfile,
Value: base64.RawURLEncoding.EncodeToString(encrypted),
Path: "/",
})
log.Error(context.Background()).Err(err).Send()
}

View file

@ -18,6 +18,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/hpke"
) )
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -39,7 +40,8 @@ type authenticateState struct {
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
// sessionLoaders are a collection of session loaders to attempt to pull // sessionLoaders are a collection of session loaders to attempt to pull
// a user's session state from // a user's session state from
sessionLoader sessions.SessionLoader sessionLoader sessions.SessionLoader
hpkePrivateKey *hpke.PrivateKey
jwk *jose.JSONWebKeySet jwk *jose.JSONWebKeySet
@ -137,6 +139,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
return nil, err return nil, err
} }
state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,

View file

@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
@ -174,34 +175,42 @@ func (a *Authorize) requireLoginResponse(
in *envoy_service_auth_v3.CheckRequest, in *envoy_service_auth_v3.CheckRequest,
request *evaluator.Request, request *evaluator.Request,
) (*envoy_service_auth_v3.CheckResponse, error) { ) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load() options := a.currentOptions.Load()
state := a.state.Load() state := a.state.Load()
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err
}
if !a.shouldRedirect(in) { if !a.shouldRedirect(in) {
return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil) return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil)
} }
signinURL := authenticateURL.ResolveReference(&url.URL{ authenticateURL, err := options.GetAuthenticateURL()
Path: "/.pomerium/sign_in", if err != nil {
}) return nil, err
q := signinURL.Query() }
idp, err := options.GetIdentityProviderForPolicy(request.Policy)
if err != nil {
return nil, err
}
authenticateHPKEPublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx)
if err != nil {
return nil, err
}
// always assume https scheme // always assume https scheme
checkRequestURL := getCheckRequestURL(in) checkRequestURL := getCheckRequestURL(in)
checkRequestURL.Scheme = "https" checkRequestURL.Scheme = "https"
q.Set(urlutil.QueryRedirectURI, checkRequestURL.String()) redirectTo, err := handlers.BuildSignInURL(
idp, err := opts.GetIdentityProviderForPolicy(request.Policy) state.hpkePrivateKey,
authenticateHPKEPublicKey,
authenticateURL,
&checkRequestURL,
idp.GetId(),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
q.Set(urlutil.QueryIdentityProviderID, idp.GetId())
signinURL.RawQuery = q.Encode()
redirectTo := urlutil.NewSignedURL(state.sharedKey, signinURL).String()
return a.deniedResponse(ctx, in, http.StatusFound, "Login", map[string]string{ return a.deniedResponse(ctx, in, http.StatusFound, "Login", map[string]string{
"Location": redirectTo, "Location": redirectTo,

View file

@ -3,6 +3,7 @@ package authorize
import ( import (
"context" "context"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@ -19,15 +20,23 @@ import (
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/policy/criteria" "github.com/pomerium/pomerium/pkg/policy/criteria"
) )
func TestAuthorize_handleResult(t *testing.T) { func TestAuthorize_handleResult(t *testing.T) {
opt := config.NewDefaultOptions() opt := config.NewDefaultOptions()
opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com" opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
htpkePrivateKey, err := opt.GetHPKEPrivateKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(opt.SigningKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
a, err := New(&config.Config{Options: opt}) a, err := New(&config.Config{Options: opt})
require.NoError(t, err) require.NoError(t, err)
@ -179,10 +188,20 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL {
} }
func TestRequireLogin(t *testing.T) { func TestRequireLogin(t *testing.T) {
t.Parallel()
opt := config.NewDefaultOptions() opt := config.NewDefaultOptions()
opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com" opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
htpkePrivateKey, err := opt.GetHPKEPrivateKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(opt.SigningKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
a, err := New(&config.Config{Options: opt}) a, err := New(&config.Config{Options: opt})
require.NoError(t, err) require.NoError(t, err)

View file

@ -3,6 +3,7 @@ package authorize
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
googlegrpc "google.golang.org/grpc" googlegrpc "google.golang.org/grpc"
@ -11,6 +12,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/hpke"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
) )
@ -23,6 +25,8 @@ type authorizeState struct {
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
auditEncryptor *protoutil.Encryptor auditEncryptor *protoutil.Encryptor
sessionStore *config.SessionStore sessionStore *config.SessionStore
hpkePrivateKey *hpke.PrivateKey
authenticateKeyFetcher hpke.KeyFetcher
} }
func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*authorizeState, error) { func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*authorizeState, error) {
@ -74,5 +78,15 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
return nil, fmt.Errorf("authorize: invalid session store: %w", err) return nil, fmt.Errorf("authorize: invalid session store: %w", err)
} }
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("authorize: invalid authenticate service url: %w", err)
}
state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
state.authenticateKeyFetcher = hpke.NewKeyFetcher(authenticateURL.ResolveReference(&url.URL{
Path: "/.well-known/pomerium/jwks.json",
}).String())
return state, nil return state, nil
} }

View file

@ -25,7 +25,7 @@ func TestAuthorization(t *testing.T) {
} }
t.Run("public", func(t *testing.T) { t.Run("public", func(t *testing.T) {
client := getClient() client := getClient(t)
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io", nil) req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io", nil)
if err != nil { if err != nil {
@ -43,7 +43,7 @@ func TestAuthorization(t *testing.T) {
t.Run("domains", func(t *testing.T) { t.Run("domains", func(t *testing.T) {
t.Run("allowed", func(t *testing.T) { t.Run("allowed", func(t *testing.T) {
client := getClient() client := getClient(t)
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"), res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"),
withAPI, flows.WithEmail("user1@dogs.test")) withAPI, flows.WithEmail("user1@dogs.test"))
if assert.NoError(t, err) { if assert.NoError(t, err) {
@ -51,7 +51,7 @@ func TestAuthorization(t *testing.T) {
} }
}) })
t.Run("not allowed", func(t *testing.T) { t.Run("not allowed", func(t *testing.T) {
client := getClient() client := getClient(t)
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"), res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"),
withAPI, flows.WithEmail("user1@cats.test")) withAPI, flows.WithEmail("user1@cats.test"))
if assert.NoError(t, err) { if assert.NoError(t, err) {

View file

@ -12,7 +12,7 @@ import (
func BenchmarkLoggedInUserAccess(b *testing.B) { func BenchmarkLoggedInUserAccess(b *testing.B) {
ctx := context.Background() ctx := context.Background()
client := getClient() client := getClient(b)
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"), res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"),
flows.WithEmail("user1@dogs.test")) flows.WithEmail("user1@dogs.test"))
require.NoError(b, err) require.NoError(b, err)
@ -30,7 +30,7 @@ func BenchmarkLoggedInUserAccess(b *testing.B) {
func BenchmarkLoggedOutUserAccess(b *testing.B) { func BenchmarkLoggedOutUserAccess(b *testing.B) {
ctx := context.Background() ctx := context.Background()
client := getClient() client := getClient(b)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View file

@ -21,7 +21,7 @@ func TestDashboard(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -37,7 +37,7 @@ func TestDashboard(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -69,7 +69,7 @@ func TestHealth(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }

View file

@ -50,10 +50,21 @@ func TestMain(m *testing.M) {
os.Exit(status) os.Exit(status)
} }
func getClient() *http.Client { type loggingRoundTripper struct {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) t testing.TB
if err != nil { transport http.RoundTripper
panic(err) }
func (l loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if l.t != nil {
l.t.Logf("%s %s", req.Method, req.URL.String())
}
return l.transport.RoundTrip(req)
}
func getTransport(t testing.TB) http.RoundTripper {
if t != nil {
t.Helper()
} }
rootCAs, err := x509.SystemCertPool() rootCAs, err := x509.SystemCertPool()
@ -66,23 +77,36 @@ func getClient() *http.Client {
panic(err) panic(err)
} }
_ = rootCAs.AppendCertsFromPEM(bs) _ = rootCAs.AppendCertsFromPEM(bs)
transport := &http.Transport{
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
},
}
return loggingRoundTripper{t, transport}
}
func getClient(t testing.TB) *http.Client {
if t != nil {
t.Helper()
}
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
panic(err)
}
return &http.Client{ return &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },
Transport: &http.Transport{ Transport: getTransport(t),
DisableKeepAlives: true, Jar: jar,
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
},
},
Jar: jar,
} }
} }
func waitForHealthy(ctx context.Context) error { func waitForHealthy(ctx context.Context) error {
client := getClient() client := getClient(nil)
check := func(endpoint string) error { check := func(endpoint string) error {
reqCtx, clearTimeout := context.WithTimeout(ctx, time.Second) reqCtx, clearTimeout := context.WithTimeout(ctx, time.Second)
defer clearTimeout() defer clearTimeout()

View file

@ -31,7 +31,7 @@ func TestQueryStringParams(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -65,7 +65,7 @@ func TestCORS(t *testing.T) {
req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Access-Control-Request-Method", "GET")
req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io")
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -81,7 +81,7 @@ func TestCORS(t *testing.T) {
req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Access-Control-Request-Method", "GET")
req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io")
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -102,7 +102,7 @@ func TestPreserveHostHeader(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -127,7 +127,7 @@ func TestPreserveHostHeader(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -158,7 +158,7 @@ func TestSetRequestHeaders(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -187,7 +187,7 @@ func TestRemoveRequestHeaders(t *testing.T) {
} }
req.Header.Add("X-Custom-Request-Header-To-Remove", "foo") req.Header.Add("X-Custom-Request-Header-To-Remove", "foo")
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -250,7 +250,7 @@ func TestGoogleCloudRun(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := getClient().Do(req) res, err := getClient(t).Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }
@ -274,7 +274,7 @@ func TestLoadBalancer(t *testing.T) {
defer clearTimeout() defer clearTimeout()
getDistribution := func(t *testing.T, path string) map[string]float64 { getDistribution := func(t *testing.T, path string) map[string]float64 {
client := getClient() client := getClient(t)
distribution := map[string]float64{} distribution := map[string]float64{}
res, err := flows.Authenticate(ctx, client, res, err := flows.Authenticate(ctx, client,

View file

@ -0,0 +1,89 @@
package handlers
import (
"fmt"
"net/url"
"time"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/hpke"
)
const signInExpiry = time.Minute * 5
// BuildCallbackURL builds the callback URL using an HPKE encrypted query string.
func BuildCallbackURL(
authenticatePrivateKey *hpke.PrivateKey,
proxyPublicKey *hpke.PublicKey,
requestParams url.Values,
profile *identity.Profile,
) (string, error) {
redirectURL, err := urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryRedirectURI))
if err != nil {
return "", fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err)
}
var callbackURL *url.URL
if requestParams.Has(urlutil.QueryCallbackURI) {
callbackURL, err = urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryCallbackURI))
if err != nil {
return "", fmt.Errorf("invalid %s: %w", urlutil.QueryCallbackURI, err)
}
} else {
callbackURL, err = urlutil.DeepCopy(redirectURL)
if err != nil {
return "", fmt.Errorf("error copying %s: %w", urlutil.QueryRedirectURI, err)
}
callbackURL.Path = "/.pomerium/callback/"
callbackURL.RawQuery = ""
}
callbackParams := callbackURL.Query()
if requestParams.Has(urlutil.QueryIsProgrammatic) {
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
}
callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
rawProfile, err := protojson.Marshal(profile)
if err != nil {
return "", fmt.Errorf("error marshaling identity profile: %w", err)
}
callbackParams.Set(urlutil.QueryIdentityProfile, string(rawProfile))
urlutil.BuildTimeParameters(callbackParams, signInExpiry)
callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams)
if err != nil {
return "", fmt.Errorf("error encrypting callback params: %w", err)
}
callbackURL.RawQuery = callbackParams.Encode()
return callbackURL.String(), nil
}
// BuildSignInURL buidls the sign in URL using an HPKE encrypted query string.
func BuildSignInURL(
senderPrivateKey *hpke.PrivateKey,
authenticatePublicKey *hpke.PublicKey,
authenticateURL *url.URL,
redirectURL *url.URL,
idpID string,
) (string, error) {
signInURL := *authenticateURL
signInURL.Path = "/.pomerium/sign_in"
q := signInURL.Query()
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
q.Set(urlutil.QueryIdentityProviderID, idpID)
urlutil.BuildTimeParameters(q, signInExpiry)
q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q)
if err != nil {
return "", err
}
signInURL.RawQuery = q.Encode()
return signInURL.String(), nil
}

View file

@ -8,6 +8,7 @@ import (
"github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/ui" "github.com/pomerium/pomerium/ui"
@ -20,6 +21,7 @@ type UserInfoData struct {
IsImpersonated bool IsImpersonated bool
Session *session.Session Session *session.Session
User *user.User User *user.User
Profile *identity.Profile
IsEnterprise bool IsEnterprise bool
DirectoryUser *directory.User DirectoryUser *directory.User
@ -43,6 +45,9 @@ func (data UserInfoData) ToJSON() map[string]any {
if bs, err := protojson.Marshal(data.User); err == nil { if bs, err := protojson.Marshal(data.User); err == nil {
m["user"] = json.RawMessage(bs) m["user"] = json.RawMessage(bs)
} }
if bs, err := protojson.Marshal(data.Profile); err == nil {
m["profile"] = json.RawMessage(bs)
}
m["isEnterprise"] = data.IsEnterprise m["isEnterprise"] = data.IsEnterprise
if data.DirectoryUser != nil { if data.DirectoryUser != nil {
m["directoryUser"] = data.DirectoryUser m["directoryUser"] = data.DirectoryUser

View file

@ -60,7 +60,7 @@ func (s *State) WithNewIssuer(issuer string, audience []string) State {
} }
// UserID returns the corresponding user ID for a session. // UserID returns the corresponding user ID for a session.
func (s *State) UserID(provider string) string { func (s *State) UserID() string {
if s.OID != "" { if s.OID != "" {
return s.OID return s.OID
} }

View file

@ -1,4 +1,4 @@
package hpke package hpke_test
import ( import (
"context" "context"
@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/hpke"
) )
func TestFetchPublicKeyFromJWKS(t *testing.T) { func TestFetchPublicKeyFromJWKS(t *testing.T) {
@ -19,7 +20,7 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout) t.Cleanup(clearTimeout)
hpkePrivateKey, err := GeneratePrivateKey() hpkePrivateKey, err := hpke.GeneratePrivateKey()
require.NoError(t, err) require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -27,7 +28,7 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) {
})) }))
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL) publicKey, err := hpke.FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String()) assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
} }

View file

@ -1,7 +1,7 @@
package proxy package proxy
import ( import (
"encoding/base64" "context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -9,12 +9,17 @@ import (
"net/url" "net/url"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/hpke"
) )
// registerDashboardHandlers returns the proxy service's ServeMux // registerDashboardHandlers returns the proxy service's ServeMux
@ -32,9 +37,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
// called following authenticate auth flow to grab a new or existing session // called following authenticate auth flow to grab a new or existing session
// the route specific cookie is returned in a signed query params // the route specific cookie is returned in a signed query params
c := r.PathPrefix(dashboardPath + "/callback").Subrouter() c := r.PathPrefix(dashboardPath + "/callback").Subrouter()
c.Use(func(h http.Handler) http.Handler {
return middleware.ValidateSignature(p.state.Load().sharedKey)(h)
})
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet) c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
// Programmatic API handlers and middleware // Programmatic API handlers and middleware
@ -105,55 +107,96 @@ func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error {
// Callback handles the result of a successful call to the authenticate service // Callback handles the result of a successful call to the authenticate service
// and is responsible setting per-route sessions. // and is responsible setting per-route sessions.
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
redirectURLString := r.FormValue(urlutil.QueryRedirectURI) state := p.state.Load()
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) options := p.currentOptions.Load()
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
// decrypt the URL values
senderPublicKey, values, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
if err != nil {
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err))
}
// confirm this request came from the authenticate service
err = p.validateSenderPublicKey(r.Context(), senderPublicKey)
if err != nil {
return err
}
// validate that the request has not expired
err = urlutil.ValidateTimeParameters(values)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession) profile, err := getProfileFromValues(values)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return err
}
ss := newSessionStateFromProfile(profile)
s, err := session.Get(r.Context(), state.dataBrokerClient, ss.ID)
if err != nil {
s = &session.Session{Id: ss.ID}
}
populateSessionFromProfile(s, profile, ss, options.CookieExpire)
u, err := user.Get(r.Context(), state.dataBrokerClient, ss.UserID())
if err != nil {
u = &user.User{Id: ss.UserID()}
}
populateUserFromProfile(u, profile, ss)
redirectURI, err := getRedirectURIFromValues(values)
if err != nil {
return err
}
// save the records
res, err := state.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{
Records: []*databroker.Record{
databroker.NewRecord(s),
databroker.NewRecord(u),
},
})
if err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err))
}
ss.DatabrokerServerVersion = res.GetServerVersion()
for _, record := range res.GetRecords() {
if record.GetVersion() > ss.DatabrokerRecordVersion {
ss.DatabrokerRecordVersion = record.GetVersion()
}
}
// save the session state
rawJWT, err := state.encoder.Marshal(ss)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err))
}
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
} }
// if programmatic, encode the session jwt as a query param // if programmatic, encode the session jwt as a query param
if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
q := redirectURL.Query() q := redirectURI.Query()
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
redirectURL.RawQuery = q.Encode() redirectURI.RawQuery = q.Encode()
} }
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
// redirect
httputil.Redirect(w, r, redirectURI.String(), http.StatusFound)
return nil return nil
} }
// saveCallbackSession takes an encrypted per-route session token, decrypts
// it using the shared service key, then stores it the local session store.
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
state := p.state.Load()
// 1. extract the base64 encoded and encrypted JWT from query params
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
if err != nil {
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
}
// 2. decrypt the JWT using the cipher using the _shared_ secret key
rawJWT, err := cryptutil.Decrypt(state.sharedCipher, encryptedJWT, nil)
if err != nil {
return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
}
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
}
return rawJWT, nil
}
// ProgrammaticLogin returns a signed url that can be used to login // ProgrammaticLogin returns a signed url that can be used to login
// using the authenticate service. // using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load() state := p.state.Load()
options := p.currentOptions.Load()
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
@ -164,19 +207,32 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri")) return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri"))
} }
idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
hpkeAuthenticateKey, err := state.authenticateKeyFetcher.FetchPublicKey(r.Context())
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
signinURL := *state.authenticateSigninURL signinURL := *state.authenticateSigninURL
callbackURI := urlutil.GetAbsoluteURL(r) callbackURI := urlutil.GetAbsoluteURL(r)
callbackURI.Path = dashboardPath + "/callback/" callbackURI.Path = dashboardPath + "/callback/"
q := signinURL.Query() q := signinURL.Query()
q.Set(urlutil.QueryCallbackURI, callbackURI.String()) q.Set(urlutil.QueryCallbackURI, callbackURI.String())
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
q.Set(urlutil.QueryIsProgrammatic, "true") q.Set(urlutil.QueryIsProgrammatic, "true")
signinURL.RawQuery = q.Encode() signinURL.RawQuery = q.Encode()
response := urlutil.NewSignedURL(state.sharedKey, &signinURL).String()
rawURL, err := handlers.BuildSignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId())
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, response) _, _ = io.WriteString(w, rawURL)
return nil return nil
} }
@ -191,3 +247,44 @@ func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error {
_, _ = io.WriteString(w, assertionJWT) _, _ = io.WriteString(w, assertionJWT)
return nil return nil
} }
func (p *Proxy) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error {
state := p.state.Load()
authenticatePublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err))
}
if !authenticatePublicKey.Equals(senderPublicKey) {
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key"))
}
return nil
}
func getProfileFromValues(values url.Values) (*identity.Profile, error) {
rawProfile := values.Get(urlutil.QueryIdentityProfile)
if rawProfile == "" {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile))
}
var profile identity.Profile
err := protojson.Unmarshal([]byte(rawProfile), &profile)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err))
}
return &profile, nil
}
func getRedirectURIFromValues(values url.Values) (*url.URL, error) {
rawRedirectURI := values.Get(urlutil.QueryRedirectURI)
if rawRedirectURI == "" {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI))
}
redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err))
}
return redirectURI, nil
}

View file

@ -2,30 +2,20 @@ package proxy
import ( import (
"bytes" "bytes"
"context"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
func TestProxy_RobotsTxt(t *testing.T) { func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{} proxy := Proxy{}
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
@ -40,29 +30,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
} }
} }
func TestProxy_Signout(t *testing.T) {
opts := testOptions(t)
err := ValidateOptions(opts)
if err != nil {
t.Fatal(err)
}
proxy, err := New(&config.Config{Options: opts})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
rr := httptest.NewRecorder()
proxy.SignOut(rr, req)
if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
}
body := rr.Body.String()
want := proxy.state.Load().authenticateURL.String()
if !strings.Contains(body, want) {
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
}
}
func TestProxy_SignOut(t *testing.T) { func TestProxy_SignOut(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
@ -104,165 +71,6 @@ func TestProxy_SignOut(t *testing.T) {
} }
} }
func TestProxy_Callback(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options *config.Options
method string
scheme string
host string
path string
headers map[string]string
qp map[string]string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
wantStatus int
wantBody string
}{
{
"good",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"good programmatic",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"bad decrypt",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"bad save session",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{SaveError: errors.New("hi")},
http.StatusBadRequest,
"",
},
{
"bad base64",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: "^"},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"malformed redirect",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
nil,
&mock.Encoder{},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder()
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}
func TestProxy_ProgrammaticLogin(t *testing.T) { func TestProxy_ProgrammaticLogin(t *testing.T) {
t.Parallel() t.Parallel()
opts := testOptions(t) opts := testOptions(t)
@ -360,155 +168,6 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
} }
} }
func TestProxy_ProgrammaticCallback(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options *config.Options
method string
redirectURI string
headers map[string]string
qp map[string]string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
wantStatus int
wantBody string
}{
{
"good",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"good programmatic",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{
urlutil.QueryIsProgrammatic: "true",
urlutil.QueryCallbackURI: "ok",
urlutil.QuerySessionEncrypted: goodEncryptionString,
},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"bad decrypt",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"bad save session",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{SaveError: errors.New("hi")},
http.StatusBadRequest,
"",
},
{
"bad base64",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: "^"},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"malformed redirect",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
nil,
&mock.Encoder{},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore
redirectURI, _ := url.Parse(tt.redirectURI)
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder()
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}
func TestProxy_jwt(t *testing.T) { func TestProxy_jwt(t *testing.T) {
// without upstream headers being set // without upstream headers being set
req, _ := http.NewRequest("GET", "https://www.example.com/.pomerium/jwt", nil) req, _ := http.NewRequest("GET", "https://www.example.com/.pomerium/jwt", nil)

79
proxy/identity_profile.go Normal file
View file

@ -0,0 +1,79 @@
package proxy
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/manager"
"github.com/pomerium/pomerium/internal/sessions"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State {
claims := p.GetClaims().AsMap()
ss := sessions.NewState(p.GetProviderId())
// set the subject
if v, ok := claims["sub"]; ok {
ss.Subject = fmt.Sprint(v)
} else if v, ok := claims["user"]; ok {
ss.Subject = fmt.Sprint(v)
}
// set the oid
if v, ok := claims["oid"]; ok {
ss.OID = fmt.Sprint(v)
}
return ss
}
func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) {
claims := p.GetClaims().AsMap()
oauthToken := new(oauth2.Token)
_ = json.Unmarshal(p.GetOauthToken(), oauthToken)
s.UserId = ss.UserID()
s.IssuedAt = timestamppb.Now()
s.AccessedAt = timestamppb.Now()
s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire))
s.IdToken = &session.IDToken{
Issuer: ss.Issuer,
Subject: ss.Subject,
ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)),
IssuedAt: timestamppb.Now(),
Raw: string(p.GetIdToken()),
}
s.OauthToken = manager.ToOAuthToken(oauthToken)
if s.Claims == nil {
s.Claims = make(map[string]*structpb.ListValue)
}
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
s.Claims[k] = vs
}
}
func populateUserFromProfile(u *user.User, p *identitypb.Profile, ss *sessions.State) {
claims := p.GetClaims().AsMap()
if v, ok := claims["name"]; ok {
u.Name = fmt.Sprint(v)
}
if v, ok := claims["email"]; ok {
u.Email = fmt.Sprint(v)
}
if u.Claims == nil {
u.Claims = make(map[string]*structpb.ListValue)
}
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
u.Claims[k] = vs
}
}

View file

@ -9,13 +9,15 @@ import (
"time" "time"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func testOptions(t *testing.T) *config.Options { func testOptions(t *testing.T) *config.Options {
t.Helper()
opts := config.NewDefaultOptions() opts := config.NewDefaultOptions()
opts.AuthenticateURLString = "https://authenticate.example"
to, err := config.ParseWeightedUrls("https://example.example") to, err := config.ParseWeightedUrls("https://example.example")
require.NoError(t, err) require.NoError(t, err)
@ -28,6 +30,13 @@ func testOptions(t *testing.T) *config.Options {
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
htpkePrivateKey, err := opts.GetHPKEPrivateKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(opts.SigningKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opts.AuthenticateURLString = authnSrv.URL
require.NoError(t, opts.Validate()) require.NoError(t, opts.Validate())
return opts return opts

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/hpke"
) )
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -26,10 +27,12 @@ type proxyState struct {
authenticateSigninURL *url.URL authenticateSigninURL *url.URL
authenticateRefreshURL *url.URL authenticateRefreshURL *url.URL
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
cookieSecret []byte cookieSecret []byte
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
jwtClaimHeaders config.JWTClaimHeaders jwtClaimHeaders config.JWTClaimHeaders
hpkePrivateKey *hpke.PrivateKey
authenticateKeyFetcher hpke.KeyFetcher
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
@ -44,11 +47,24 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
state := new(proxyState) state := new(proxyState)
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return nil, err
}
state.sharedKey, err = cfg.Options.GetSharedKey() state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
state.hpkePrivateKey, err = cfg.Options.GetHPKEPrivateKey()
if err != nil {
return nil, err
}
state.authenticateKeyFetcher = hpke.NewKeyFetcher(authenticateURL.ResolveReference(&url.URL{
Path: "/.well-known/pomerium/jwks.json",
}).String())
state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey) state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -0,0 +1,31 @@
import TableCell from "@mui/material/TableCell";
import TableRow from "@mui/material/TableRow";
import { isArray, startCase } from "lodash";
import React, { FC } from "react";
import ClaimValue from "./ClaimValue";
export type ClaimRowProps = {
claimKey: string;
claimValue: unknown;
};
export const ClaimRow: FC<ClaimRowProps> = ({ claimKey, claimValue }) => {
return (
<TableRow>
<TableCell variant="head">{startCase(claimKey)}</TableCell>
<TableCell align="left">
{isArray(claimValue) ? (
claimValue?.map((v, i) => (
<React.Fragment key={`${v}`}>
{i > 0 ? <br /> : <></>}
<ClaimValue claimKey={claimKey} claimValue={v} />
</React.Fragment>
))
) : (
<ClaimValue claimKey={claimKey} claimValue={claimValue} />
)}
</TableCell>
</TableRow>
);
};
export default ClaimRow;

View file

@ -1,6 +1,3 @@
import { Session } from "../types";
import IDField from "./IDField";
import Section from "./Section";
import Stack from "@mui/material/Stack"; import Stack from "@mui/material/Stack";
import Table from "@mui/material/Table"; import Table from "@mui/material/Table";
import TableBody from "@mui/material/TableBody"; import TableBody from "@mui/material/TableBody";
@ -8,13 +5,20 @@ import TableCell from "@mui/material/TableCell";
import TableContainer from "@mui/material/TableContainer"; import TableContainer from "@mui/material/TableContainer";
import TableRow from "@mui/material/TableRow"; import TableRow from "@mui/material/TableRow";
import React, { FC } from "react"; import React, { FC } from "react";
import ClaimValue from "./ClaimValue";
import {startCase} from "lodash"; import { Profile, Session } from "../types";
import ClaimRow from "./ClaimRow";
import IDField from "./IDField";
import Section from "./Section";
export type SessionDetailsProps = { export type SessionDetailsProps = {
session: Session; session: Session;
profile: Profile;
}; };
export const SessionDetails: FC<SessionDetailsProps> = ({ session }) => { export const SessionDetails: FC<SessionDetailsProps> = ({
session,
profile,
}) => {
return ( return (
<Section title="User Details"> <Section title="User Details">
<Stack spacing={3}> <Stack spacing={3}>
@ -22,7 +26,9 @@ export const SessionDetails: FC<SessionDetailsProps> = ({ session }) => {
<Table size="small"> <Table size="small">
<TableBody> <TableBody>
<TableRow> <TableRow>
<TableCell width={'18%'} variant="head">Session ID</TableCell> <TableCell width={"18%"} variant="head">
Session ID
</TableCell>
<TableCell align="left"> <TableCell align="left">
<IDField value={session?.id} /> <IDField value={session?.id} />
</TableCell> </TableCell>
@ -30,26 +36,28 @@ export const SessionDetails: FC<SessionDetailsProps> = ({ session }) => {
<TableRow> <TableRow>
<TableCell variant="head">User ID</TableCell> <TableCell variant="head">User ID</TableCell>
<TableCell align="left"> <TableCell align="left">
<IDField value={session?.userId} /> <IDField
value={session?.userId || `${profile?.claims?.sub}`}
/>
</TableCell> </TableCell>
</TableRow> </TableRow>
<TableRow> <TableRow>
<TableCell variant="head">Expires At</TableCell> <TableCell variant="head">Expires At</TableCell>
<TableCell align="left">{session?.expiresAt || ""}</TableCell> <TableCell align="left">{session?.expiresAt || ""}</TableCell>
</TableRow> </TableRow>
{Object.entries(session?.claims || {}).map( {Object.entries(session?.claims || {}).map(([key, values]) => (
([key, values]) => ( <ClaimRow
<TableRow key={key}> key={`session/${key}`}
<TableCell variant="head">{startCase(key)}</TableCell> claimKey={key}
<TableCell align="left"> claimValue={values}
{values?.map((v, i) => ( />
<React.Fragment key={`${v}`}> ))}
{i > 0 ? <br /> : <></>} {Object.entries(profile?.claims || {}).map(([key, value]) => (
<ClaimValue claimKey={key} claimValue={v} /> <ClaimRow
</React.Fragment> key={`profile/${key}`}
))} claimKey={key}
</TableCell> claimValue={value}
</TableRow> />
))} ))}
</TableBody> </TableBody>
</Table> </Table>

View file

@ -81,7 +81,9 @@ const UserInfoPage: FC<UserInfoPageProps> = ({ data }) => {
marginLeft: mdUp ? "256px" : "0px", marginLeft: mdUp ? "256px" : "0px",
}} }}
> >
{subpage === "User" && <SessionDetails session={data?.session} />} {subpage === "User" && (
<SessionDetails session={data?.session} profile={data?.profile} />
)}
{subpage === "Groups Info" && ( {subpage === "Groups Info" && (
<GroupDetails <GroupDetails

View file

@ -13,6 +13,13 @@ export type Group = {
name: string; name: string;
}; };
export type Profile = {
providerId: string;
idToken: string;
oauthToken: string;
claims: Record<string, unknown>;
};
export type Session = { export type Session = {
audience: string[]; audience: string[];
claims: Claims; claims: Claims;
@ -108,6 +115,7 @@ export type UserInfoData = {
isEnterprise?: boolean; isEnterprise?: boolean;
session?: Session; session?: Session;
user?: User; user?: User;
profile?: Profile;
webAuthnCreationOptions?: WebAuthnCreationOptions; webAuthnCreationOptions?: WebAuthnCreationOptions;
webAuthnRequestOptions?: WebAuthnRequestOptions; webAuthnRequestOptions?: WebAuthnRequestOptions;
webAuthnUrl?: string; webAuthnUrl?: string;