From 57217af7dde6d2e00707b26a9457200a5be0293e Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 5 Dec 2022 15:31:07 -0700 Subject: [PATCH] authenticate: implement hpke-based login flow (#3779) * urlutil: add time validation functions * authenticate: implement hpke-based login flow * fix import cycle * fix tests * log error * fix callback url * add idp param * fix test * fix test --- authenticate/handlers.go | 127 ++-------- authenticate/handlers_test.go | 98 -------- authenticate/identity_profile.go | 104 ++++++++ authenticate/state.go | 6 +- authorize/check_response.go | 37 +-- authorize/check_response_test.go | 23 +- authorize/state.go | 14 ++ integration/authorization_test.go | 6 +- integration/benchmark_test.go | 4 +- integration/control_plane_test.go | 6 +- integration/main_test.go | 48 +++- integration/policy_test.go | 18 +- internal/handlers/sign_in.go | 89 +++++++ internal/handlers/userinfo.go | 5 + internal/sessions/state.go | 2 +- pkg/hpke/jwks_test.go | 7 +- proxy/handlers.go | 175 +++++++++++--- proxy/handlers_test.go | 341 --------------------------- proxy/identity_profile.go | 79 +++++++ proxy/proxy_test.go | 11 +- proxy/state.go | 24 +- ui/src/components/ClaimRow.tsx | 31 +++ ui/src/components/SessionDetails.tsx | 50 ++-- ui/src/components/UserInfoPage.tsx | 4 +- ui/src/types/index.ts | 8 + 25 files changed, 656 insertions(+), 661 deletions(-) create mode 100644 authenticate/identity_profile.go create mode 100644 internal/handlers/sign_in.go create mode 100644 proxy/identity_profile.go create mode 100644 ui/src/components/ClaimRow.tsx diff --git a/authenticate/handlers.go b/authenticate/handlers.go index b9a5385ef..74d2bed26 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -13,8 +13,6 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" "github.com/rs/cors" - "golang.org/x/oauth2" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/csrf" "github.com/pomerium/datasource/pkg/directory" @@ -33,6 +31,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/hpke" "github.com/pomerium/pomerium/pkg/webauthnutil" ) @@ -95,7 +94,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) { sr.Use(a.RetrieveSession) sr.Use(a.VerifySession) sr.Path("/").Handler(a.requireValidSignatureOnRedirect(a.userInfo)) - sr.Path("/sign_in").Handler(a.requireValidSignature(a.SignIn)) + sr.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn)) sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) sr.Path("/webauthn").Handler(a.webauthn) sr.Path("/device-enrolled").Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { @@ -149,15 +148,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return a.reauthenticateOrFail(w, r, err) } - if state.dataBrokerClient == nil { - return errors.New("authenticate: databroker client cannot be nil") - } - if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil { + _, err = loadIdentityProfile(r, state.cookieCipher) + if err != nil { log.FromRequest(r).Info(). Err(err). Str("idp_id", idp.GetId()). - Str("id", sessionState.ID). - Msg("authenticate: session not found in databroker") + Msg("authenticate: identity profile load error") return a.reauthenticateOrFail(w, r, err) } @@ -180,25 +176,18 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { state := a.state.Load() options := a.options.Load() - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + + if err := r.ParseForm(); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + proxyPublicKey, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form) if err != nil { return err } - redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) + idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID)) if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - jwtAudience := []string{state.redirectURL.Host, redirectURL.Host} - - // if the callback is explicitly set, set it and add an additional audience - if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" { - callbackURL, err := urlutil.ParseAndValidateURL(callbackStr) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - jwtAudience = append(jwtAudience, callbackURL.Host) + return err } s, err := a.getSessionFromCtx(ctx) @@ -212,33 +201,22 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { s = sessions.NewState(idp.GetId()) } - newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience) - // re-persist the session, useful when session was evicted from session if err := state.sessionStore.SaveSession(w, r, s); err != nil { return httputil.NewError(http.StatusBadRequest, err) } - // sign the route session, as a JWT - signedJWT, err := state.sharedEncoder.Marshal(newSession) + profile, err := loadIdentityProfile(r, state.cookieCipher) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } - // encrypt our route-scoped JWT to avoid accidental logging of queryparams - encryptedJWT := cryptutil.Encrypt(a.state.Load().sharedCipher, signedJWT, nil) - // base64 our encrypted payload for URL-friendlyness - encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) - - callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT) + redirectTo, err := handlers.BuildCallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile) if err != nil { - return httputil.NewError(http.StatusBadRequest, err) + return httputil.NewError(http.StatusInternalServerError, err) } - // build our hmac-d redirect URL with our session, pointing back to the - // proxy's callback URL which is responsible for setting our new route-session - uri := urlutil.NewSignedURL(state.sharedKey, callbackURL) - httputil.Redirect(w, r, uri.String(), http.StatusFound) + httputil.Redirect(w, r, redirectTo, http.StatusFound) return nil } @@ -460,10 +438,11 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) } // save the session and access token to the databroker - err = a.saveSessionToDataBroker(ctx, r, &newState, claims, accessToken) + profile, err := a.buildIdentityProfile(ctx, r, &newState, claims, accessToken) if err != nil { return nil, httputil.NewError(http.StatusInternalServerError, err) } + storeIdentityProfile(w, state.cookieCipher, profile) // ... and the user state to local storage. if err := state.sessionStore.SaveSession(w, r, &newState); err != nil { @@ -542,11 +521,14 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, } creationOptions, requestOptions, _ := a.webauthn.GetOptions(r) + profile, _ := loadIdentityProfile(r, state.cookieCipher) + data := handlers.UserInfoData{ CSRFToken: csrf.Token(r), IsImpersonated: isImpersonated, Session: pbSession, User: pbUser, + Profile: profile, WebAuthnCreationOptions: creationOptions, WebAuthnRequestOptions: requestOptions, @@ -582,73 +564,6 @@ func (a *Authenticate) fillEnterpriseUserInfoData( } } -func (a *Authenticate) saveSessionToDataBroker( - ctx context.Context, - r *http.Request, - sessionState *sessions.State, - claims identity.SessionClaims, - accessToken *oauth2.Token, -) error { - state := a.state.Load() - options := a.options.Load() - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) - if err != nil { - return err - } - - authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) - if err != nil { - return err - } - - sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire)) - idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time()) - - s := &session.Session{ - Id: sessionState.ID, - UserId: sessionState.UserID(authenticator.Name()), - IssuedAt: timestamppb.Now(), - AccessedAt: timestamppb.Now(), - ExpiresAt: sessionExpiry, - IdToken: &session.IDToken{ - Issuer: sessionState.Issuer, // todo(bdd): the issuer is not authN but the downstream IdP from the claims - Subject: sessionState.Subject, - ExpiresAt: sessionExpiry, - IssuedAt: idTokenIssuedAt, - }, - OauthToken: manager.ToOAuthToken(accessToken), - Audience: sessionState.Audience, - } - s.SetRawIDToken(claims.RawIDToken) - s.AddClaims(claims.Flatten()) - - var managerUser manager.User - managerUser.User, _ = user.Get(ctx, state.dataBrokerClient, s.GetUserId()) - if managerUser.User == nil { - // if no user exists yet, create a new one - managerUser.User = &user.User{ - Id: s.GetUserId(), - } - } - err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser) - if err != nil { - return fmt.Errorf("authenticate: error retrieving user info: %w", err) - } - _, err = databroker.Put(ctx, state.dataBrokerClient, managerUser.User) - if err != nil { - return fmt.Errorf("authenticate: error saving user: %w", err) - } - - res, err := session.Put(ctx, state.dataBrokerClient, s) - if err != nil { - return fmt.Errorf("authenticate: error saving session: %w", err) - } - sessionState.DatabrokerServerVersion = res.GetServerVersion() - sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() - - return nil -} - // revokeSession always clears the local session and tries to revoke the associated session stored in the // databroker. If successful, it returns the original `id_token` of the session, if failed, returns // and empty string. diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 5d32ea136..e0f587304 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -23,7 +23,6 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" - "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/handlers/webauthn" @@ -108,87 +107,6 @@ func TestAuthenticate_Handler(t *testing.T) { } } -func TestAuthenticate_SignIn(t *testing.T) { - t.Parallel() - tests := []struct { - name string - - scheme string - host string - qp map[string]string - - session sessions.SessionStore - provider identity.MockProvider - encoder encoding.MarshalUnmarshaler - wantCode int - }{ - {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) - - a := &Authenticate{ - cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { - return tt.provider, nil - })), - state: atomicutil.NewValue(&authenticateState{ - sharedCipher: sharedCipher, - sessionStore: tt.session, - redirectURL: uriParseHelper("https://some.example"), - sharedEncoder: tt.encoder, - dataBrokerClient: mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - return &databroker.GetResponse{ - Record: databroker.NewRecord(&session.Session{ - Id: "SESSION_ID", - }), - }, nil - }, - }, - }), - - options: config.NewAtomicOptions(), - } - a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())}) - uri := &url.URL{Scheme: tt.scheme, Host: tt.host} - - queryString := uri.Query() - for k, v := range tt.qp { - queryString.Set(k, v) - } - uri.RawQuery = queryString.Encode() - r := httptest.NewRequest(http.MethodGet, uri.String(), nil) - r.Header.Set("Accept", "application/json") - state, err := tt.session.LoadSession(r) - ctx := r.Context() - ctx = sessions.NewContext(ctx, state, err) - r = r.WithContext(ctx) - - w := httptest.NewRecorder() - httputil.HandlerFunc(a.SignIn).ServeHTTP(w, r) - if status := w.Code; status != tt.wantCode { - t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri) - t.Errorf("\n%+v", w.Body) - } - }) - } -} - func uriParseHelper(s string) *url.URL { uri, _ := url.Parse(s) return uri @@ -475,14 +393,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { wantStatus int }{ - { - "good", - nil, - &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, - nil, - identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, - http.StatusOK, - }, { "invalid session", nil, @@ -491,14 +401,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { identity.MockProvider{}, http.StatusFound, }, - { - "good refresh expired", - nil, - &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, - nil, - identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, - http.StatusOK, - }, { "expired,refresh error", nil, diff --git a/authenticate/identity_profile.go b/authenticate/identity_profile.go new file mode 100644 index 000000000..fcd9e171f --- /dev/null +++ b/authenticate/identity_profile.go @@ -0,0 +1,104 @@ +package authenticate + +import ( + "context" + "crypto/cipher" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + + "golang.org/x/oauth2" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" +) + +var cookieChunker = httputil.NewCookieChunker() + +func (a *Authenticate) buildIdentityProfile( + ctx context.Context, + r *http.Request, + sessionState *sessions.State, + claims identity.SessionClaims, + oauthToken *oauth2.Token, +) (*identitypb.Profile, error) { + options := a.options.Load() + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err) + } + + authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) + if err != nil { + return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err) + } + + err = authenticator.UpdateUserInfo(ctx, oauthToken, &claims) + if err != nil { + return nil, fmt.Errorf("authenticate: error retrieving user info: %w", err) + } + + rawIDToken := []byte(claims.RawIDToken) + rawOAuthToken, err := json.Marshal(oauthToken) + if err != nil { + return nil, fmt.Errorf("authenticate: error marshaling oauth token: %w", err) + } + rawClaims, err := structpb.NewStruct(claims.Claims) + if err != nil { + return nil, fmt.Errorf("authenticate: error creating claims struct: %w", err) + } + + return &identitypb.Profile{ + ProviderId: idp.GetId(), + IdToken: rawIDToken, + OauthToken: rawOAuthToken, + Claims: rawClaims, + }, nil +} + +func loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) { + cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile) + if err != nil { + return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err) + } + + encrypted, err := base64.RawURLEncoding.DecodeString(cookie.Value) + if err != nil { + return nil, fmt.Errorf("authenticate: error decoding identity profile cookie: %w", err) + } + + decrypted, err := cryptutil.Decrypt(aead, encrypted, nil) + if err != nil { + return nil, fmt.Errorf("authenticate: error decrypting identity profile cookie: %w", err) + } + + var profile identitypb.Profile + err = protojson.Unmarshal(decrypted, &profile) + if err != nil { + return nil, fmt.Errorf("authenticate: error unmarshaling identity profile cookie: %w", err) + } + return &profile, nil +} + +func storeIdentityProfile(w http.ResponseWriter, aead cipher.AEAD, profile *identitypb.Profile) { + decrypted, err := protojson.Marshal(profile) + if err != nil { + // this shouldn't happen + panic(fmt.Errorf("error marshaling message: %w", err)) + } + encrypted := cryptutil.Encrypt(aead, decrypted, nil) + err = cookieChunker.SetCookie(w, &http.Cookie{ + Name: urlutil.QueryIdentityProfile, + Value: base64.RawURLEncoding.EncodeToString(encrypted), + Path: "/", + }) + log.Error(context.Background()).Err(err).Send() +} diff --git a/authenticate/state.go b/authenticate/state.go index d9b77ef8f..59094590c 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -18,6 +18,7 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/hpke" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) @@ -39,7 +40,8 @@ type authenticateState struct { sessionStore sessions.SessionStore // sessionLoaders are a collection of session loaders to attempt to pull // a user's session state from - sessionLoader sessions.SessionLoader + sessionLoader sessions.SessionLoader + hpkePrivateKey *hpke.PrivateKey jwk *jose.JSONWebKeySet @@ -137,6 +139,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err return nil, err } + state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey) + dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, InstallationID: cfg.Options.InstallationID, diff --git a/authorize/check_response.go b/authorize/check_response.go index 251554f5e..587ccf840 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/codes" "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/requestid" @@ -174,34 +175,42 @@ func (a *Authorize) requireLoginResponse( in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request, ) (*envoy_service_auth_v3.CheckResponse, error) { - opts := a.currentOptions.Load() + options := a.currentOptions.Load() state := a.state.Load() - authenticateURL, err := opts.GetAuthenticateURL() - if err != nil { - return nil, err - } if !a.shouldRedirect(in) { return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil) } - signinURL := authenticateURL.ResolveReference(&url.URL{ - Path: "/.pomerium/sign_in", - }) - q := signinURL.Query() + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return nil, err + } + + idp, err := options.GetIdentityProviderForPolicy(request.Policy) + if err != nil { + return nil, err + } + + authenticateHPKEPublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx) + if err != nil { + return nil, err + } // always assume https scheme checkRequestURL := getCheckRequestURL(in) checkRequestURL.Scheme = "https" - q.Set(urlutil.QueryRedirectURI, checkRequestURL.String()) - idp, err := opts.GetIdentityProviderForPolicy(request.Policy) + redirectTo, err := handlers.BuildSignInURL( + state.hpkePrivateKey, + authenticateHPKEPublicKey, + authenticateURL, + &checkRequestURL, + idp.GetId(), + ) if err != nil { return nil, err } - q.Set(urlutil.QueryIdentityProviderID, idp.GetId()) - signinURL.RawQuery = q.Encode() - redirectTo := urlutil.NewSignedURL(state.sharedKey, signinURL).String() return a.deniedResponse(ctx, in, http.StatusFound, "Login", map[string]string{ "Location": redirectTo, diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 9844ef4fb..0b9b0fa85 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -3,6 +3,7 @@ package authorize import ( "context" "net/http" + "net/http/httptest" "net/url" "testing" @@ -19,15 +20,23 @@ import ( "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/policy/criteria" ) func TestAuthorize_handleResult(t *testing.T) { opt := config.NewDefaultOptions() - opt.AuthenticateURLString = "https://authenticate.example.com" opt.DataBrokerURLString = "https://databroker.example.com" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" + + htpkePrivateKey, err := opt.GetHPKEPrivateKey() + require.NoError(t, err) + + authnSrv := httptest.NewServer(handlers.JWKSHandler(opt.SigningKey, htpkePrivateKey.PublicKey())) + t.Cleanup(authnSrv.Close) + opt.AuthenticateURLString = authnSrv.URL + a, err := New(&config.Config{Options: opt}) require.NoError(t, err) @@ -179,10 +188,20 @@ func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL { } func TestRequireLogin(t *testing.T) { + t.Parallel() + opt := config.NewDefaultOptions() - opt.AuthenticateURLString = "https://authenticate.example.com" opt.DataBrokerURLString = "https://databroker.example.com" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" + opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" + + htpkePrivateKey, err := opt.GetHPKEPrivateKey() + require.NoError(t, err) + + authnSrv := httptest.NewServer(handlers.JWKSHandler(opt.SigningKey, htpkePrivateKey.PublicKey())) + t.Cleanup(authnSrv.Close) + opt.AuthenticateURLString = authnSrv.URL + a, err := New(&config.Config{Options: opt}) require.NoError(t, err) diff --git a/authorize/state.go b/authorize/state.go index 4bf58a040..46c9384f2 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -3,6 +3,7 @@ package authorize import ( "context" "fmt" + "net/url" googlegrpc "google.golang.org/grpc" @@ -11,6 +12,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/hpke" "github.com/pomerium/pomerium/pkg/protoutil" ) @@ -23,6 +25,8 @@ type authorizeState struct { dataBrokerClient databroker.DataBrokerServiceClient auditEncryptor *protoutil.Encryptor sessionStore *config.SessionStore + hpkePrivateKey *hpke.PrivateKey + authenticateKeyFetcher hpke.KeyFetcher } func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*authorizeState, error) { @@ -74,5 +78,15 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho return nil, fmt.Errorf("authorize: invalid session store: %w", err) } + authenticateURL, err := cfg.Options.GetAuthenticateURL() + if err != nil { + return nil, fmt.Errorf("authorize: invalid authenticate service url: %w", err) + } + + state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey) + state.authenticateKeyFetcher = hpke.NewKeyFetcher(authenticateURL.ResolveReference(&url.URL{ + Path: "/.well-known/pomerium/jwks.json", + }).String()) + return state, nil } diff --git a/integration/authorization_test.go b/integration/authorization_test.go index 6b0bc3185..10306a149 100644 --- a/integration/authorization_test.go +++ b/integration/authorization_test.go @@ -25,7 +25,7 @@ func TestAuthorization(t *testing.T) { } t.Run("public", func(t *testing.T) { - client := getClient() + client := getClient(t) req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io", nil) if err != nil { @@ -43,7 +43,7 @@ func TestAuthorization(t *testing.T) { t.Run("domains", func(t *testing.T) { t.Run("allowed", func(t *testing.T) { - client := getClient() + client := getClient(t) res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"), withAPI, flows.WithEmail("user1@dogs.test")) if assert.NoError(t, err) { @@ -51,7 +51,7 @@ func TestAuthorization(t *testing.T) { } }) t.Run("not allowed", func(t *testing.T) { - client := getClient() + client := getClient(t) res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"), withAPI, flows.WithEmail("user1@cats.test")) if assert.NoError(t, err) { diff --git a/integration/benchmark_test.go b/integration/benchmark_test.go index a1ae68a61..f2b17c89b 100644 --- a/integration/benchmark_test.go +++ b/integration/benchmark_test.go @@ -12,7 +12,7 @@ import ( func BenchmarkLoggedInUserAccess(b *testing.B) { ctx := context.Background() - client := getClient() + client := getClient(b) res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-domain"), flows.WithEmail("user1@dogs.test")) require.NoError(b, err) @@ -30,7 +30,7 @@ func BenchmarkLoggedInUserAccess(b *testing.B) { func BenchmarkLoggedOutUserAccess(b *testing.B) { ctx := context.Background() - client := getClient() + client := getClient(b) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/integration/control_plane_test.go b/integration/control_plane_test.go index 36118f2e9..f34445b33 100644 --- a/integration/control_plane_test.go +++ b/integration/control_plane_test.go @@ -21,7 +21,7 @@ func TestDashboard(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -37,7 +37,7 @@ func TestDashboard(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -69,7 +69,7 @@ func TestHealth(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } diff --git a/integration/main_test.go b/integration/main_test.go index 3f8eb87ef..f79d73306 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -50,10 +50,21 @@ func TestMain(m *testing.M) { os.Exit(status) } -func getClient() *http.Client { - jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) - if err != nil { - panic(err) +type loggingRoundTripper struct { + t testing.TB + transport http.RoundTripper +} + +func (l loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if l.t != nil { + l.t.Logf("%s %s", req.Method, req.URL.String()) + } + return l.transport.RoundTrip(req) +} + +func getTransport(t testing.TB) http.RoundTripper { + if t != nil { + t.Helper() } rootCAs, err := x509.SystemCertPool() @@ -66,23 +77,36 @@ func getClient() *http.Client { panic(err) } _ = rootCAs.AppendCertsFromPEM(bs) + transport := &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + }, + } + return loggingRoundTripper{t, transport} +} + +func getClient(t testing.TB) *http.Client { + if t != nil { + t.Helper() + } + + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + panic(err) + } return &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, - Transport: &http.Transport{ - DisableKeepAlives: true, - TLSClientConfig: &tls.Config{ - RootCAs: rootCAs, - }, - }, - Jar: jar, + Transport: getTransport(t), + Jar: jar, } } func waitForHealthy(ctx context.Context) error { - client := getClient() + client := getClient(nil) check := func(endpoint string) error { reqCtx, clearTimeout := context.WithTimeout(ctx, time.Second) defer clearTimeout() diff --git a/integration/policy_test.go b/integration/policy_test.go index b9d1a9bf2..3e31139f2 100644 --- a/integration/policy_test.go +++ b/integration/policy_test.go @@ -31,7 +31,7 @@ func TestQueryStringParams(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -65,7 +65,7 @@ func TestCORS(t *testing.T) { req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -81,7 +81,7 @@ func TestCORS(t *testing.T) { req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -102,7 +102,7 @@ func TestPreserveHostHeader(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -127,7 +127,7 @@ func TestPreserveHostHeader(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -158,7 +158,7 @@ func TestSetRequestHeaders(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -187,7 +187,7 @@ func TestRemoveRequestHeaders(t *testing.T) { } req.Header.Add("X-Custom-Request-Header-To-Remove", "foo") - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -250,7 +250,7 @@ func TestGoogleCloudRun(t *testing.T) { t.Fatal(err) } - res, err := getClient().Do(req) + res, err := getClient(t).Do(req) if !assert.NoError(t, err, "unexpected http error") { return } @@ -274,7 +274,7 @@ func TestLoadBalancer(t *testing.T) { defer clearTimeout() getDistribution := func(t *testing.T, path string) map[string]float64 { - client := getClient() + client := getClient(t) distribution := map[string]float64{} res, err := flows.Authenticate(ctx, client, diff --git a/internal/handlers/sign_in.go b/internal/handlers/sign_in.go new file mode 100644 index 000000000..077d63452 --- /dev/null +++ b/internal/handlers/sign_in.go @@ -0,0 +1,89 @@ +package handlers + +import ( + "fmt" + "net/url" + "time" + + "google.golang.org/protobuf/encoding/protojson" + + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/hpke" +) + +const signInExpiry = time.Minute * 5 + +// BuildCallbackURL builds the callback URL using an HPKE encrypted query string. +func BuildCallbackURL( + authenticatePrivateKey *hpke.PrivateKey, + proxyPublicKey *hpke.PublicKey, + requestParams url.Values, + profile *identity.Profile, +) (string, error) { + redirectURL, err := urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryRedirectURI)) + if err != nil { + return "", fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err) + } + + var callbackURL *url.URL + if requestParams.Has(urlutil.QueryCallbackURI) { + callbackURL, err = urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryCallbackURI)) + if err != nil { + return "", fmt.Errorf("invalid %s: %w", urlutil.QueryCallbackURI, err) + } + } else { + callbackURL, err = urlutil.DeepCopy(redirectURL) + if err != nil { + return "", fmt.Errorf("error copying %s: %w", urlutil.QueryRedirectURI, err) + } + callbackURL.Path = "/.pomerium/callback/" + callbackURL.RawQuery = "" + } + + callbackParams := callbackURL.Query() + if requestParams.Has(urlutil.QueryIsProgrammatic) { + callbackParams.Set(urlutil.QueryIsProgrammatic, "true") + } + callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String()) + + rawProfile, err := protojson.Marshal(profile) + if err != nil { + return "", fmt.Errorf("error marshaling identity profile: %w", err) + } + callbackParams.Set(urlutil.QueryIdentityProfile, string(rawProfile)) + + urlutil.BuildTimeParameters(callbackParams, signInExpiry) + + callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams) + if err != nil { + return "", fmt.Errorf("error encrypting callback params: %w", err) + } + callbackURL.RawQuery = callbackParams.Encode() + + return callbackURL.String(), nil +} + +// BuildSignInURL buidls the sign in URL using an HPKE encrypted query string. +func BuildSignInURL( + senderPrivateKey *hpke.PrivateKey, + authenticatePublicKey *hpke.PublicKey, + authenticateURL *url.URL, + redirectURL *url.URL, + idpID string, +) (string, error) { + signInURL := *authenticateURL + signInURL.Path = "/.pomerium/sign_in" + + q := signInURL.Query() + q.Set(urlutil.QueryRedirectURI, redirectURL.String()) + q.Set(urlutil.QueryIdentityProviderID, idpID) + urlutil.BuildTimeParameters(q, signInExpiry) + q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q) + if err != nil { + return "", err + } + signInURL.RawQuery = q.Encode() + + return signInURL.String(), nil +} diff --git a/internal/handlers/userinfo.go b/internal/handlers/userinfo.go index e92a0767e..ea3737d1a 100644 --- a/internal/handlers/userinfo.go +++ b/internal/handlers/userinfo.go @@ -8,6 +8,7 @@ import ( "github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/ui" @@ -20,6 +21,7 @@ type UserInfoData struct { IsImpersonated bool Session *session.Session User *user.User + Profile *identity.Profile IsEnterprise bool DirectoryUser *directory.User @@ -43,6 +45,9 @@ func (data UserInfoData) ToJSON() map[string]any { if bs, err := protojson.Marshal(data.User); err == nil { m["user"] = json.RawMessage(bs) } + if bs, err := protojson.Marshal(data.Profile); err == nil { + m["profile"] = json.RawMessage(bs) + } m["isEnterprise"] = data.IsEnterprise if data.DirectoryUser != nil { m["directoryUser"] = data.DirectoryUser diff --git a/internal/sessions/state.go b/internal/sessions/state.go index 82a084c40..cd39599a7 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -60,7 +60,7 @@ func (s *State) WithNewIssuer(issuer string, audience []string) State { } // UserID returns the corresponding user ID for a session. -func (s *State) UserID(provider string) string { +func (s *State) UserID() string { if s.OID != "" { return s.OID } diff --git a/pkg/hpke/jwks_test.go b/pkg/hpke/jwks_test.go index 7a85f917d..82dac60d3 100644 --- a/pkg/hpke/jwks_test.go +++ b/pkg/hpke/jwks_test.go @@ -1,4 +1,4 @@ -package hpke +package hpke_test import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/pkg/hpke" ) func TestFetchPublicKeyFromJWKS(t *testing.T) { @@ -19,7 +20,7 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) t.Cleanup(clearTimeout) - hpkePrivateKey, err := GeneratePrivateKey() + hpkePrivateKey, err := hpke.GeneratePrivateKey() require.NoError(t, err) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -27,7 +28,7 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) { })) t.Cleanup(srv.Close) - publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL) + publicKey, err := hpke.FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL) assert.NoError(t, err) assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String()) } diff --git a/proxy/handlers.go b/proxy/handlers.go index 5c8ef42b2..990fa0036 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -1,7 +1,7 @@ package proxy import ( - "encoding/base64" + "context" "errors" "fmt" "io" @@ -9,12 +9,17 @@ import ( "net/url" "github.com/gorilla/mux" + "google.golang.org/protobuf/encoding/protojson" "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/hpke" ) // registerDashboardHandlers returns the proxy service's ServeMux @@ -32,9 +37,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { // called following authenticate auth flow to grab a new or existing session // the route specific cookie is returned in a signed query params c := r.PathPrefix(dashboardPath + "/callback").Subrouter() - c.Use(func(h http.Handler) http.Handler { - return middleware.ValidateSignature(p.state.Load().sharedKey)(h) - }) c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet) // Programmatic API handlers and middleware @@ -105,55 +107,96 @@ func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error { // Callback handles the result of a successful call to the authenticate service // and is responsible setting per-route sessions. func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error { - redirectURLString := r.FormValue(urlutil.QueryRedirectURI) - encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) + state := p.state.Load() + options := p.currentOptions.Load() - redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) + if err := r.ParseForm(); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + // decrypt the URL values + senderPublicKey, values, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form) + if err != nil { + return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err)) + } + + // confirm this request came from the authenticate service + err = p.validateSenderPublicKey(r.Context(), senderPublicKey) + if err != nil { + return err + } + + // validate that the request has not expired + err = urlutil.ValidateTimeParameters(values) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } - rawJWT, err := p.saveCallbackSession(w, r, encryptedSession) + profile, err := getProfileFromValues(values) if err != nil { - return httputil.NewError(http.StatusBadRequest, err) + return err + } + + ss := newSessionStateFromProfile(profile) + s, err := session.Get(r.Context(), state.dataBrokerClient, ss.ID) + if err != nil { + s = &session.Session{Id: ss.ID} + } + populateSessionFromProfile(s, profile, ss, options.CookieExpire) + u, err := user.Get(r.Context(), state.dataBrokerClient, ss.UserID()) + if err != nil { + u = &user.User{Id: ss.UserID()} + } + populateUserFromProfile(u, profile, ss) + + redirectURI, err := getRedirectURIFromValues(values) + if err != nil { + return err + } + + // save the records + res, err := state.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{ + Records: []*databroker.Record{ + databroker.NewRecord(s), + databroker.NewRecord(u), + }, + }) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err)) + } + ss.DatabrokerServerVersion = res.GetServerVersion() + for _, record := range res.GetRecords() { + if record.GetVersion() > ss.DatabrokerRecordVersion { + ss.DatabrokerRecordVersion = record.GetVersion() + } + } + + // save the session state + rawJWT, err := state.encoder.Marshal(ss) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err)) + } + if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err)) } // if programmatic, encode the session jwt as a query param - if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { - q := redirectURL.Query() + if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { + q := redirectURI.Query() q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) - redirectURL.RawQuery = q.Encode() + redirectURI.RawQuery = q.Encode() } - httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) + + // redirect + httputil.Redirect(w, r, redirectURI.String(), http.StatusFound) return nil } -// saveCallbackSession takes an encrypted per-route session token, decrypts -// it using the shared service key, then stores it the local session store. -func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) { - state := p.state.Load() - - // 1. extract the base64 encoded and encrypted JWT from query params - encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken) - if err != nil { - return nil, fmt.Errorf("proxy: malfromed callback token: %w", err) - } - // 2. decrypt the JWT using the cipher using the _shared_ secret key - rawJWT, err := cryptutil.Decrypt(state.sharedCipher, encryptedJWT, nil) - if err != nil { - return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err) - } - // 3. Save the decrypted JWT to the session store directly as a string, without resigning - if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil { - return nil, fmt.Errorf("proxy: callback session save failure: %w", err) - } - return rawJWT, nil -} - // ProgrammaticLogin returns a signed url that can be used to login // using the authenticate service. func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error { state := p.state.Load() + options := p.currentOptions.Load() redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { @@ -164,19 +207,32 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri")) } + idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + + hpkeAuthenticateKey, err := state.authenticateKeyFetcher.FetchPublicKey(r.Context()) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + signinURL := *state.authenticateSigninURL callbackURI := urlutil.GetAbsoluteURL(r) callbackURI.Path = dashboardPath + "/callback/" q := signinURL.Query() q.Set(urlutil.QueryCallbackURI, callbackURI.String()) - q.Set(urlutil.QueryRedirectURI, redirectURI.String()) q.Set(urlutil.QueryIsProgrammatic, "true") signinURL.RawQuery = q.Encode() - response := urlutil.NewSignedURL(state.sharedKey, &signinURL).String() + + rawURL, err := handlers.BuildSignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId()) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusOK) - _, _ = io.WriteString(w, response) + _, _ = io.WriteString(w, rawURL) return nil } @@ -191,3 +247,44 @@ func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error { _, _ = io.WriteString(w, assertionJWT) return nil } + +func (p *Proxy) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error { + state := p.state.Load() + + authenticatePublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err)) + } + + if !authenticatePublicKey.Equals(senderPublicKey) { + return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key")) + } + + return nil +} + +func getProfileFromValues(values url.Values) (*identity.Profile, error) { + rawProfile := values.Get(urlutil.QueryIdentityProfile) + if rawProfile == "" { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile)) + } + + var profile identity.Profile + err := protojson.Unmarshal([]byte(rawProfile), &profile) + if err != nil { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err)) + } + return &profile, nil +} + +func getRedirectURIFromValues(values url.Values) (*url.URL, error) { + rawRedirectURI := values.Get(urlutil.QueryRedirectURI) + if rawRedirectURI == "" { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI)) + } + redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI) + if err != nil { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err)) + } + return redirectURI, nil +} diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 36c17d44c..741d720b8 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -2,30 +2,20 @@ package proxy import ( "bytes" - "context" - "errors" "net/http" "net/http/httptest" "net/url" - "strings" "testing" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" - "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/sessions" - mstore "github.com/pomerium/pomerium/internal/sessions/mock" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" ) -const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg=" - func TestProxy_RobotsTxt(t *testing.T) { proxy := Proxy{} req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) @@ -40,29 +30,6 @@ func TestProxy_RobotsTxt(t *testing.T) { } } -func TestProxy_Signout(t *testing.T) { - opts := testOptions(t) - err := ValidateOptions(opts) - if err != nil { - t.Fatal(err) - } - proxy, err := New(&config.Config{Options: opts}) - if err != nil { - t.Fatal(err) - } - req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil) - rr := httptest.NewRecorder() - proxy.SignOut(rr, req) - if status := rr.Code; status != http.StatusFound { - t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) - } - body := rr.Body.String() - want := proxy.state.Load().authenticateURL.String() - if !strings.Contains(body, want) { - t.Errorf("handler returned unexpected body: got %v want %s ", body, want) - } -} - func TestProxy_SignOut(t *testing.T) { t.Parallel() tests := []struct { @@ -104,165 +71,6 @@ func TestProxy_SignOut(t *testing.T) { } } -func TestProxy_Callback(t *testing.T) { - t.Parallel() - opts := testOptions(t) - tests := []struct { - name string - options *config.Options - - method string - - scheme string - host string - path string - - headers map[string]string - qp map[string]string - - cipher encoding.MarshalUnmarshaler - sessionStore sessions.SessionStore - wantStatus int - wantBody string - }{ - { - "good", - opts, - http.MethodGet, - "http", - "example.com", - "/", - nil, - map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusFound, - "", - }, - { - "good programmatic", - opts, - http.MethodGet, - "http", - "example.com", - "/", - nil, - map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusFound, - "", - }, - { - "bad decrypt", - opts, - http.MethodGet, - "http", - "example.com", - "/", - nil, - map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusBadRequest, - "", - }, - { - "bad save session", - opts, - http.MethodGet, - "http", - "example.com", - "/", - nil, - map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{SaveError: errors.New("hi")}, - http.StatusBadRequest, - "", - }, - { - "bad base64", - opts, - http.MethodGet, - "http", - "example.com", - "/", - nil, - map[string]string{urlutil.QuerySessionEncrypted: "^"}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusBadRequest, - "", - }, - { - "malformed redirect", - opts, - http.MethodGet, - "http", - "example.com", - "/", - nil, - nil, - &mock.Encoder{}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusBadRequest, - "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) - if err != nil { - t.Fatal(err) - } - p.OnConfigChange(context.Background(), &config.Config{Options: tt.options}) - state := p.state.Load() - state.encoder = tt.cipher - state.sessionStore = tt.sessionStore - redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path} - queryString := redirectURI.Query() - for k, v := range tt.qp { - queryString.Set(k, v) - } - redirectURI.RawQuery = queryString.Encode() - - uri := &url.URL{Path: "/"} - if tt.qp != nil { - qu := uri.Query() - for k, v := range tt.qp { - qu.Set(k, v) - } - qu.Set(urlutil.QueryRedirectURI, redirectURI.String()) - uri.RawQuery = qu.Encode() - } - - r := httptest.NewRequest(tt.method, uri.String(), nil) - - r.Header.Set("Accept", "application/json") - if len(tt.headers) != 0 { - for k, v := range tt.headers { - r.Header.Set(k, v) - } - } - - w := httptest.NewRecorder() - httputil.HandlerFunc(p.Callback).ServeHTTP(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("status code: got %v want %v", status, tt.wantStatus) - t.Errorf("\n%+v", w.Body.String()) - } - - if tt.wantBody != "" { - body := w.Body.String() - if diff := cmp.Diff(body, tt.wantBody); diff != "" { - t.Errorf("wrong body\n%s", diff) - } - } - }) - } -} - func TestProxy_ProgrammaticLogin(t *testing.T) { t.Parallel() opts := testOptions(t) @@ -360,155 +168,6 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { } } -func TestProxy_ProgrammaticCallback(t *testing.T) { - t.Parallel() - opts := testOptions(t) - tests := []struct { - name string - options *config.Options - - method string - - redirectURI string - - headers map[string]string - qp map[string]string - - cipher encoding.MarshalUnmarshaler - sessionStore sessions.SessionStore - wantStatus int - wantBody string - }{ - { - "good", - opts, - http.MethodGet, - "http://pomerium.io/", - nil, - map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusFound, - "", - }, - { - "good programmatic", - opts, - http.MethodGet, - "http://pomerium.io/", - nil, - map[string]string{ - urlutil.QueryIsProgrammatic: "true", - urlutil.QueryCallbackURI: "ok", - urlutil.QuerySessionEncrypted: goodEncryptionString, - }, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusFound, - "", - }, - { - "bad decrypt", - opts, - http.MethodGet, - "http://pomerium.io/", - nil, - map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusBadRequest, - "", - }, - { - "bad save session", - opts, - http.MethodGet, - "http://pomerium.io/", - nil, - map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{SaveError: errors.New("hi")}, - http.StatusBadRequest, - "", - }, - { - "bad base64", - opts, - http.MethodGet, - "http://pomerium.io/", - nil, - map[string]string{urlutil.QuerySessionEncrypted: "^"}, - &mock.Encoder{MarshalResponse: []byte("x")}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusBadRequest, - "", - }, - { - "malformed redirect", - opts, - http.MethodGet, - "http://pomerium.io/", - nil, - nil, - &mock.Encoder{}, - &mstore.Store{Session: &sessions.State{}}, - http.StatusBadRequest, - "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) - if err != nil { - t.Fatal(err) - } - p.OnConfigChange(context.Background(), &config.Config{Options: tt.options}) - state := p.state.Load() - state.encoder = tt.cipher - state.sessionStore = tt.sessionStore - redirectURI, _ := url.Parse(tt.redirectURI) - queryString := redirectURI.Query() - for k, v := range tt.qp { - queryString.Set(k, v) - } - redirectURI.RawQuery = queryString.Encode() - - uri := &url.URL{Path: "/"} - if tt.qp != nil { - qu := uri.Query() - for k, v := range tt.qp { - qu.Set(k, v) - } - qu.Set(urlutil.QueryRedirectURI, redirectURI.String()) - uri.RawQuery = qu.Encode() - } - - r := httptest.NewRequest(tt.method, uri.String(), nil) - - r.Header.Set("Accept", "application/json") - if len(tt.headers) != 0 { - for k, v := range tt.headers { - r.Header.Set(k, v) - } - } - - w := httptest.NewRecorder() - httputil.HandlerFunc(p.Callback).ServeHTTP(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("status code: got %v want %v", status, tt.wantStatus) - t.Errorf("\n%+v", w.Body.String()) - } - - if tt.wantBody != "" { - body := w.Body.String() - if diff := cmp.Diff(body, tt.wantBody); diff != "" { - t.Errorf("wrong body\n%s", diff) - } - } - }) - } -} - func TestProxy_jwt(t *testing.T) { // without upstream headers being set req, _ := http.NewRequest("GET", "https://www.example.com/.pomerium/jwt", nil) diff --git a/proxy/identity_profile.go b/proxy/identity_profile.go new file mode 100644 index 000000000..1bdb37eca --- /dev/null +++ b/proxy/identity_profile.go @@ -0,0 +1,79 @@ +package proxy + +import ( + "encoding/json" + "fmt" + "time" + + "golang.org/x/oauth2" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/identity/manager" + "github.com/pomerium/pomerium/internal/sessions" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State { + claims := p.GetClaims().AsMap() + + ss := sessions.NewState(p.GetProviderId()) + + // set the subject + if v, ok := claims["sub"]; ok { + ss.Subject = fmt.Sprint(v) + } else if v, ok := claims["user"]; ok { + ss.Subject = fmt.Sprint(v) + } + + // set the oid + if v, ok := claims["oid"]; ok { + ss.OID = fmt.Sprint(v) + } + + return ss +} + +func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) { + claims := p.GetClaims().AsMap() + oauthToken := new(oauth2.Token) + _ = json.Unmarshal(p.GetOauthToken(), oauthToken) + + s.UserId = ss.UserID() + s.IssuedAt = timestamppb.Now() + s.AccessedAt = timestamppb.Now() + s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire)) + s.IdToken = &session.IDToken{ + Issuer: ss.Issuer, + Subject: ss.Subject, + ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)), + IssuedAt: timestamppb.Now(), + Raw: string(p.GetIdToken()), + } + s.OauthToken = manager.ToOAuthToken(oauthToken) + if s.Claims == nil { + s.Claims = make(map[string]*structpb.ListValue) + } + for k, vs := range identity.Claims(claims).Flatten().ToPB() { + s.Claims[k] = vs + } +} + +func populateUserFromProfile(u *user.User, p *identitypb.Profile, ss *sessions.State) { + claims := p.GetClaims().AsMap() + if v, ok := claims["name"]; ok { + u.Name = fmt.Sprint(v) + } + if v, ok := claims["email"]; ok { + u.Email = fmt.Sprint(v) + } + if u.Claims == nil { + u.Claims = make(map[string]*structpb.ListValue) + } + for k, vs := range identity.Claims(claims).Flatten().ToPB() { + u.Claims[k] = vs + } +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index a8f5eadf0..5c2e5d86e 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -9,13 +9,15 @@ import ( "time" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/handlers" "github.com/stretchr/testify/require" ) func testOptions(t *testing.T) *config.Options { + t.Helper() + opts := config.NewDefaultOptions() - opts.AuthenticateURLString = "https://authenticate.example" to, err := config.ParseWeightedUrls("https://example.example") require.NoError(t, err) @@ -28,6 +30,13 @@ func testOptions(t *testing.T) *config.Options { opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" + htpkePrivateKey, err := opts.GetHPKEPrivateKey() + require.NoError(t, err) + + authnSrv := httptest.NewServer(handlers.JWKSHandler(opts.SigningKey, htpkePrivateKey.PublicKey())) + t.Cleanup(authnSrv.Close) + opts.AuthenticateURLString = authnSrv.URL + require.NoError(t, opts.Validate()) return opts diff --git a/proxy/state.go b/proxy/state.go index 73893722b..709ea2db8 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/hpke" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) @@ -26,10 +27,12 @@ type proxyState struct { authenticateSigninURL *url.URL authenticateRefreshURL *url.URL - encoder encoding.MarshalUnmarshaler - cookieSecret []byte - sessionStore sessions.SessionStore - jwtClaimHeaders config.JWTClaimHeaders + encoder encoding.MarshalUnmarshaler + cookieSecret []byte + sessionStore sessions.SessionStore + jwtClaimHeaders config.JWTClaimHeaders + hpkePrivateKey *hpke.PrivateKey + authenticateKeyFetcher hpke.KeyFetcher dataBrokerClient databroker.DataBrokerServiceClient @@ -44,11 +47,24 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { state := new(proxyState) + authenticateURL, err := cfg.Options.GetAuthenticateURL() + if err != nil { + return nil, err + } + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err } + state.hpkePrivateKey, err = cfg.Options.GetHPKEPrivateKey() + if err != nil { + return nil, err + } + state.authenticateKeyFetcher = hpke.NewKeyFetcher(authenticateURL.ResolveReference(&url.URL{ + Path: "/.well-known/pomerium/jwks.json", + }).String()) + state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey) if err != nil { return nil, err diff --git a/ui/src/components/ClaimRow.tsx b/ui/src/components/ClaimRow.tsx new file mode 100644 index 000000000..0af5a55f1 --- /dev/null +++ b/ui/src/components/ClaimRow.tsx @@ -0,0 +1,31 @@ +import TableCell from "@mui/material/TableCell"; +import TableRow from "@mui/material/TableRow"; +import { isArray, startCase } from "lodash"; +import React, { FC } from "react"; + +import ClaimValue from "./ClaimValue"; + +export type ClaimRowProps = { + claimKey: string; + claimValue: unknown; +}; +export const ClaimRow: FC = ({ claimKey, claimValue }) => { + return ( + + {startCase(claimKey)} + + {isArray(claimValue) ? ( + claimValue?.map((v, i) => ( + + {i > 0 ?
: <>} + +
+ )) + ) : ( + + )} +
+
+ ); +}; +export default ClaimRow; diff --git a/ui/src/components/SessionDetails.tsx b/ui/src/components/SessionDetails.tsx index 910d3510f..154951869 100644 --- a/ui/src/components/SessionDetails.tsx +++ b/ui/src/components/SessionDetails.tsx @@ -1,6 +1,3 @@ -import { Session } from "../types"; -import IDField from "./IDField"; -import Section from "./Section"; import Stack from "@mui/material/Stack"; import Table from "@mui/material/Table"; import TableBody from "@mui/material/TableBody"; @@ -8,13 +5,20 @@ import TableCell from "@mui/material/TableCell"; import TableContainer from "@mui/material/TableContainer"; import TableRow from "@mui/material/TableRow"; import React, { FC } from "react"; -import ClaimValue from "./ClaimValue"; -import {startCase} from "lodash"; + +import { Profile, Session } from "../types"; +import ClaimRow from "./ClaimRow"; +import IDField from "./IDField"; +import Section from "./Section"; export type SessionDetailsProps = { session: Session; + profile: Profile; }; -export const SessionDetails: FC = ({ session }) => { +export const SessionDetails: FC = ({ + session, + profile, +}) => { return (
@@ -22,7 +26,9 @@ export const SessionDetails: FC = ({ session }) => { - Session ID + + Session ID + @@ -30,26 +36,28 @@ export const SessionDetails: FC = ({ session }) => { User ID - + Expires At {session?.expiresAt || ""} - {Object.entries(session?.claims || {}).map( - ([key, values]) => ( - - {startCase(key)} - - {values?.map((v, i) => ( - - {i > 0 ?
: <>} - -
- ))} -
-
+ {Object.entries(session?.claims || {}).map(([key, values]) => ( + + ))} + {Object.entries(profile?.claims || {}).map(([key, value]) => ( + ))}
diff --git a/ui/src/components/UserInfoPage.tsx b/ui/src/components/UserInfoPage.tsx index d07b56a05..0146ca30a 100644 --- a/ui/src/components/UserInfoPage.tsx +++ b/ui/src/components/UserInfoPage.tsx @@ -81,7 +81,9 @@ const UserInfoPage: FC = ({ data }) => { marginLeft: mdUp ? "256px" : "0px", }} > - {subpage === "User" && } + {subpage === "User" && ( + + )} {subpage === "Groups Info" && ( ; +}; + export type Session = { audience: string[]; claims: Claims; @@ -108,6 +115,7 @@ export type UserInfoData = { isEnterprise?: boolean; session?: Session; user?: User; + profile?: Profile; webAuthnCreationOptions?: WebAuthnCreationOptions; webAuthnRequestOptions?: WebAuthnRequestOptions; webAuthnUrl?: string;