diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 5a7163c9f..7ae42dfd1 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -8,10 +8,7 @@ import ( "fmt" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/identity" - "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -28,15 +25,6 @@ func ValidateOptions(o *config.Options) error { if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil { return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %w", err) } - if o.Provider == "" { - return errors.New("authenticate: 'IDP_PROVIDER' is required") - } - if o.ClientID == "" { - return errors.New("authenticate: 'IDP_CLIENT_ID' is required") - } - if o.ClientSecret == "" { - return errors.New("authenticate: 'IDP_CLIENT_SECRET' is required") - } if o.AuthenticateCallbackPath == "" { return errors.New("authenticate: 'AUTHENTICATE_CALLBACK_PATH' is required") } @@ -45,17 +33,17 @@ func ValidateOptions(o *config.Options) error { // Authenticate contains data required to run the authenticate service. type Authenticate struct { - options *config.AtomicOptions - provider *identity.AtomicAuthenticator - state *atomicAuthenticateState + cfg *authenticateConfig + options *config.AtomicOptions + state *atomicAuthenticateState } // New validates and creates a new authenticate service from a set of Options. -func New(cfg *config.Config) (*Authenticate, error) { +func New(cfg *config.Config, options ...Option) (*Authenticate, error) { a := &Authenticate{ - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), - state: newAtomicAuthenticateState(newAuthenticateState()), + cfg: getAuthenticateConfig(options...), + options: config.NewAtomicOptions(), + state: newAtomicAuthenticateState(newAuthenticateState()), } state, err := newAuthenticateStateFromConfig(cfg) @@ -64,11 +52,6 @@ func New(cfg *config.Config) (*Authenticate, error) { } a.state.Store(state) - err = a.updateProvider(cfg) - if err != nil { - return nil, err - } - return a, nil } @@ -84,36 +67,4 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) { } else { a.state.Store(state) } - if err := a.updateProvider(cfg); err != nil { - log.Error(ctx).Err(err).Msg("authenticate: failed to update identity provider") - } -} - -func (a *Authenticate) updateProvider(cfg *config.Config) error { - u, err := cfg.Options.GetAuthenticateURL() - if err != nil { - return err - } - - redirectURL, _ := urlutil.DeepCopy(u) - redirectURL.Path = cfg.Options.AuthenticateCallbackPath - - // configure our identity provider - provider, err := identity.NewAuthenticator( - oauth.Options{ - RedirectURL: redirectURL, - ProviderName: cfg.Options.Provider, - ProviderURL: cfg.Options.ProviderURL, - ClientID: cfg.Options.ClientID, - ClientSecret: cfg.Options.ClientSecret, - Scopes: cfg.Options.Scopes, - ServiceAccount: cfg.Options.ServiceAccount, - AuthCodeOptions: cfg.Options.RequestParams, - }) - if err != nil { - return err - } - a.provider.Store(provider) - - return nil } diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 2fd5e9622..255d6725a 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -57,11 +57,10 @@ func TestOptions_Validate(t *testing.T) { {"invalid cookie secret", invalidCookieSecret, true}, {"short cookie secret", shortCookieLength, true}, {"no shared secret", badSharedKey, true}, - {"no client id", emptyClientID, true}, - {"no client secret", emptyClientSecret, true}, {"empty callback path", badCallbackPath, true}, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { if err := ValidateOptions(tt.o); (err != nil) != tt.wantErr { t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr) @@ -107,13 +106,12 @@ func TestNew(t *testing.T) { {"good", good, false}, {"empty opts", &config.Options{}, true}, {"fails to validate", badRedirectURL, true}, - {"bad provider", badProvider, true}, - {"empty provider url", emptyProviderURL, true}, {"good signing key", goodSigningKey, false}, {"bad signing key", badSigningKey, true}, {"bad public signing key", badSigninKeyPublic, true}, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { _, err := New(&config.Config{Options: tt.opts}) if (err != nil) != tt.wantErr { diff --git a/authenticate/config.go b/authenticate/config.go new file mode 100644 index 000000000..73f0af1d6 --- /dev/null +++ b/authenticate/config.go @@ -0,0 +1,29 @@ +package authenticate + +import ( + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/identity" +) + +type authenticateConfig struct { + getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error) +} + +// An Option customizes the Authenticate config. +type Option func(*authenticateConfig) + +func getAuthenticateConfig(options ...Option) *authenticateConfig { + cfg := new(authenticateConfig) + WithGetIdentityProvider(defaultGetIdentityProvider)(cfg) + for _, option := range options { + option(cfg) + } + return cfg +} + +// WithGetIdentityProvider sets the getIdentityProvider function in the config. +func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option { + return func(cfg *authenticateConfig) { + cfg.getIdentityProvider = getIdentityProvider + } +} diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 5f148a3f8..07648647d 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -161,10 +161,22 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { defer span.End() state := a.state.Load() + idpID := r.FormValue(urlutil.QueryIdentityProviderID) sessionState, err := a.getSessionFromCtx(ctx) if err != nil { - log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error") + log.FromRequest(r).Info(). + Err(err). + Str("idp_id", idpID). + Msg("authenticate: session load error") + return a.reauthenticateOrFail(w, r, err) + } + + if sessionState.IdentityProviderID != idpID { + log.FromRequest(r).Info(). + Str("idp_id", idpID). + Str("id", sessionState.ID). + Msg("authenticate: session not associated with identity provider") return a.reauthenticateOrFail(w, r, err) } @@ -172,7 +184,11 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return errors.New("authenticate: databroker client cannot be nil") } if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil { - log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker") + log.FromRequest(r).Info(). + Err(err). + Str("idp_id", idpID). + Str("id", sessionState.ID). + Msg("authenticate: session not found in databroker") return a.reauthenticateOrFail(w, r, err) } @@ -222,7 +238,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { return err } - newSession := sessions.NewSession(s, state.redirectURL.Host, jwtAudience) + // start over if this is a different identity provider + if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) { + s = sessions.NewState(urlutil.QueryIdentityProviderID) + } + + newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience) // re-persist the session, useful when session was evicted from session if err := state.sessionStore.SaveSession(w, r, s); err != nil { @@ -258,6 +279,13 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut") defer span.End() + options := a.options.Load() + + idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } + rawIDToken := a.revokeSession(ctx, w, r) redirectString := "" @@ -272,7 +300,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { redirectString = uri } - endSessionURL, err := a.provider.Load().LogOut() + endSessionURL, err := idp.LogOut() if err == nil && redirectString != "" { params := url.Values{} params.Add("id_token_hint", rawIDToken) @@ -300,12 +328,20 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { // https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error { - state := a.state.Load() // If request AJAX/XHR request, return a 401 instead because the redirect // will almost certainly violate their CORs policy if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") { return httputil.NewError(http.StatusUnauthorized, err) } + + options := a.options.Load() + state := a.state.Load() + + idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } + state.sessionStore.ClearSession(w, r) redirectURL := state.redirectURL.ResolveReference(r.URL) nonce := csrf.Token(r) @@ -314,7 +350,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) b = append(b, enc...) encodedState := base64.URLEncoding.EncodeToString(b) - signinURL, err := a.provider.Load().GetSignInURL(encodedState) + signinURL, err := idp.GetSignInURL(encodedState) if err != nil { return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("failed to get sign in url: %w", err)) @@ -349,6 +385,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") defer span.End() + options := a.options.Load() state := a.state.Load() // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 @@ -357,21 +394,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) if idpError := r.FormValue("error"); idpError != "" { return nil, httputil.NewError(a.statusForErrorCode(idpError), fmt.Errorf("identity provider: %v", idpError)) } + // fail if no session redemption code is returned code := r.FormValue("code") if code == "" { return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider returned empty code")) } - // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5 - // - // Exchange the supplied Authorization Code for a valid user session. - var claims identity.SessionClaims - accessToken, err := a.provider.Load().Authenticate(ctx, code, &claims) - if err != nil { - return nil, fmt.Errorf("error redeeming authenticate code: %w", err) - } - // state includes a csrf nonce (validated by middleware) and redirect uri bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) if err != nil { @@ -403,24 +432,35 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) if err != nil { return nil, httputil.NewError(http.StatusBadRequest, err) } + idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID) - s := sessions.State{ID: uuid.New().String()} + idp, err := a.cfg.getIdentityProvider(options, idpID) + if err != nil { + return nil, err + } + + // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5 + // + // Exchange the supplied Authorization Code for a valid user session. + var claims identity.SessionClaims + accessToken, err := idp.Authenticate(ctx, code, &claims) + if err != nil { + return nil, fmt.Errorf("error redeeming authenticate code: %w", err) + } + + s := sessions.NewState(idpID) err = claims.Claims.Claims(&s) if err != nil { return nil, fmt.Errorf("error unmarshaling session state: %w", err) } - newState := sessions.NewSession( - &s, - state.redirectURL.Hostname(), - []string{state.redirectURL.Hostname()}) - + newState := s.WithNewIssuer(state.redirectURL.Hostname(), []string{state.redirectURL.Hostname()}) if nextRedirectURL, err := urlutil.ParseAndValidateURL(redirectURL.Query().Get(urlutil.QueryRedirectURI)); err == nil { newState.Audience = append(newState.Audience, nextRedirectURL.Hostname()) } // save the session and access token to the databroker - err = a.saveSessionToDataBroker(ctx, &newState, claims, accessToken) + err = a.saveSessionToDataBroker(ctx, r, &newState, claims, accessToken) if err != nil { return nil, httputil.NewError(http.StatusInternalServerError, err) } @@ -522,6 +562,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) saveSessionToDataBroker( ctx context.Context, + r *http.Request, sessionState *sessions.State, claims identity.SessionClaims, accessToken *oauth2.Token, @@ -529,12 +570,17 @@ func (a *Authenticate) saveSessionToDataBroker( state := a.state.Load() options := a.options.Load() + idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } + sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire)) idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time()) s := &session.Session{ Id: sessionState.ID, - UserId: sessionState.UserID(a.provider.Load().Name()), + UserId: sessionState.UserID(idp.Name()), IssuedAt: timestamppb.Now(), ExpiresAt: sessionExpiry, IdToken: &session.IDToken{ @@ -557,7 +603,7 @@ func (a *Authenticate) saveSessionToDataBroker( Id: s.GetUserId(), } } - err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &managerUser) + err = idp.UpdateUserInfo(ctx, accessToken, &managerUser) if err != nil { return fmt.Errorf("authenticate: error retrieving user info: %w", err) } @@ -588,10 +634,17 @@ func (a *Authenticate) saveSessionToDataBroker( // databroker. If successful, it returns the original `id_token` of the session, if failed, returns // and empty string. func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string { + options := a.options.Load() state := a.state.Load() + // clear the user's local session no matter what defer state.sessionStore.ClearSession(w, r) + idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return "" + } + var rawIDToken string sessionState, err := a.getSessionFromCtx(ctx) if err != nil { @@ -600,7 +653,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { rawIDToken = s.GetIdToken().GetRaw() - if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { + if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") } } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 1fe697164..b3312fc7e 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -138,6 +138,7 @@ func TestAuthenticate_SignIn(t *testing.T) { {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -145,6 +146,9 @@ func TestAuthenticate_SignIn(t *testing.T) { sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) a := &Authenticate{ + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + return tt.provider, nil + })), state: newAtomicAuthenticateState(&authenticateState{ sharedCipher: sharedCipher, sessionStore: tt.session, @@ -173,11 +177,9 @@ func TestAuthenticate_SignIn(t *testing.T) { directoryClient: new(mockDirectoryServiceClient), }), - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), + options: config.NewAtomicOptions(), } a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())}) - a.provider.Store(tt.provider) uri := &url.URL{Scheme: tt.scheme, Host: tt.host} queryString := uri.Query() @@ -233,10 +235,14 @@ func TestAuthenticate_SignOut(t *testing.T) { {"no redirect uri", http.MethodPost, nil, "", "", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusOK, "{\"Status\":200,\"Error\":\"OK: user logged out\"}\n"}, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() a := &Authenticate{ + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + return tt.provider, nil + })), state: newAtomicAuthenticateState(&authenticateState{ sessionStore: tt.sessionStore, encryptedEncoder: mock.Encoder{}, @@ -265,15 +271,13 @@ func TestAuthenticate_SignOut(t *testing.T) { }, directoryClient: new(mockDirectoryServiceClient), }), - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), + options: config.NewAtomicOptions(), } if tt.signoutRedirectURL != "" { opts := a.options.Load() opts.SignOutRedirectURLString = tt.signoutRedirectURL a.options.Store(opts) } - a.provider.Store(tt.provider) u, _ := url.Parse("/sign_out") params, _ := url.ParseQuery(u.RawQuery) params.Add("sig", tt.sig) @@ -345,6 +349,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -358,6 +363,9 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { } authURL, _ := url.Parse(tt.authenticateURL) a := &Authenticate{ + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + return tt.provider, nil + })), state: newAtomicAuthenticateState(&authenticateState{ dataBrokerClient: mockDataBrokerServiceClient{ get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { @@ -373,10 +381,8 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { cookieCipher: aead, encryptedEncoder: signer, }), - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), + options: config.NewAtomicOptions(), } - a.provider.Store(tt.provider) u, _ := url.Parse("/oauthGet") params, _ := url.ParseQuery(u.RawQuery) params.Add("error", tt.paramErr) @@ -478,6 +484,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { }, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -490,7 +497,10 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { if err != nil { t.Fatal(err) } - a := Authenticate{ + a := &Authenticate{ + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + return tt.provider, nil + })), state: newAtomicAuthenticateState(&authenticateState{ cookieSecret: cryptutil.NewKey(), redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), @@ -519,10 +529,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { }, directoryClient: new(mockDirectoryServiceClient), }), - options: config.NewAtomicOptions(), - provider: identity.NewAtomicAuthenticator(), + options: config.NewAtomicOptions(), } - a.provider.Store(tt.provider) r := httptest.NewRequest("GET", "/", nil) state, err := tt.session.LoadSession(r) if err != nil { diff --git a/authenticate/identity.go b/authenticate/identity.go new file mode 100644 index 000000000..b42059579 --- /dev/null +++ b/authenticate/identity.go @@ -0,0 +1,33 @@ +package authenticate + +import ( + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/identity/oauth" + "github.com/pomerium/pomerium/internal/urlutil" +) + +func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) { + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return nil, err + } + + redirectURL, err := urlutil.DeepCopy(authenticateURL) + if err != nil { + return nil, err + } + redirectURL.Path = options.AuthenticateCallbackPath + + idp := options.GetIdentityProviderForID(idpID) + return identity.NewAuthenticator(oauth.Options{ + RedirectURL: redirectURL, + ProviderName: idp.GetType(), + ProviderURL: idp.GetUrl(), + ClientID: idp.GetClientId(), + ClientSecret: idp.GetClientSecret(), + Scopes: idp.GetScopes(), + ServiceAccount: idp.GetServiceAccount(), + AuthCodeOptions: idp.GetRequestParams(), + }) +} diff --git a/authorize/check_response.go b/authorize/check_response.go index 45ae564e8..af8364ee2 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -38,6 +38,7 @@ func (a *Authorize) handleResultAllowed( func (a *Authorize) handleResultDenied( ctx context.Context, in *envoy_service_auth_v3.CheckRequest, + request *evaluator.Request, result *evaluator.Result, isForwardAuthVerify bool, reasons criteria.Reasons, @@ -49,7 +50,7 @@ func (a *Authorize) handleResultDenied( case reasons.Has(criteria.ReasonUserUnauthenticated): // when the user is unauthenticated it means they haven't // logged in yet, so redirect to authenticate - return a.requireLoginResponse(ctx, in, isForwardAuthVerify) + return a.requireLoginResponse(ctx, in, request, isForwardAuthVerify) case reasons.Has(criteria.ReasonDeviceUnauthenticated): // when the user's device is unauthenticated it means they haven't // registered a webauthn device yet, so redirect to the webauthn flow @@ -141,6 +142,7 @@ func (a *Authorize) deniedResponse( func (a *Authorize) requireLoginResponse( ctx context.Context, in *envoy_service_auth_v3.CheckRequest, + request *evaluator.Request, isForwardAuthVerify bool, ) (*envoy_service_auth_v3.CheckResponse, error) { opts := a.currentOptions.Load() @@ -164,6 +166,7 @@ func (a *Authorize) requireLoginResponse( checkRequestURL.Scheme = "https" q.Set(urlutil.QueryRedirectURI, checkRequestURL.String()) + q.Set(urlutil.QueryIdentityProviderID, opts.GetIdentityProviderForPolicy(request.Policy).GetId()) signinURL.RawQuery = q.Encode() redirectTo := urlutil.NewSignedURL(state.sharedKey, signinURL).String() diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 7fad6b05d..cc9c77d04 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -172,38 +172,46 @@ func TestRequireLogin(t *testing.T) { require.NoError(t, err) t.Run("accept empty", func(t *testing.T) { - res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{}, + res, err := a.requireLoginResponse(context.Background(), + &envoy_service_auth_v3.CheckRequest{}, + &evaluator.Request{}, false) require.NoError(t, err) assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode())) }) t.Run("accept html", func(t *testing.T) { - res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{ - Attributes: &envoy_service_auth_v3.AttributeContext{ - Request: &envoy_service_auth_v3.AttributeContext_Request{ - Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ - Headers: map[string]string{ - "accept": "*/*", + res, err := a.requireLoginResponse(context.Background(), + &envoy_service_auth_v3.CheckRequest{ + Attributes: &envoy_service_auth_v3.AttributeContext{ + Request: &envoy_service_auth_v3.AttributeContext_Request{ + Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ + Headers: map[string]string{ + "accept": "*/*", + }, }, }, }, }, - }, false) + &evaluator.Request{}, + false) require.NoError(t, err) assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode())) }) t.Run("accept json", func(t *testing.T) { - res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{ - Attributes: &envoy_service_auth_v3.AttributeContext{ - Request: &envoy_service_auth_v3.AttributeContext_Request{ - Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ - Headers: map[string]string{ - "accept": "application/json", + res, err := a.requireLoginResponse(context.Background(), + &envoy_service_auth_v3.CheckRequest{ + Attributes: &envoy_service_auth_v3.AttributeContext{ + Request: &envoy_service_auth_v3.AttributeContext_Request{ + Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ + Headers: map[string]string{ + "accept": "application/json", + }, }, }, }, }, - }, false) + &evaluator.Request{}, + false) require.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, int(res.GetDeniedResponse().GetStatus().GetCode())) }) diff --git a/authorize/grpc.go b/authorize/grpc.go index 804d37956..1e3557235 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -76,7 +76,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe // if there's a deny, the result is denied using the deny reasons. if res.Deny.Value { - return a.handleResultDenied(ctx, in, res, isForwardAuthVerify, res.Deny.Reasons) + return a.handleResultDenied(ctx, in, req, res, isForwardAuthVerify, res.Deny.Reasons) } // if there's an allow, the result is allowed. @@ -85,7 +85,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe } // otherwise, the result is denied using the allow reasons. - return a.handleResultDenied(ctx, in, res, isForwardAuthVerify, res.Allow.Reasons) + return a.handleResultDenied(ctx, in, req, res, isForwardAuthVerify, res.Allow.Reasons) } func getForwardAuthURL(r *http.Request) *url.URL { diff --git a/config/identity.go b/config/identity.go new file mode 100644 index 000000000..69169f69f --- /dev/null +++ b/config/identity.go @@ -0,0 +1,42 @@ +package config + +import ( + "github.com/pomerium/pomerium/pkg/grpc/identity" +) + +// GetIdentityProviderForID returns the identity provider associated with the given IDP id. +// If none is found the default provider is returned. +func (o *Options) GetIdentityProviderForID(idpID string) *identity.Provider { + for _, policy := range o.GetAllPolicies() { + idp := o.GetIdentityProviderForPolicy(&policy) //nolint + if idp.GetId() == idpID { + return idp + } + } + + return o.GetIdentityProviderForPolicy(nil) +} + +// GetIdentityProviderForPolicy gets the identity provider associated with the given policy. +// If policy is nil, or changes none of the default settings, the default provider is returned. +func (o *Options) GetIdentityProviderForPolicy(policy *Policy) *identity.Provider { + idp := &identity.Provider{ + ClientId: o.ClientID, + ClientSecret: o.ClientSecret, + Type: o.Provider, + Scopes: o.Scopes, + ServiceAccount: o.ServiceAccount, + Url: o.ProviderURL, + RequestParams: o.RequestParams, + } + if policy != nil { + if policy.IDPClientID != "" { + idp.ClientId = policy.IDPClientID + } + if policy.IDPClientSecret != "" { + idp.ClientSecret = policy.IDPClientSecret + } + } + idp.Id = idp.Hash() + return idp +} diff --git a/config/policy.go b/config/policy.go index da461e20e..4c8ae5b1f 100644 --- a/config/policy.go +++ b/config/policy.go @@ -162,6 +162,11 @@ type Policy struct { // SetResponseHeaders sets response headers. SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"` + // IDPClientID is the client id used for the identity provider. + IDPClientID string `mapstructure:"idp_client_id" yaml:"idp_client_id,omitempty"` + // IDPClientSecret is the client secret used for the identity provider. + IDPClientSecret string `mapstructure:"idp_client_secret" yaml:"idp_client_secret,omitempty"` + Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"` } diff --git a/internal/httputil/errors.go b/internal/httputil/errors.go index 408acac90..a3737a5cb 100644 --- a/internal/httputil/errors.go +++ b/internal/httputil/errors.go @@ -27,7 +27,11 @@ func NewError(status int, err error) error { // Error implements the `error` interface. func (e *HTTPError) Error() string { - return StatusText(e.Status) + ": " + e.Err.Error() + str := StatusText(e.Status) + if e.Err != nil { + str += ": " + e.Err.Error() + } + return str } // Unwrap implements the `error` Unwrap interface. diff --git a/internal/sessions/state.go b/internal/sessions/state.go index 456a758a3..82a084c40 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -3,10 +3,10 @@ package sessions import ( "encoding/json" "errors" - "fmt" "time" "github.com/go-jose/go-jose/v3/jwt" + "github.com/google/uuid" ) // ErrMissingID is the error for a session state that has no ID set. @@ -15,34 +15,6 @@ var ErrMissingID = errors.New("invalid session: missing id") // timeNow is time.Now but pulled out as a variable for tests. var timeNow = time.Now -// Version represents "ver" field in JWT public claims. -// -// The field is not specified by RFC 7519, so providers can -// return either string or number (like okta). -type Version string - -// String implements fmt.Stringer interface. -func (v *Version) String() string { - return string(*v) -} - -// UnmarshalJSON implements json.Unmarshaler interface. -func (v *Version) UnmarshalJSON(b []byte) error { - var tmp interface{} - if err := json.Unmarshal(b, &tmp); err != nil { - return err - } - switch val := tmp.(type) { - case string: - *v = Version(val) - case float64: - *v = Version(fmt.Sprintf("%g", val)) - default: - return errors.New("invalid type for Version") - } - return nil -} - // State is our object that keeps track of a user's session state type State struct { // Public claim values (as specified in RFC 7519). @@ -61,12 +33,26 @@ type State struct { // DatabrokerRecordVersion tracks the last referenced databroker record version // for the saved session. DatabrokerRecordVersion uint64 `json:"databroker_record_version,omitempty"` + + // IdentityProviderID is the identity provider for the session. + IdentityProviderID string `json:"idp_id,omitempty"` } -// NewSession updates issuer, audience, and issuance timestamps but keeps -// parent expiry. -func NewSession(s *State, issuer string, audience []string) State { - newState := *s +// NewState creates a new State. +func NewState(idpID string) *State { + return &State{ + IssuedAt: jwt.NewNumericDate(timeNow()), + ID: uuid.NewString(), + IdentityProviderID: idpID, + } +} + +// WithNewIssuer creates a new State from an existing State. +func (s *State) WithNewIssuer(issuer string, audience []string) State { + newState := State{} + if s != nil { + newState = *s + } newState.IssuedAt = jwt.NewNumericDate(timeNow()) newState.Audience = audience newState.Issuer = issuer diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index 2b63ac0a0..e51890628 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -18,31 +18,31 @@ func TestState_UnmarshalJSON(t *testing.T) { tests := []struct { name string in *State - want State + want *State wantErr bool }{ { "good", &State{ID: "xyz"}, - State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)}, + &State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)}, false, }, { "with user", &State{ID: "xyz"}, - State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)}, + &State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)}, false, }, { "without", &State{ID: "xyz", Subject: "user"}, - State{ID: "xyz", Subject: "user", IssuedAt: jwt.NewNumericDate(fixedTime)}, + &State{ID: "xyz", Subject: "user", IssuedAt: jwt.NewNumericDate(fixedTime)}, false, }, { "missing id", &State{}, - State{IssuedAt: jwt.NewNumericDate(fixedTime)}, + &State{IssuedAt: jwt.NewNumericDate(fixedTime)}, true, }, } @@ -53,7 +53,8 @@ func TestState_UnmarshalJSON(t *testing.T) { t.Fatal(err) } - s := NewSession(&State{}, "", nil) + s := NewState("") + s.ID = "" if err := s.UnmarshalJSON(data); (err != nil) != tt.wantErr { t.Errorf("State.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } @@ -63,30 +64,3 @@ func TestState_UnmarshalJSON(t *testing.T) { }) } } - -func TestVersion_UnmarshalJSON(t *testing.T) { - tests := []struct { - name string - jsonStr string - wantVersion string - wantErr bool - }{ - {"Version is string", `"1"`, "1", false}, - {"Version is integer", `1`, "1", false}, - {"Version is float", `1.1`, "1.1", false}, - {"Invalid version", `[1]`, "", true}, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - var v Version - if err := v.UnmarshalJSON([]byte(tc.jsonStr)); (err != nil) != tc.wantErr { - t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tc.wantErr) - } - if !tc.wantErr && v.String() != tc.wantVersion { - t.Errorf("mismatch version, want: %s, got: %s", tc.wantVersion, v.String()) - } - }) - } -} diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 72aed1d6c..8b773e434 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -8,6 +8,7 @@ const ( QueryDeviceCredentialID = "pomerium_device_credential_id" QueryDeviceType = "pomerium_device_type" QueryEnrollmentToken = "pomerium_enrollment_token" //nolint + QueryIdentityProviderID = "pomerium_idp_id" QueryIsProgrammatic = "pomerium_programmatic" QueryForwardAuth = "pomerium_forward_auth" QueryPomeriumJWT = "pomerium_jwt" diff --git a/pkg/grpc/identity/identity.go b/pkg/grpc/identity/identity.go new file mode 100644 index 000000000..20bc9c475 --- /dev/null +++ b/pkg/grpc/identity/identity.go @@ -0,0 +1,27 @@ +// Package identity contains protobuf types for identity management. +package identity + +import ( + "crypto/sha256" + + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/pkg/encoding/base58" +) + +// Clone clones the Provider. +func (x *Provider) Clone() *Provider { + return proto.Clone(x).(*Provider) +} + +// Hash computes a sha256 hash of the provider's fields. It excludes the Id field. +func (x *Provider) Hash() string { + tmp := x.Clone() + tmp.Id = "" + bs, _ := proto.MarshalOptions{ + AllowPartial: true, + Deterministic: true, + }.Marshal(tmp) + h := sha256.Sum256(bs) + return base58.Encode(h[:]) +} diff --git a/pkg/grpc/identity/identity.pb.go b/pkg/grpc/identity/identity.pb.go new file mode 100644 index 000000000..961a87c1c --- /dev/null +++ b/pkg/grpc/identity/identity.pb.go @@ -0,0 +1,232 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.14.0 +// source: identity.proto + +package identity + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Provider struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + ClientId string `protobuf:"bytes,2,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + ClientSecret string `protobuf:"bytes,3,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"` + Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"` + Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"` + ServiceAccount string `protobuf:"bytes,6,opt,name=service_account,json=serviceAccount,proto3" json:"service_account,omitempty"` + Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"` + RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + RedirectUrl string `protobuf:"bytes,9,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"` +} + +func (x *Provider) Reset() { + *x = Provider{} + if protoimpl.UnsafeEnabled { + mi := &file_identity_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Provider) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Provider) ProtoMessage() {} + +func (x *Provider) ProtoReflect() protoreflect.Message { + mi := &file_identity_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Provider.ProtoReflect.Descriptor instead. +func (*Provider) Descriptor() ([]byte, []int) { + return file_identity_proto_rawDescGZIP(), []int{0} +} + +func (x *Provider) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *Provider) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *Provider) GetClientSecret() string { + if x != nil { + return x.ClientSecret + } + return "" +} + +func (x *Provider) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +func (x *Provider) GetScopes() []string { + if x != nil { + return x.Scopes + } + return nil +} + +func (x *Provider) GetServiceAccount() string { + if x != nil { + return x.ServiceAccount + } + return "" +} + +func (x *Provider) GetUrl() string { + if x != nil { + return x.Url + } + return "" +} + +func (x *Provider) GetRequestParams() map[string]string { + if x != nil { + return x.RequestParams + } + return nil +} + +func (x *Provider) GetRedirectUrl() string { + if x != nil { + return x.RedirectUrl + } + return "" +} + +var File_identity_proto protoreflect.FileDescriptor + +var file_identity_proto_rawDesc = []byte{ + 0x0a, 0x0e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x11, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74, + 0x69, 0x74, 0x79, 0x22, 0xff, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, + 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a, + 0x0d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, + 0x65, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x27, + 0x0a, 0x0f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x55, 0x0a, 0x0e, 0x72, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, + 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x72, 0x6c, 0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, + 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, + 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x69, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_identity_proto_rawDescOnce sync.Once + file_identity_proto_rawDescData = file_identity_proto_rawDesc +) + +func file_identity_proto_rawDescGZIP() []byte { + file_identity_proto_rawDescOnce.Do(func() { + file_identity_proto_rawDescData = protoimpl.X.CompressGZIP(file_identity_proto_rawDescData) + }) + return file_identity_proto_rawDescData +} + +var file_identity_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_identity_proto_goTypes = []interface{}{ + (*Provider)(nil), // 0: pomerium.identity.Provider + nil, // 1: pomerium.identity.Provider.RequestParamsEntry +} +var file_identity_proto_depIdxs = []int32{ + 1, // 0: pomerium.identity.Provider.request_params:type_name -> pomerium.identity.Provider.RequestParamsEntry + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_identity_proto_init() } +func file_identity_proto_init() { + if File_identity_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_identity_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Provider); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_identity_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_identity_proto_goTypes, + DependencyIndexes: file_identity_proto_depIdxs, + MessageInfos: file_identity_proto_msgTypes, + }.Build() + File_identity_proto = out.File + file_identity_proto_rawDesc = nil + file_identity_proto_goTypes = nil + file_identity_proto_depIdxs = nil +} diff --git a/pkg/grpc/identity/identity.proto b/pkg/grpc/identity/identity.proto new file mode 100644 index 000000000..5a5f446d8 --- /dev/null +++ b/pkg/grpc/identity/identity.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package pomerium.identity; +option go_package = "github.com/pomerium/pomerium/pkg/grpc/identity"; + +message Provider { + string id = 1; + string client_id = 2; + string client_secret = 3; + string type = 4; + repeated string scopes = 5; + string service_account = 6; + string url = 7; + map request_params = 8; +} diff --git a/pkg/grpc/protoc.bash b/pkg/grpc/protoc.bash index 776f0cc4e..76dca1e84 100755 --- a/pkg/grpc/protoc.bash +++ b/pkg/grpc/protoc.bash @@ -96,6 +96,11 @@ _import_paths=$(join_by , "${_imports[@]}") --go_out="$_import_paths,plugins=grpc,paths=source_relative:./directory/." \ ./directory/directory.proto + +../../scripts/protoc -I ./identity/ \ + --go_out="$_import_paths,plugins=grpc,paths=source_relative:./identity/." \ + ./identity/identity.proto + ../../scripts/protoc -I ./registry/ \ --go_out="$_import_paths,plugins=grpc,paths=source_relative:./registry/." \ --validate_out="lang=go,paths=source_relative:./registry" \