authenticate: support for per-route client id and client secret (#3030)

* implement dynamic provider support

* authenticate: support per-route client id and secret
This commit is contained in:
Caleb Doxsey 2022-02-16 12:31:55 -07:00 committed by GitHub
parent 99ffaf233d
commit f9b95a276b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 557 additions and 183 deletions

View file

@ -8,10 +8,7 @@ import (
"fmt" "fmt"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
) )
@ -28,15 +25,6 @@ func ValidateOptions(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil { if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %w", err) return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %w", err)
} }
if o.Provider == "" {
return errors.New("authenticate: 'IDP_PROVIDER' is required")
}
if o.ClientID == "" {
return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
}
if o.ClientSecret == "" {
return errors.New("authenticate: 'IDP_CLIENT_SECRET' is required")
}
if o.AuthenticateCallbackPath == "" { if o.AuthenticateCallbackPath == "" {
return errors.New("authenticate: 'AUTHENTICATE_CALLBACK_PATH' is required") return errors.New("authenticate: 'AUTHENTICATE_CALLBACK_PATH' is required")
} }
@ -45,17 +33,17 @@ func ValidateOptions(o *config.Options) error {
// Authenticate contains data required to run the authenticate service. // Authenticate contains data required to run the authenticate service.
type Authenticate struct { type Authenticate struct {
options *config.AtomicOptions cfg *authenticateConfig
provider *identity.AtomicAuthenticator options *config.AtomicOptions
state *atomicAuthenticateState state *atomicAuthenticateState
} }
// New validates and creates a new authenticate service from a set of Options. // New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config) (*Authenticate, error) { func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
a := &Authenticate{ a := &Authenticate{
options: config.NewAtomicOptions(), cfg: getAuthenticateConfig(options...),
provider: identity.NewAtomicAuthenticator(), options: config.NewAtomicOptions(),
state: newAtomicAuthenticateState(newAuthenticateState()), state: newAtomicAuthenticateState(newAuthenticateState()),
} }
state, err := newAuthenticateStateFromConfig(cfg) state, err := newAuthenticateStateFromConfig(cfg)
@ -64,11 +52,6 @@ func New(cfg *config.Config) (*Authenticate, error) {
} }
a.state.Store(state) a.state.Store(state)
err = a.updateProvider(cfg)
if err != nil {
return nil, err
}
return a, nil return a, nil
} }
@ -84,36 +67,4 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
} else { } else {
a.state.Store(state) a.state.Store(state)
} }
if err := a.updateProvider(cfg); err != nil {
log.Error(ctx).Err(err).Msg("authenticate: failed to update identity provider")
}
}
func (a *Authenticate) updateProvider(cfg *config.Config) error {
u, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return err
}
redirectURL, _ := urlutil.DeepCopy(u)
redirectURL.Path = cfg.Options.AuthenticateCallbackPath
// configure our identity provider
provider, err := identity.NewAuthenticator(
oauth.Options{
RedirectURL: redirectURL,
ProviderName: cfg.Options.Provider,
ProviderURL: cfg.Options.ProviderURL,
ClientID: cfg.Options.ClientID,
ClientSecret: cfg.Options.ClientSecret,
Scopes: cfg.Options.Scopes,
ServiceAccount: cfg.Options.ServiceAccount,
AuthCodeOptions: cfg.Options.RequestParams,
})
if err != nil {
return err
}
a.provider.Store(provider)
return nil
} }

View file

@ -57,11 +57,10 @@ func TestOptions_Validate(t *testing.T) {
{"invalid cookie secret", invalidCookieSecret, true}, {"invalid cookie secret", invalidCookieSecret, true},
{"short cookie secret", shortCookieLength, true}, {"short cookie secret", shortCookieLength, true},
{"no shared secret", badSharedKey, true}, {"no shared secret", badSharedKey, true},
{"no client id", emptyClientID, true},
{"no client secret", emptyClientSecret, true},
{"empty callback path", badCallbackPath, true}, {"empty callback path", badCallbackPath, true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := ValidateOptions(tt.o); (err != nil) != tt.wantErr { if err := ValidateOptions(tt.o); (err != nil) != tt.wantErr {
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
@ -107,13 +106,12 @@ func TestNew(t *testing.T) {
{"good", good, false}, {"good", good, false},
{"empty opts", &config.Options{}, true}, {"empty opts", &config.Options{}, true},
{"fails to validate", badRedirectURL, true}, {"fails to validate", badRedirectURL, true},
{"bad provider", badProvider, true},
{"empty provider url", emptyProviderURL, true},
{"good signing key", goodSigningKey, false}, {"good signing key", goodSigningKey, false},
{"bad signing key", badSigningKey, true}, {"bad signing key", badSigningKey, true},
{"bad public signing key", badSigninKeyPublic, true}, {"bad public signing key", badSigninKeyPublic, true},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := New(&config.Config{Options: tt.opts}) _, err := New(&config.Config{Options: tt.opts})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {

29
authenticate/config.go Normal file
View file

@ -0,0 +1,29 @@
package authenticate
import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/identity"
)
type authenticateConfig struct {
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
}
// An Option customizes the Authenticate config.
type Option func(*authenticateConfig)
func getAuthenticateConfig(options ...Option) *authenticateConfig {
cfg := new(authenticateConfig)
WithGetIdentityProvider(defaultGetIdentityProvider)(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// WithGetIdentityProvider sets the getIdentityProvider function in the config.
func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option {
return func(cfg *authenticateConfig) {
cfg.getIdentityProvider = getIdentityProvider
}
}

View file

@ -161,10 +161,22 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End() defer span.End()
state := a.state.Load() state := a.state.Load()
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
sessionState, err := a.getSessionFromCtx(ctx) sessionState, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error") log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}
if sessionState.IdentityProviderID != idpID {
log.FromRequest(r).Info().
Str("idp_id", idpID).
Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
@ -172,7 +184,11 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return errors.New("authenticate: databroker client cannot be nil") return errors.New("authenticate: databroker client cannot be nil")
} }
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil { if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
log.FromRequest(r).Info().Err(err).Str("id", sessionState.ID).Msg("authenticate: session not found in databroker") log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Str("id", sessionState.ID).
Msg("authenticate: session not found in databroker")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
@ -222,7 +238,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
newSession := sessions.NewSession(s, state.redirectURL.Host, jwtAudience) // start over if this is a different identity provider
if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) {
s = sessions.NewState(urlutil.QueryIdentityProviderID)
}
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
// re-persist the session, useful when session was evicted from session // re-persist the session, useful when session was evicted from session
if err := state.sessionStore.SaveSession(w, r, s); err != nil { if err := state.sessionStore.SaveSession(w, r, s); err != nil {
@ -258,6 +279,13 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut") ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End() defer span.End()
options := a.options.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
rawIDToken := a.revokeSession(ctx, w, r) rawIDToken := a.revokeSession(ctx, w, r)
redirectString := "" redirectString := ""
@ -272,7 +300,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
redirectString = uri redirectString = uri
} }
endSessionURL, err := a.provider.Load().LogOut() endSessionURL, err := idp.LogOut()
if err == nil && redirectString != "" { if err == nil && redirectString != "" {
params := url.Values{} params := url.Values{}
params.Add("id_token_hint", rawIDToken) params.Add("id_token_hint", rawIDToken)
@ -300,12 +328,20 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://tools.ietf.org/html/rfc6749#section-4.2.1
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest // https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error { func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
state := a.state.Load()
// If request AJAX/XHR request, return a 401 instead because the redirect // If request AJAX/XHR request, return a 401 instead because the redirect
// will almost certainly violate their CORs policy // will almost certainly violate their CORs policy
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") { if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
return httputil.NewError(http.StatusUnauthorized, err) return httputil.NewError(http.StatusUnauthorized, err)
} }
options := a.options.Load()
state := a.state.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
state.sessionStore.ClearSession(w, r) state.sessionStore.ClearSession(w, r)
redirectURL := state.redirectURL.ResolveReference(r.URL) redirectURL := state.redirectURL.ResolveReference(r.URL)
nonce := csrf.Token(r) nonce := csrf.Token(r)
@ -314,7 +350,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...) b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b) encodedState := base64.URLEncoding.EncodeToString(b)
signinURL, err := a.provider.Load().GetSignInURL(encodedState) signinURL, err := idp.GetSignInURL(encodedState)
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, return httputil.NewError(http.StatusInternalServerError,
fmt.Errorf("failed to get sign in url: %w", err)) fmt.Errorf("failed to get sign in url: %w", err))
@ -349,6 +385,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End() defer span.End()
options := a.options.Load()
state := a.state.Load() state := a.state.Load()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
@ -357,21 +394,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
if idpError := r.FormValue("error"); idpError != "" { if idpError := r.FormValue("error"); idpError != "" {
return nil, httputil.NewError(a.statusForErrorCode(idpError), fmt.Errorf("identity provider: %v", idpError)) return nil, httputil.NewError(a.statusForErrorCode(idpError), fmt.Errorf("identity provider: %v", idpError))
} }
// fail if no session redemption code is returned // fail if no session redemption code is returned
code := r.FormValue("code") code := r.FormValue("code")
if code == "" { if code == "" {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider returned empty code")) return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider returned empty code"))
} }
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
//
// Exchange the supplied Authorization Code for a valid user session.
var claims identity.SessionClaims
accessToken, err := a.provider.Load().Authenticate(ctx, code, &claims)
if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}
// state includes a csrf nonce (validated by middleware) and redirect uri // state includes a csrf nonce (validated by middleware) and redirect uri
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
if err != nil { if err != nil {
@ -403,24 +432,35 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
s := sessions.State{ID: uuid.New().String()} idp, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, err
}
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
//
// Exchange the supplied Authorization Code for a valid user session.
var claims identity.SessionClaims
accessToken, err := idp.Authenticate(ctx, code, &claims)
if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}
s := sessions.NewState(idpID)
err = claims.Claims.Claims(&s) err = claims.Claims.Claims(&s)
if err != nil { if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err) return nil, fmt.Errorf("error unmarshaling session state: %w", err)
} }
newState := sessions.NewSession( newState := s.WithNewIssuer(state.redirectURL.Hostname(), []string{state.redirectURL.Hostname()})
&s,
state.redirectURL.Hostname(),
[]string{state.redirectURL.Hostname()})
if nextRedirectURL, err := urlutil.ParseAndValidateURL(redirectURL.Query().Get(urlutil.QueryRedirectURI)); err == nil { if nextRedirectURL, err := urlutil.ParseAndValidateURL(redirectURL.Query().Get(urlutil.QueryRedirectURI)); err == nil {
newState.Audience = append(newState.Audience, nextRedirectURL.Hostname()) newState.Audience = append(newState.Audience, nextRedirectURL.Hostname())
} }
// save the session and access token to the databroker // save the session and access token to the databroker
err = a.saveSessionToDataBroker(ctx, &newState, claims, accessToken) err = a.saveSessionToDataBroker(ctx, r, &newState, claims, accessToken)
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusInternalServerError, err) return nil, httputil.NewError(http.StatusInternalServerError, err)
} }
@ -522,6 +562,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
func (a *Authenticate) saveSessionToDataBroker( func (a *Authenticate) saveSessionToDataBroker(
ctx context.Context, ctx context.Context,
r *http.Request,
sessionState *sessions.State, sessionState *sessions.State,
claims identity.SessionClaims, claims identity.SessionClaims,
accessToken *oauth2.Token, accessToken *oauth2.Token,
@ -529,12 +570,17 @@ func (a *Authenticate) saveSessionToDataBroker(
state := a.state.Load() state := a.state.Load()
options := a.options.Load() options := a.options.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire)) sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire))
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time()) idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
s := &session.Session{ s := &session.Session{
Id: sessionState.ID, Id: sessionState.ID,
UserId: sessionState.UserID(a.provider.Load().Name()), UserId: sessionState.UserID(idp.Name()),
IssuedAt: timestamppb.Now(), IssuedAt: timestamppb.Now(),
ExpiresAt: sessionExpiry, ExpiresAt: sessionExpiry,
IdToken: &session.IDToken{ IdToken: &session.IDToken{
@ -557,7 +603,7 @@ func (a *Authenticate) saveSessionToDataBroker(
Id: s.GetUserId(), Id: s.GetUserId(),
} }
} }
err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &managerUser) err = idp.UpdateUserInfo(ctx, accessToken, &managerUser)
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err) return fmt.Errorf("authenticate: error retrieving user info: %w", err)
} }
@ -588,10 +634,17 @@ func (a *Authenticate) saveSessionToDataBroker(
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns // databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string. // and empty string.
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string { func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
options := a.options.Load()
state := a.state.Load() state := a.state.Load()
// clear the user's local session no matter what // clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r) defer state.sessionStore.ClearSession(w, r)
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
var rawIDToken string var rawIDToken string
sessionState, err := a.getSessionFromCtx(ctx) sessionState, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
@ -600,7 +653,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
rawIDToken = s.GetIdToken().GetRaw() rawIDToken = s.GetIdToken().GetRaw()
if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
} }
} }

View file

@ -138,6 +138,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
@ -145,6 +146,9 @@ func TestAuthenticate_SignIn(t *testing.T) {
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: newAtomicAuthenticateState(&authenticateState{ state: newAtomicAuthenticateState(&authenticateState{
sharedCipher: sharedCipher, sharedCipher: sharedCipher,
sessionStore: tt.session, sessionStore: tt.session,
@ -173,11 +177,9 @@ func TestAuthenticate_SignIn(t *testing.T) {
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(),
} }
a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())}) a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
a.provider.Store(tt.provider)
uri := &url.URL{Scheme: tt.scheme, Host: tt.host} uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
queryString := uri.Query() queryString := uri.Query()
@ -233,10 +235,14 @@ func TestAuthenticate_SignOut(t *testing.T) {
{"no redirect uri", http.MethodPost, nil, "", "", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusOK, "{\"Status\":200,\"Error\":\"OK: user logged out\"}\n"}, {"no redirect uri", http.MethodPost, nil, "", "", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusOK, "{\"Status\":200,\"Error\":\"OK: user logged out\"}\n"},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: newAtomicAuthenticateState(&authenticateState{ state: newAtomicAuthenticateState(&authenticateState{
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
encryptedEncoder: mock.Encoder{}, encryptedEncoder: mock.Encoder{},
@ -265,15 +271,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
}, },
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(),
} }
if tt.signoutRedirectURL != "" { if tt.signoutRedirectURL != "" {
opts := a.options.Load() opts := a.options.Load()
opts.SignOutRedirectURLString = tt.signoutRedirectURL opts.SignOutRedirectURLString = tt.signoutRedirectURL
a.options.Store(opts) a.options.Store(opts)
} }
a.provider.Store(tt.provider)
u, _ := url.Parse("/sign_out") u, _ := url.Parse("/sign_out")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
params.Add("sig", tt.sig) params.Add("sig", tt.sig)
@ -345,6 +349,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
@ -358,6 +363,9 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
} }
authURL, _ := url.Parse(tt.authenticateURL) authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: newAtomicAuthenticateState(&authenticateState{ state: newAtomicAuthenticateState(&authenticateState{
dataBrokerClient: mockDataBrokerServiceClient{ dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
@ -373,10 +381,8 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
cookieCipher: aead, cookieCipher: aead,
encryptedEncoder: signer, encryptedEncoder: signer,
}), }),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(),
} }
a.provider.Store(tt.provider)
u, _ := url.Parse("/oauthGet") u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
params.Add("error", tt.paramErr) params.Add("error", tt.paramErr)
@ -478,6 +484,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
}, },
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
@ -490,7 +497,10 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a := Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: newAtomicAuthenticateState(&authenticateState{ state: newAtomicAuthenticateState(&authenticateState{
cookieSecret: cryptutil.NewKey(), cookieSecret: cryptutil.NewKey(),
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
@ -519,10 +529,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
}, },
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(),
} }
a.provider.Store(tt.provider)
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
state, err := tt.session.LoadSession(r) state, err := tt.session.LoadSession(r)
if err != nil { if err != nil {

33
authenticate/identity.go Normal file
View file

@ -0,0 +1,33 @@
package authenticate
import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/urlutil"
)
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) {
authenticateURL, err := options.GetAuthenticateURL()
if err != nil {
return nil, err
}
redirectURL, err := urlutil.DeepCopy(authenticateURL)
if err != nil {
return nil, err
}
redirectURL.Path = options.AuthenticateCallbackPath
idp := options.GetIdentityProviderForID(idpID)
return identity.NewAuthenticator(oauth.Options{
RedirectURL: redirectURL,
ProviderName: idp.GetType(),
ProviderURL: idp.GetUrl(),
ClientID: idp.GetClientId(),
ClientSecret: idp.GetClientSecret(),
Scopes: idp.GetScopes(),
ServiceAccount: idp.GetServiceAccount(),
AuthCodeOptions: idp.GetRequestParams(),
})
}

View file

@ -38,6 +38,7 @@ func (a *Authorize) handleResultAllowed(
func (a *Authorize) handleResultDenied( func (a *Authorize) handleResultDenied(
ctx context.Context, ctx context.Context,
in *envoy_service_auth_v3.CheckRequest, in *envoy_service_auth_v3.CheckRequest,
request *evaluator.Request,
result *evaluator.Result, result *evaluator.Result,
isForwardAuthVerify bool, isForwardAuthVerify bool,
reasons criteria.Reasons, reasons criteria.Reasons,
@ -49,7 +50,7 @@ func (a *Authorize) handleResultDenied(
case reasons.Has(criteria.ReasonUserUnauthenticated): case reasons.Has(criteria.ReasonUserUnauthenticated):
// when the user is unauthenticated it means they haven't // when the user is unauthenticated it means they haven't
// logged in yet, so redirect to authenticate // logged in yet, so redirect to authenticate
return a.requireLoginResponse(ctx, in, isForwardAuthVerify) return a.requireLoginResponse(ctx, in, request, isForwardAuthVerify)
case reasons.Has(criteria.ReasonDeviceUnauthenticated): case reasons.Has(criteria.ReasonDeviceUnauthenticated):
// when the user's device is unauthenticated it means they haven't // when the user's device is unauthenticated it means they haven't
// registered a webauthn device yet, so redirect to the webauthn flow // registered a webauthn device yet, so redirect to the webauthn flow
@ -141,6 +142,7 @@ func (a *Authorize) deniedResponse(
func (a *Authorize) requireLoginResponse( func (a *Authorize) requireLoginResponse(
ctx context.Context, ctx context.Context,
in *envoy_service_auth_v3.CheckRequest, in *envoy_service_auth_v3.CheckRequest,
request *evaluator.Request,
isForwardAuthVerify bool, isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) { ) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load() opts := a.currentOptions.Load()
@ -164,6 +166,7 @@ func (a *Authorize) requireLoginResponse(
checkRequestURL.Scheme = "https" checkRequestURL.Scheme = "https"
q.Set(urlutil.QueryRedirectURI, checkRequestURL.String()) q.Set(urlutil.QueryRedirectURI, checkRequestURL.String())
q.Set(urlutil.QueryIdentityProviderID, opts.GetIdentityProviderForPolicy(request.Policy).GetId())
signinURL.RawQuery = q.Encode() signinURL.RawQuery = q.Encode()
redirectTo := urlutil.NewSignedURL(state.sharedKey, signinURL).String() redirectTo := urlutil.NewSignedURL(state.sharedKey, signinURL).String()

View file

@ -172,38 +172,46 @@ func TestRequireLogin(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Run("accept empty", func(t *testing.T) { t.Run("accept empty", func(t *testing.T) {
res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{}, res, err := a.requireLoginResponse(context.Background(),
&envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{},
false) false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode())) assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode()))
}) })
t.Run("accept html", func(t *testing.T) { t.Run("accept html", func(t *testing.T) {
res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{ res, err := a.requireLoginResponse(context.Background(),
Attributes: &envoy_service_auth_v3.AttributeContext{ &envoy_service_auth_v3.CheckRequest{
Request: &envoy_service_auth_v3.AttributeContext_Request{ Attributes: &envoy_service_auth_v3.AttributeContext{
Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ Request: &envoy_service_auth_v3.AttributeContext_Request{
Headers: map[string]string{ Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{
"accept": "*/*", Headers: map[string]string{
"accept": "*/*",
},
}, },
}, },
}, },
}, },
}, false) &evaluator.Request{},
false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode())) assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode()))
}) })
t.Run("accept json", func(t *testing.T) { t.Run("accept json", func(t *testing.T) {
res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{ res, err := a.requireLoginResponse(context.Background(),
Attributes: &envoy_service_auth_v3.AttributeContext{ &envoy_service_auth_v3.CheckRequest{
Request: &envoy_service_auth_v3.AttributeContext_Request{ Attributes: &envoy_service_auth_v3.AttributeContext{
Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ Request: &envoy_service_auth_v3.AttributeContext_Request{
Headers: map[string]string{ Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{
"accept": "application/json", Headers: map[string]string{
"accept": "application/json",
},
}, },
}, },
}, },
}, },
}, false) &evaluator.Request{},
false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, int(res.GetDeniedResponse().GetStatus().GetCode())) assert.Equal(t, http.StatusUnauthorized, int(res.GetDeniedResponse().GetStatus().GetCode()))
}) })

View file

@ -76,7 +76,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
// if there's a deny, the result is denied using the deny reasons. // if there's a deny, the result is denied using the deny reasons.
if res.Deny.Value { if res.Deny.Value {
return a.handleResultDenied(ctx, in, res, isForwardAuthVerify, res.Deny.Reasons) return a.handleResultDenied(ctx, in, req, res, isForwardAuthVerify, res.Deny.Reasons)
} }
// if there's an allow, the result is allowed. // if there's an allow, the result is allowed.
@ -85,7 +85,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
} }
// otherwise, the result is denied using the allow reasons. // otherwise, the result is denied using the allow reasons.
return a.handleResultDenied(ctx, in, res, isForwardAuthVerify, res.Allow.Reasons) return a.handleResultDenied(ctx, in, req, res, isForwardAuthVerify, res.Allow.Reasons)
} }
func getForwardAuthURL(r *http.Request) *url.URL { func getForwardAuthURL(r *http.Request) *url.URL {

42
config/identity.go Normal file
View file

@ -0,0 +1,42 @@
package config
import (
"github.com/pomerium/pomerium/pkg/grpc/identity"
)
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
// If none is found the default provider is returned.
func (o *Options) GetIdentityProviderForID(idpID string) *identity.Provider {
for _, policy := range o.GetAllPolicies() {
idp := o.GetIdentityProviderForPolicy(&policy) //nolint
if idp.GetId() == idpID {
return idp
}
}
return o.GetIdentityProviderForPolicy(nil)
}
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
// If policy is nil, or changes none of the default settings, the default provider is returned.
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) *identity.Provider {
idp := &identity.Provider{
ClientId: o.ClientID,
ClientSecret: o.ClientSecret,
Type: o.Provider,
Scopes: o.Scopes,
ServiceAccount: o.ServiceAccount,
Url: o.ProviderURL,
RequestParams: o.RequestParams,
}
if policy != nil {
if policy.IDPClientID != "" {
idp.ClientId = policy.IDPClientID
}
if policy.IDPClientSecret != "" {
idp.ClientSecret = policy.IDPClientSecret
}
}
idp.Id = idp.Hash()
return idp
}

View file

@ -162,6 +162,11 @@ type Policy struct {
// SetResponseHeaders sets response headers. // SetResponseHeaders sets response headers.
SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"` SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"`
// IDPClientID is the client id used for the identity provider.
IDPClientID string `mapstructure:"idp_client_id" yaml:"idp_client_id,omitempty"`
// IDPClientSecret is the client secret used for the identity provider.
IDPClientSecret string `mapstructure:"idp_client_secret" yaml:"idp_client_secret,omitempty"`
Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"` Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"`
} }

View file

@ -27,7 +27,11 @@ func NewError(status int, err error) error {
// Error implements the `error` interface. // Error implements the `error` interface.
func (e *HTTPError) Error() string { func (e *HTTPError) Error() string {
return StatusText(e.Status) + ": " + e.Err.Error() str := StatusText(e.Status)
if e.Err != nil {
str += ": " + e.Err.Error()
}
return str
} }
// Unwrap implements the `error` Unwrap interface. // Unwrap implements the `error` Unwrap interface.

View file

@ -3,10 +3,10 @@ package sessions
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"time" "time"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v3/jwt"
"github.com/google/uuid"
) )
// ErrMissingID is the error for a session state that has no ID set. // ErrMissingID is the error for a session state that has no ID set.
@ -15,34 +15,6 @@ var ErrMissingID = errors.New("invalid session: missing id")
// timeNow is time.Now but pulled out as a variable for tests. // timeNow is time.Now but pulled out as a variable for tests.
var timeNow = time.Now var timeNow = time.Now
// Version represents "ver" field in JWT public claims.
//
// The field is not specified by RFC 7519, so providers can
// return either string or number (like okta).
type Version string
// String implements fmt.Stringer interface.
func (v *Version) String() string {
return string(*v)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (v *Version) UnmarshalJSON(b []byte) error {
var tmp interface{}
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
switch val := tmp.(type) {
case string:
*v = Version(val)
case float64:
*v = Version(fmt.Sprintf("%g", val))
default:
return errors.New("invalid type for Version")
}
return nil
}
// State is our object that keeps track of a user's session state // State is our object that keeps track of a user's session state
type State struct { type State struct {
// Public claim values (as specified in RFC 7519). // Public claim values (as specified in RFC 7519).
@ -61,12 +33,26 @@ type State struct {
// DatabrokerRecordVersion tracks the last referenced databroker record version // DatabrokerRecordVersion tracks the last referenced databroker record version
// for the saved session. // for the saved session.
DatabrokerRecordVersion uint64 `json:"databroker_record_version,omitempty"` DatabrokerRecordVersion uint64 `json:"databroker_record_version,omitempty"`
// IdentityProviderID is the identity provider for the session.
IdentityProviderID string `json:"idp_id,omitempty"`
} }
// NewSession updates issuer, audience, and issuance timestamps but keeps // NewState creates a new State.
// parent expiry. func NewState(idpID string) *State {
func NewSession(s *State, issuer string, audience []string) State { return &State{
newState := *s IssuedAt: jwt.NewNumericDate(timeNow()),
ID: uuid.NewString(),
IdentityProviderID: idpID,
}
}
// WithNewIssuer creates a new State from an existing State.
func (s *State) WithNewIssuer(issuer string, audience []string) State {
newState := State{}
if s != nil {
newState = *s
}
newState.IssuedAt = jwt.NewNumericDate(timeNow()) newState.IssuedAt = jwt.NewNumericDate(timeNow())
newState.Audience = audience newState.Audience = audience
newState.Issuer = issuer newState.Issuer = issuer

View file

@ -18,31 +18,31 @@ func TestState_UnmarshalJSON(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
in *State in *State
want State want *State
wantErr bool wantErr bool
}{ }{
{ {
"good", "good",
&State{ID: "xyz"}, &State{ID: "xyz"},
State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)}, &State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)},
false, false,
}, },
{ {
"with user", "with user",
&State{ID: "xyz"}, &State{ID: "xyz"},
State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)}, &State{ID: "xyz", IssuedAt: jwt.NewNumericDate(fixedTime)},
false, false,
}, },
{ {
"without", "without",
&State{ID: "xyz", Subject: "user"}, &State{ID: "xyz", Subject: "user"},
State{ID: "xyz", Subject: "user", IssuedAt: jwt.NewNumericDate(fixedTime)}, &State{ID: "xyz", Subject: "user", IssuedAt: jwt.NewNumericDate(fixedTime)},
false, false,
}, },
{ {
"missing id", "missing id",
&State{}, &State{},
State{IssuedAt: jwt.NewNumericDate(fixedTime)}, &State{IssuedAt: jwt.NewNumericDate(fixedTime)},
true, true,
}, },
} }
@ -53,7 +53,8 @@ func TestState_UnmarshalJSON(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
s := NewSession(&State{}, "", nil) s := NewState("")
s.ID = ""
if err := s.UnmarshalJSON(data); (err != nil) != tt.wantErr { if err := s.UnmarshalJSON(data); (err != nil) != tt.wantErr {
t.Errorf("State.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("State.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -63,30 +64,3 @@ func TestState_UnmarshalJSON(t *testing.T) {
}) })
} }
} }
func TestVersion_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
jsonStr string
wantVersion string
wantErr bool
}{
{"Version is string", `"1"`, "1", false},
{"Version is integer", `1`, "1", false},
{"Version is float", `1.1`, "1.1", false},
{"Invalid version", `[1]`, "", true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var v Version
if err := v.UnmarshalJSON([]byte(tc.jsonStr)); (err != nil) != tc.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tc.wantErr)
}
if !tc.wantErr && v.String() != tc.wantVersion {
t.Errorf("mismatch version, want: %s, got: %s", tc.wantVersion, v.String())
}
})
}
}

View file

@ -8,6 +8,7 @@ const (
QueryDeviceCredentialID = "pomerium_device_credential_id" QueryDeviceCredentialID = "pomerium_device_credential_id"
QueryDeviceType = "pomerium_device_type" QueryDeviceType = "pomerium_device_type"
QueryEnrollmentToken = "pomerium_enrollment_token" //nolint QueryEnrollmentToken = "pomerium_enrollment_token" //nolint
QueryIdentityProviderID = "pomerium_idp_id"
QueryIsProgrammatic = "pomerium_programmatic" QueryIsProgrammatic = "pomerium_programmatic"
QueryForwardAuth = "pomerium_forward_auth" QueryForwardAuth = "pomerium_forward_auth"
QueryPomeriumJWT = "pomerium_jwt" QueryPomeriumJWT = "pomerium_jwt"

View file

@ -0,0 +1,27 @@
// Package identity contains protobuf types for identity management.
package identity
import (
"crypto/sha256"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/encoding/base58"
)
// Clone clones the Provider.
func (x *Provider) Clone() *Provider {
return proto.Clone(x).(*Provider)
}
// Hash computes a sha256 hash of the provider's fields. It excludes the Id field.
func (x *Provider) Hash() string {
tmp := x.Clone()
tmp.Id = ""
bs, _ := proto.MarshalOptions{
AllowPartial: true,
Deterministic: true,
}.Marshal(tmp)
h := sha256.Sum256(bs)
return base58.Encode(h[:])
}

View file

@ -0,0 +1,232 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.14.0
// source: identity.proto
package identity
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Provider struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
ClientId string `protobuf:"bytes,2,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"`
ClientSecret string `protobuf:"bytes,3,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"`
Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"`
Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"`
ServiceAccount string `protobuf:"bytes,6,opt,name=service_account,json=serviceAccount,proto3" json:"service_account,omitempty"`
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
RedirectUrl string `protobuf:"bytes,9,opt,name=redirect_url,json=redirectUrl,proto3" json:"redirect_url,omitempty"`
}
func (x *Provider) Reset() {
*x = Provider{}
if protoimpl.UnsafeEnabled {
mi := &file_identity_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Provider) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Provider) ProtoMessage() {}
func (x *Provider) ProtoReflect() protoreflect.Message {
mi := &file_identity_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Provider.ProtoReflect.Descriptor instead.
func (*Provider) Descriptor() ([]byte, []int) {
return file_identity_proto_rawDescGZIP(), []int{0}
}
func (x *Provider) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *Provider) GetClientId() string {
if x != nil {
return x.ClientId
}
return ""
}
func (x *Provider) GetClientSecret() string {
if x != nil {
return x.ClientSecret
}
return ""
}
func (x *Provider) GetType() string {
if x != nil {
return x.Type
}
return ""
}
func (x *Provider) GetScopes() []string {
if x != nil {
return x.Scopes
}
return nil
}
func (x *Provider) GetServiceAccount() string {
if x != nil {
return x.ServiceAccount
}
return ""
}
func (x *Provider) GetUrl() string {
if x != nil {
return x.Url
}
return ""
}
func (x *Provider) GetRequestParams() map[string]string {
if x != nil {
return x.RequestParams
}
return nil
}
func (x *Provider) GetRedirectUrl() string {
if x != nil {
return x.RedirectUrl
}
return ""
}
var File_identity_proto protoreflect.FileDescriptor
var file_identity_proto_rawDesc = []byte{
0x0a, 0x0e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x12, 0x11, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74,
0x69, 0x74, 0x79, 0x22, 0xff, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64,
0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a,
0x0d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x03,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72,
0x65, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73,
0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x27,
0x0a, 0x0f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e,
0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x07,
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x55, 0x0a, 0x0e, 0x72, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28,
0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65,
0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72,
0x79, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73,
0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c,
0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x55, 0x72, 0x6c, 0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61,
0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76,
0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75,
0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d,
0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x69,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_identity_proto_rawDescOnce sync.Once
file_identity_proto_rawDescData = file_identity_proto_rawDesc
)
func file_identity_proto_rawDescGZIP() []byte {
file_identity_proto_rawDescOnce.Do(func() {
file_identity_proto_rawDescData = protoimpl.X.CompressGZIP(file_identity_proto_rawDescData)
})
return file_identity_proto_rawDescData
}
var file_identity_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_identity_proto_goTypes = []interface{}{
(*Provider)(nil), // 0: pomerium.identity.Provider
nil, // 1: pomerium.identity.Provider.RequestParamsEntry
}
var file_identity_proto_depIdxs = []int32{
1, // 0: pomerium.identity.Provider.request_params:type_name -> pomerium.identity.Provider.RequestParamsEntry
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_identity_proto_init() }
func file_identity_proto_init() {
if File_identity_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_identity_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Provider); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_identity_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_identity_proto_goTypes,
DependencyIndexes: file_identity_proto_depIdxs,
MessageInfos: file_identity_proto_msgTypes,
}.Build()
File_identity_proto = out.File
file_identity_proto_rawDesc = nil
file_identity_proto_goTypes = nil
file_identity_proto_depIdxs = nil
}

View file

@ -0,0 +1,15 @@
syntax = "proto3";
package pomerium.identity;
option go_package = "github.com/pomerium/pomerium/pkg/grpc/identity";
message Provider {
string id = 1;
string client_id = 2;
string client_secret = 3;
string type = 4;
repeated string scopes = 5;
string service_account = 6;
string url = 7;
map<string, string> request_params = 8;
}

View file

@ -96,6 +96,11 @@ _import_paths=$(join_by , "${_imports[@]}")
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./directory/." \ --go_out="$_import_paths,plugins=grpc,paths=source_relative:./directory/." \
./directory/directory.proto ./directory/directory.proto
../../scripts/protoc -I ./identity/ \
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./identity/." \
./identity/identity.proto
../../scripts/protoc -I ./registry/ \ ../../scripts/protoc -I ./registry/ \
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./registry/." \ --go_out="$_import_paths,plugins=grpc,paths=source_relative:./registry/." \
--validate_out="lang=go,paths=source_relative:./registry" \ --validate_out="lang=go,paths=source_relative:./registry" \