diff --git a/Makefile b/Makefile index e9a0c6e6c..37814be26 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,12 @@ GETENVOY_VERSION = v0.1.8 all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet. +.PHONY: generate-mocks +generate-mocks: ## Generate mocks + @echo "==> $@" + @go run github.com/golang/mock/mockgen -destination authorize/evaluator/mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator + @go run github.com/golang/mock/mockgen -destination internal/grpc/cache/mock/mock_cacher.go github.com/pomerium/pomerium/internal/grpc/cache Cacher + .PHONY: build-deps build-deps: ## Install build dependencies @echo "==> $@" diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index b2c999ac6..5be243912 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -17,12 +17,12 @@ import ( "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/grpc" + "github.com/pomerium/pomerium/internal/grpc/cache" "github.com/pomerium/pomerium/internal/grpc/cache/client" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/sessions/cache" "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/header" "github.com/pomerium/pomerium/internal/sessions/queryparam" @@ -93,7 +93,7 @@ type Authenticate struct { provider identity.Authenticator // cacheClient is the interface for setting and getting sessions from a cache - cacheClient client.Cacher + cacheClient cache.Cacher templates *template.Template } @@ -106,12 +106,12 @@ func New(opts config.Options) (*Authenticate, error) { // shared state encoder setup sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey) - signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host) + sharedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host) if err != nil { return nil, err } - // private state encoder setup + // private state encoder setup, used to encrypt oauth2 tokens decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret) encryptedEncoder := ecjson.New(cookieCipher) @@ -124,7 +124,7 @@ func New(opts config.Options) (*Authenticate, error) { Expire: opts.CookieExpire, } - cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder) + cookieStore, err := cookie.NewStore(cookieOptions, sharedEncoder) if err != nil { return nil, err } @@ -145,13 +145,7 @@ func New(opts config.Options) (*Authenticate, error) { cacheClient := client.New(cacheConn) - cacheStore := cache.NewStore(&cache.Options{ - Cache: cacheClient, - Encoder: encryptedEncoder, - QueryParam: urlutil.QueryAccessTokenID, - WrappedStore: cookieStore}) - - qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token") + qpStore := queryparam.NewStore(encryptedEncoder, urlutil.QueryProgrammaticToken) headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium) redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) @@ -177,14 +171,14 @@ func New(opts config.Options) (*Authenticate, error) { // shared state sharedKey: opts.SharedKey, sharedCipher: sharedCipher, - sharedEncoder: signedEncoder, + sharedEncoder: sharedEncoder, // private state cookieSecret: decodedCookieSecret, cookieCipher: cookieCipher, cookieOptions: cookieOptions, - sessionStore: cacheStore, + sessionStore: cookieStore, encryptedEncoder: encryptedEncoder, - sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore}, + sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore}, // IdP provider: provider, // grpc client for cache diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 6083dab92..9f575966d 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -91,6 +91,10 @@ func TestNew(t *testing.T) { badGRPCConn.CacheURL = nil badGRPCConn.CookieName = "D" + emptyProviderURL := newTestOptions(t) + emptyProviderURL.Provider = "oidc" + emptyProviderURL.ProviderURL = "" + tests := []struct { name string opts *config.Options @@ -103,6 +107,7 @@ func TestNew(t *testing.T) { {"bad cookie name", badCookieName, true}, {"bad provider", badProvider, true}, {"bad cache url", badGRPCConn, true}, + {"empty provider url", emptyProviderURL, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9ef8497dd..f45f92211 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/csrf" "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" @@ -23,6 +24,7 @@ import ( "github.com/gorilla/mux" "github.com/rs/cors" + "golang.org/x/oauth2" ) // Handler returns the authenticate service's handler chain. @@ -80,19 +82,15 @@ func (a *Authenticate) Mount(r *mux.Router) { // session state is attached to the users's request context. func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - jwt, err := sessions.FromContext(ctx) + ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession") + defer span.End() + s, err := a.getSessionFromCtx(ctx) if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error") return a.reauthenticateOrFail(w, r, err) } - var s sessions.State - if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - if s.IsExpired() { - ctx, err = a.refresh(w, r, &s) + ctx, err = a.refresh(w, r, s) if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") return a.reauthenticateOrFail(w, r, err) @@ -106,18 +104,34 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) { ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh") defer span.End() - newSession, err := a.provider.Refresh(ctx, s) + accessToken, err := a.getAccessToken(ctx, s) + if err != nil { + return nil, err + } + + // we are going to keep the same audiences for the refreshed token + // otherwise this will be rewritten to be the ClientID of the provider + oldAudience := s.Audience + + newAccessToken, err := a.provider.Refresh(ctx, accessToken, s) if err != nil { return nil, fmt.Errorf("authenticate: refresh failed: %w", err) } - if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { - return nil, fmt.Errorf("authenticate: refresh save failed: %w", err) - } - newSession = newSession.NewSession(s.Issuer, s.Audience) - encSession, err := a.encryptedEncoder.Marshal(newSession) + + newSession := sessions.NewSession(s, a.RedirectURL.Hostname(), oldAudience, newAccessToken) + + encSession, err := a.sharedEncoder.Marshal(newSession) if err != nil { return nil, err } + + if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { + return nil, fmt.Errorf("authenticate: error saving new session: %w", err) + } + + if err := a.setAccessToken(ctx, newAccessToken); err != nil { + return nil, fmt.Errorf("authenticate: error saving refreshed access token: %w", err) + } // return the new session and add it to the current request context return sessions.NewContext(ctx, string(encSession), err), nil } @@ -129,8 +143,11 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "User-agent: *\nDisallow: /") } -// SignIn handles to authenticating a user. +// SignIn handles authenticating a user. func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { + ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut") + defer span.End() + redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { return httputil.NewError(http.StatusBadRequest, err) @@ -158,32 +175,30 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { jwtAudience = append(jwtAudience, fwdAuth) } - jwt, err := sessions.FromContext(r.Context()) + s, err := a.getSessionFromCtx(ctx) if err != nil { - return httputil.NewError(http.StatusBadRequest, err) + return err } - var s sessions.State - if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { - return httputil.NewError(http.StatusBadRequest, err) + accessToken, err := a.getAccessToken(ctx, s) + if err != nil { + return err } - // user impersonation if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" { s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups)) } + newSession := sessions.NewSession(s, a.RedirectURL.Host, jwtAudience, accessToken) // re-persist the session, useful when session was evicted from session - if err := a.sessionStore.SaveSession(w, r, &s); err != nil { + if err := a.sessionStore.SaveSession(w, r, s); err != nil { return httputil.NewError(http.StatusBadRequest, err) } - newSession := s.NewSession(a.RedirectURL.Host, jwtAudience) - callbackParams := callbackURL.Query() if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { newSession.Programmatic = true - encSession, err := a.encryptedEncoder.Marshal(newSession) + encSession, err := a.encryptedEncoder.Marshal(accessToken) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } @@ -192,7 +207,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { } // sign the route session, as a JWT - signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession()) + signedJWT, err := a.sharedEncoder.Marshal(newSession) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } @@ -217,27 +232,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { // SignOut signs the user out and attempts to revoke the user's identity session // Handles both GET and POST. func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { - // no matter what happens, we want to clear the local session store + ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut") + defer span.End() + + // no matter what happens, we want to clear the session store a.sessionStore.ClearSession(w, r) - - jwt, err := sessions.FromContext(r.Context()) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - var s sessions.State - if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - redirectString := r.FormValue(urlutil.QueryRedirectURI) - - // first, try to revoke the session if implemented - err = a.provider.Revoke(r.Context(), s.AccessToken) - if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) { - return httputil.NewError(http.StatusBadRequest, err) - } - - // next, try to build a logout url if implemented endSessionURL, err := a.provider.LogOut() if err == nil { params := url.Values{} @@ -245,14 +245,29 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { endSessionURL.RawQuery = params.Encode() redirectString = endSessionURL.String() } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { - return httputil.NewError(http.StatusBadRequest, err) + log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session") } - redirectURL, err := urlutil.ParseAndValidateURL(redirectString) + httputil.Redirect(w, r, redirectString, http.StatusFound) + + s, err := a.getSessionFromCtx(ctx) if err != nil { - return httputil.NewError(http.StatusBadRequest, err) + log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session") + return nil + } + + accessToken, err := a.getAccessToken(ctx, s) + if err != nil { + log.Warn().Err(err).Msg("authenticate.SignOut: failed getting access token") + return nil + } + + // first, try to revoke the session if implemented + err = a.provider.Revoke(ctx, accessToken) + if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) { + log.Warn().Err(err).Msg("authenticate.SignOut: failed revoking token") + return nil } - httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) return nil } @@ -267,7 +282,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { // https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error { - // If request AJAX/XHR request, return a 401 instead . + // If request AJAX/XHR request, return a 401 instead because the redirect + // will almost certainly violate their CORs policy if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") { return httputil.NewError(http.StatusUnauthorized, err) } @@ -290,7 +306,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error { redirect, err := a.getOAuthCallback(w, r) if err != nil { - return fmt.Errorf("oauth callback : %w", err) + return fmt.Errorf("authenticate.OAuthCallback: %w", err) } httputil.Redirect(w, r, redirect.String(), http.StatusFound) return nil @@ -306,6 +322,9 @@ func (a *Authenticate) statusForErrorCode(errorCode string) int { } func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { + ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") + defer span.End() + // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 // // first, check if the identity provider returned an error @@ -321,14 +340,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5 // // Exchange the supplied Authorization Code for a valid user session. - session, err := a.provider.Authenticate(r.Context(), code) + var s sessions.State + accessToken, err := a.provider.Authenticate(ctx, code, &s) if err != nil { return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } + + newState := sessions.NewSession( + &s, + a.RedirectURL.Hostname(), + []string{a.RedirectURL.Hostname()}, + accessToken) + // state includes a csrf nonce (validated by middleware) and redirect uri bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) if err != nil { - return nil, httputil.NewError(http.StatusBadRequest, err) + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("bad bytes: %w", err)) } // split state into concat'd components @@ -357,8 +384,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) return nil, httputil.NewError(http.StatusBadRequest, err) } - // OK. Looks good so let's persist our user session - if err := a.sessionStore.SaveSession(w, r, session); err != nil { + // Ok -- We've got a valid session here. Let's now persist the access + // token to cache ... + if err := a.setAccessToken(ctx, accessToken); err != nil { + return nil, fmt.Errorf("failed saving access token: %w", err) + } + // ... and the user state to local storage. + if err := a.sessionStore.SaveSession(w, r, &newState); err != nil { return nil, fmt.Errorf("failed saving new session: %w", err) } return redirectURL, nil @@ -368,26 +400,32 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) // tokens and state with the identity provider. If successful, a new signed JWT // and refresh token (`refresh_token`) are returned as JSON func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error { - jwt, err := sessions.FromContext(r.Context()) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - var s sessions.State - if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - newSession, err := a.provider.Refresh(r.Context(), &s) - if err != nil { - return err - } - newSession = newSession.NewSession(s.Issuer, s.Audience) + ctx, span := trace.StartSpan(r.Context(), "authenticate.RefreshAPI") + defer span.End() - encSession, err := a.encryptedEncoder.Marshal(newSession) + s, err := a.getSessionFromCtx(ctx) if err != nil { return err } - signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession()) + accessToken, err := a.getAccessToken(ctx, s) + if err != nil { + return err + } + + newAccessToken, err := a.provider.Refresh(ctx, accessToken, s) + if err != nil { + return err + } + + routeNewSession := sessions.NewSession(s, a.RedirectURL.Hostname(), s.Audience, newAccessToken) + + encSession, err := a.encryptedEncoder.Marshal(accessToken) + if err != nil { + return err + } + + signedJWT, err := a.sharedEncoder.Marshal(routeNewSession) if err != nil { return err } @@ -410,28 +448,79 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error // Refresh is called by the proxy service to handle backend session refresh. // // NOTE: The actual refresh is handled as part of the "VerifySession" -// middleware. This handler is responsible for creating a new route scoped -// session and returning it. +// middleware. This handler is simply responsible for returning that jwt. func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error { - jwt, err := sessions.FromContext(r.Context()) + ctx, span := trace.StartSpan(r.Context(), "authenticate.Refresh") + defer span.End() + jwt, err := sessions.FromContext(ctx) if err != nil { - return httputil.NewError(http.StatusBadRequest, err) + return fmt.Errorf("authenticate.Refresh: %w", err) } - var s sessions.State - if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { - return httputil.NewError(http.StatusBadRequest, err) + w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1 + fmt.Fprint(w, jwt) + return nil +} + +// getAccessToken gets an associated oauth2 access token from a session state +func (a *Authenticate) getAccessToken(ctx context.Context, s *sessions.State) (*oauth2.Token, error) { + ctx, span := trace.StartSpan(ctx, "authenticate.getAccessToken") + defer span.End() + + var accessToken oauth2.Token + tokenBytes, err := a.cacheClient.Get(ctx, s.AccessTokenHash) + if err != nil { + return nil, err + } + if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil { + return nil, err + } + if accessToken.Valid() { + return &accessToken, nil // this token is still valid, use it! + } + tokenBytes, err = a.cacheClient.Get(ctx, a.timestampedHash(accessToken.RefreshToken)) + if err == nil { + // we found another possibly newer access token associated with the + // existing refresh token so let's try that. + if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil { + return nil, err + } } - aud := strings.Split(r.FormValue(urlutil.QueryAudience), ",") - routeSession := s.NewSession(r.Host, aud) - routeSession.AccessTokenID = s.AccessTokenID + return &accessToken, nil +} - signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession()) +func (a *Authenticate) setAccessToken(ctx context.Context, accessToken *oauth2.Token) error { + encToken, err := a.encryptedEncoder.Marshal(accessToken) if err != nil { return err } + // set this specific access token + key := fmt.Sprintf("%x", hashutil.Hash(accessToken)) + if err := a.cacheClient.Set(ctx, key, encToken); err != nil { + return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err) + } + + // set this as the "latest" token for this access token + key = a.timestampedHash(accessToken.RefreshToken) + if err := a.cacheClient.Set(ctx, key, encToken); err != nil { + return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err) + } - w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1 - w.Write(signedJWT) return nil } + +func (a *Authenticate) timestampedHash(s string) string { + return fmt.Sprintf("%x-%v", hashutil.Hash(s), time.Now().Truncate(time.Minute).Unix()) +} + +func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, error) { + jwt, err := sessions.FromContext(ctx) + if err != nil { + return nil, httputil.NewError(http.StatusBadRequest, err) + } + var s sessions.State + if err := a.sharedEncoder.Unmarshal([]byte(jwt), &s); err != nil { + return nil, httputil.NewError(http.StatusBadRequest, err) + } + return &s, nil +} diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 0e5de8cda..5ffc27f0d 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -11,20 +11,24 @@ import ( "testing" "time" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/frontend" + mock_cache "github.com/pomerium/pomerium/internal/grpc/cache/mock" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" mstore "github.com/pomerium/pomerium/internal/sessions/mock" - - "github.com/google/go-cmp/cmp" "github.com/pomerium/pomerium/internal/urlutil" + + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + "github.com/gorilla/mux" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2/jwt" @@ -115,23 +119,28 @@ func TestAuthenticate_SignIn(t *testing.T) { encoder encoding.MarshalUnmarshaler wantCode int }{ - {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, + {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, + {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mc := mock_cache.NewMockCacher(ctrl) + mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes() + a := &Authenticate{ sessionStore: tt.session, provider: tt.provider, @@ -144,6 +153,7 @@ func TestAuthenticate_SignIn(t *testing.T) { Name: "cookie", Domain: "foo", }, + cacheClient: mc, } uri := &url.URL{Scheme: tt.scheme, Host: tt.host} @@ -176,6 +186,7 @@ func uriParseHelper(s string) *url.URL { func TestAuthenticate_SignOut(t *testing.T) { t.Parallel() + tests := []struct { name string method string @@ -190,19 +201,24 @@ func TestAuthenticate_SignOut(t *testing.T) { wantCode int wantBody string }{ - {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""}, - {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"}, - {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"}, - {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"}, + {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""}, + {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""}, + {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""}, + {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mc := mock_cache.NewMockCacher(ctrl) + mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes() a := &Authenticate{ sessionStore: tt.sessionStore, provider: tt.provider, encryptedEncoder: mock.Encoder{}, templates: template.Must(frontend.NewTemplates()), + sharedEncoder: mock.Encoder{}, + cacheClient: mc, } u, _ := url.Parse("/sign_out") params, _ := url.ParseQuery(u.RawQuery) @@ -256,40 +272,50 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { want string wantCode int }{ - {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound}, - {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, - {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError}, - {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, - {"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusUnauthorized}, - {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, - {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, - {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusFound}, + {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}, AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, + {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{}, "", http.StatusInternalServerError}, + {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest}, + {"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusUnauthorized}, + {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest}, + {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest}, + {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mc := mock_cache.NewMockCacher(ctrl) + mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes() + mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) if err != nil { t.Fatal(err) } + signer, err := jws.NewHS256Signer(nil, "mock") + if err != nil { + t.Fatal(err) + } authURL, _ := url.Parse(tt.authenticateURL) a := &Authenticate{ - RedirectURL: authURL, - sessionStore: tt.session, - provider: tt.provider, - cookieCipher: aead, + RedirectURL: authURL, + sessionStore: tt.session, + provider: tt.provider, + cookieCipher: aead, + cacheClient: mc, + encryptedEncoder: signer, } u, _ := url.Parse("/oauthGet") params, _ := url.ParseQuery(u.RawQuery) params.Add("error", tt.paramErr) params.Add("code", tt.code) nonce := cryptutil.NewBase64Key() // mock csrf - // (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts)) b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac)) @@ -336,15 +362,21 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { wantStatus int }{ - {"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, + {"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK}, {"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, - {"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, + {"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK}, {"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, - {"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, + {"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusFound}, {"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mc := mock_cache.NewMockCacher(ctrl) + mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes() + mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) if err != nil { t.Fatal(err) @@ -361,6 +393,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { provider: tt.provider, cookieCipher: aead, encryptedEncoder: signer, + cacheClient: mc, + sharedEncoder: mock.Encoder{}, } r := httptest.NewRequest("GET", "/", nil) state, err := tt.session.LoadSession(r) @@ -402,14 +436,20 @@ func TestAuthenticate_RefreshAPI(t *testing.T) { wantStatus int }{ - {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, + {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, {"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, - {"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, - {"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, - {"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError}, + {"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, + {"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, + {"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mc := mock_cache.NewMockCacher(ctrl) + mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes() + mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) if err != nil { t.Fatal(err) @@ -423,6 +463,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) { sessionStore: tt.session, provider: tt.provider, cookieCipher: aead, + cacheClient: mc, } r := httptest.NewRequest("GET", "/", nil) state, _ := tt.session.LoadSession(r) @@ -441,53 +482,111 @@ func TestAuthenticate_RefreshAPI(t *testing.T) { }) } } + func TestAuthenticate_Refresh(t *testing.T) { t.Parallel() tests := []struct { name string - session sessions.SessionStore - ctxError error + session *sessions.State + at *oauth2.Token provider identity.Authenticator secretEncoder encoding.MarshalUnmarshaler - sharedEncoder encoding.MarshalUnmarshaler wantStatus int }{ - {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, - {"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, - {"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError}, + {"good", + &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, + &oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)}, + identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, + mock.Encoder{MarshalResponse: []byte("ok")}, + 200}, + {"session and oauth2 expired", + &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, + &oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(-10 * time.Minute)}, + identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, + mock.Encoder{MarshalResponse: []byte("ok")}, + 200}, + {"session expired", + &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, + &oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)}, + identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, + mock.Encoder{MarshalResponse: []byte("ok")}, + 200}, + {"failed refresh", + &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, + &oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)}, + identity.MockProvider{RefreshError: errors.New("oh no")}, + mock.Encoder{MarshalResponse: []byte("ok")}, + 302}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mc := mock_cache.NewMockCacher(ctrl) + // just enough is stubbed out here so we can use our own mock provider + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Header().Set("Content-Type", "application/json") + out := fmt.Sprintf(`{"issuer":"http://%s"}`, r.Host) + fmt.Fprintln(w, out) + })) + defer ts.Close() + rURL := ts.URL + a, err := New(config.Options{ + SharedKey: cryptutil.NewBase64Key(), + CookieSecret: cryptutil.NewBase64Key(), + AuthenticateURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), + Provider: "oidc", + ClientID: "mock", + ClientSecret: "mock", + ProviderURL: rURL, + AuthenticateCallbackPath: "mock", + CookieName: "pomerium", + Addr: ":0", + CacheURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), + }) if err != nil { t.Fatal(err) } - a := Authenticate{ - sharedKey: cryptutil.NewBase64Key(), - cookieSecret: cryptutil.NewKey(), - RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), - encryptedEncoder: tt.secretEncoder, - sharedEncoder: tt.sharedEncoder, - sessionStore: tt.session, - provider: tt.provider, - cookieCipher: aead, - } - r := httptest.NewRequest("GET", "/", nil) - state, _ := tt.session.LoadSession(r) - ctx := r.Context() - ctx = sessions.NewContext(ctx, state, tt.ctxError) - r = r.WithContext(ctx) + a.cacheClient = mc + a.provider = tt.provider + u, _ := url.Parse("/oauthGet") + params, _ := url.ParseQuery(u.RawQuery) + destination := urlutil.NewSignedURL(a.sharedKey, + &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/.pomerium/refresh"}) + + u.RawQuery = params.Encode() + + r := httptest.NewRequest(http.MethodGet, destination.String(), nil) + + jwt, err := a.sharedEncoder.Marshal(tt.session) + if err != nil { + t.Fatal(err) + } + rawToken, err := a.encryptedEncoder.Marshal(tt.at) + if err != nil { + t.Fatal(err) + } + mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return(rawToken, nil).AnyTimes() + mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + a.cacheClient = mc + + r.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", jwt)) r.Header.Set("Accept", "application/json") w := httptest.NewRecorder() - httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r) + router := mux.NewRouter() + a.Mount(router) + router.ServeHTTP(w, r) if status := w.Code; status != tt.wantStatus { - t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) - + t.Errorf("Refresh() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) } }) } diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 4a3122fec..6a58f9d9f 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -1,5 +1,3 @@ -//go:generate mockgen -destination mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator - // Package evaluator defines a Evaluator interfaces that can be implemented by // a policy evaluator framework. package evaluator diff --git a/authorize/grpc.go b/authorize/grpc.go index 0897eaa33..3da564860 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -218,10 +218,6 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS // 1 - build a signed url to call refresh on authenticate service refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"}) - q := refreshURI.Query() - q.Set(urlutil.QueryAccessTokenID, state.AccessTokenID) // hash value points to parent token - q.Set(urlutil.QueryAudience, strings.Join(state.Audience, ",")) // request's audience, this route - refreshURI.RawQuery = q.Encode() signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String() // 2 - http call to authenticate service @@ -229,6 +225,7 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS if err != nil { return nil, fmt.Errorf("authorize: refresh request: %w", err) } + req.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", rawSession)) req.Header.Set("X-Requested-With", "XmlHttpRequest") req.Header.Set("Accept", "application/json") diff --git a/cache/grpc_test.go b/cache/grpc_test.go index 7b969578f..c7165db4e 100644 --- a/cache/grpc_test.go +++ b/cache/grpc_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/grpc/cache" @@ -48,11 +49,8 @@ func TestCache_Get_and_Set(t *testing.T) { &cache.SetReply{}, &cache.GetRequest{Key: "key"}, &cache.GetReply{ - Exists: true, - Value: []byte("hello"), - XXX_NoUnkeyedLiteral: struct{}{}, - XXX_unrecognized: nil, - XXX_sizecache: 0, + Exists: true, + Value: []byte("hello"), }, false, false, @@ -63,11 +61,8 @@ func TestCache_Get_and_Set(t *testing.T) { &cache.SetReply{}, &cache.GetRequest{Key: "no-such-key"}, &cache.GetReply{ - Exists: false, - Value: nil, - XXX_NoUnkeyedLiteral: struct{}{}, - XXX_unrecognized: nil, - XXX_sizecache: 0, + Exists: false, + Value: nil, }, false, false, @@ -78,11 +73,8 @@ func TestCache_Get_and_Set(t *testing.T) { nil, &cache.GetRequest{Key: hugeKey}, &cache.GetReply{ - Exists: false, - Value: nil, - XXX_NoUnkeyedLiteral: struct{}{}, - XXX_unrecognized: nil, - XXX_sizecache: 0, + Exists: false, + Value: nil, }, true, false, @@ -96,7 +88,11 @@ func TestCache_Get_and_Set(t *testing.T) { t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError) return } - if diff := cmp.Diff(setGot, tt.SetReply); diff != "" { + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(cache.SetReply{}, cache.GetReply{}), + } + + if diff := cmp.Diff(setGot, tt.SetReply, cmpOpts...); diff != "" { t.Errorf("Cache.Set() = %v", diff) } getGot, err := c.Get(tt.ctx, tt.GetRequest) @@ -104,7 +100,7 @@ func TestCache_Get_and_Set(t *testing.T) { t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError) return } - if diff := cmp.Diff(getGot, tt.GetReply); diff != "" { + if diff := cmp.Diff(getGot, tt.GetReply, cmpOpts...); diff != "" { t.Errorf("Cache.Get() = %v", diff) } }) diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 04eae8316..c8ee1a619 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/pem" "sort" + "time" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" @@ -151,6 +152,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, }, Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{ GrpcService: &envoy_config_core_v3.GrpcService{ + Timeout: ptypes.DurationProto(time.Second * 30), TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ ClusterName: "pomerium-authz", diff --git a/internal/grpc/cache/cache.go b/internal/grpc/cache/cache.go new file mode 100644 index 000000000..6f55a1db4 --- /dev/null +++ b/internal/grpc/cache/cache.go @@ -0,0 +1,14 @@ +// Package cache defines a Cacher interfaces that can be implemented by any +// key value store system. +package cache + +import ( + "context" +) + +// Cacher specifies an interface for remote clients connecting to the cache service. +type Cacher interface { + Get(ctx context.Context, key string) (value []byte, err error) + Set(ctx context.Context, key string, value []byte) error + Close() error +} diff --git a/internal/grpc/cache/cache.pb.go b/internal/grpc/cache/cache.pb.go index ebb73477f..6cd3c7aff 100644 --- a/internal/grpc/cache/cache.pb.go +++ b/internal/grpc/cache/cache.pb.go @@ -1,217 +1,356 @@ // Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.21.0 +// protoc v3.11.4 // source: cache.proto package cache import ( context "context" - fmt "fmt" proto "github.com/golang/protobuf/proto" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" - math "math" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" ) -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 type GetRequest struct { - Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` } -func (m *GetRequest) Reset() { *m = GetRequest{} } -func (m *GetRequest) String() string { return proto.CompactTextString(m) } -func (*GetRequest) ProtoMessage() {} +func (x *GetRequest) Reset() { + *x = GetRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_cache_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetRequest) ProtoMessage() {} + +func (x *GetRequest) ProtoReflect() protoreflect.Message { + mi := &file_cache_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetRequest.ProtoReflect.Descriptor instead. func (*GetRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_5fca3b110c9bbf3a, []int{0} + return file_cache_proto_rawDescGZIP(), []int{0} } -func (m *GetRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_GetRequest.Unmarshal(m, b) -} -func (m *GetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_GetRequest.Marshal(b, m, deterministic) -} -func (m *GetRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_GetRequest.Merge(m, src) -} -func (m *GetRequest) XXX_Size() int { - return xxx_messageInfo_GetRequest.Size(m) -} -func (m *GetRequest) XXX_DiscardUnknown() { - xxx_messageInfo_GetRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_GetRequest proto.InternalMessageInfo - -func (m *GetRequest) GetKey() string { - if m != nil { - return m.Key +func (x *GetRequest) GetKey() string { + if x != nil { + return x.Key } return "" } type GetReply struct { - Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"` - Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` } -func (m *GetReply) Reset() { *m = GetReply{} } -func (m *GetReply) String() string { return proto.CompactTextString(m) } -func (*GetReply) ProtoMessage() {} +func (x *GetReply) Reset() { + *x = GetReply{} + if protoimpl.UnsafeEnabled { + mi := &file_cache_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetReply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetReply) ProtoMessage() {} + +func (x *GetReply) ProtoReflect() protoreflect.Message { + mi := &file_cache_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetReply.ProtoReflect.Descriptor instead. func (*GetReply) Descriptor() ([]byte, []int) { - return fileDescriptor_5fca3b110c9bbf3a, []int{1} + return file_cache_proto_rawDescGZIP(), []int{1} } -func (m *GetReply) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_GetReply.Unmarshal(m, b) -} -func (m *GetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_GetReply.Marshal(b, m, deterministic) -} -func (m *GetReply) XXX_Merge(src proto.Message) { - xxx_messageInfo_GetReply.Merge(m, src) -} -func (m *GetReply) XXX_Size() int { - return xxx_messageInfo_GetReply.Size(m) -} -func (m *GetReply) XXX_DiscardUnknown() { - xxx_messageInfo_GetReply.DiscardUnknown(m) -} - -var xxx_messageInfo_GetReply proto.InternalMessageInfo - -func (m *GetReply) GetExists() bool { - if m != nil { - return m.Exists +func (x *GetReply) GetExists() bool { + if x != nil { + return x.Exists } return false } -func (m *GetReply) GetValue() []byte { - if m != nil { - return m.Value +func (x *GetReply) GetValue() []byte { + if x != nil { + return x.Value } return nil } type SetRequest struct { - Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` - Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` } -func (m *SetRequest) Reset() { *m = SetRequest{} } -func (m *SetRequest) String() string { return proto.CompactTextString(m) } -func (*SetRequest) ProtoMessage() {} +func (x *SetRequest) Reset() { + *x = SetRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_cache_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SetRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetRequest) ProtoMessage() {} + +func (x *SetRequest) ProtoReflect() protoreflect.Message { + mi := &file_cache_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetRequest.ProtoReflect.Descriptor instead. func (*SetRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_5fca3b110c9bbf3a, []int{2} + return file_cache_proto_rawDescGZIP(), []int{2} } -func (m *SetRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_SetRequest.Unmarshal(m, b) -} -func (m *SetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_SetRequest.Marshal(b, m, deterministic) -} -func (m *SetRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_SetRequest.Merge(m, src) -} -func (m *SetRequest) XXX_Size() int { - return xxx_messageInfo_SetRequest.Size(m) -} -func (m *SetRequest) XXX_DiscardUnknown() { - xxx_messageInfo_SetRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_SetRequest proto.InternalMessageInfo - -func (m *SetRequest) GetKey() string { - if m != nil { - return m.Key +func (x *SetRequest) GetKey() string { + if x != nil { + return x.Key } return "" } -func (m *SetRequest) GetValue() []byte { - if m != nil { - return m.Value +func (x *SetRequest) GetValue() []byte { + if x != nil { + return x.Value } return nil } type SetReply struct { - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields } -func (m *SetReply) Reset() { *m = SetReply{} } -func (m *SetReply) String() string { return proto.CompactTextString(m) } -func (*SetReply) ProtoMessage() {} +func (x *SetReply) Reset() { + *x = SetReply{} + if protoimpl.UnsafeEnabled { + mi := &file_cache_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SetReply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetReply) ProtoMessage() {} + +func (x *SetReply) ProtoReflect() protoreflect.Message { + mi := &file_cache_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetReply.ProtoReflect.Descriptor instead. func (*SetReply) Descriptor() ([]byte, []int) { - return fileDescriptor_5fca3b110c9bbf3a, []int{3} + return file_cache_proto_rawDescGZIP(), []int{3} } -func (m *SetReply) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_SetReply.Unmarshal(m, b) -} -func (m *SetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_SetReply.Marshal(b, m, deterministic) -} -func (m *SetReply) XXX_Merge(src proto.Message) { - xxx_messageInfo_SetReply.Merge(m, src) -} -func (m *SetReply) XXX_Size() int { - return xxx_messageInfo_SetReply.Size(m) -} -func (m *SetReply) XXX_DiscardUnknown() { - xxx_messageInfo_SetReply.DiscardUnknown(m) +var File_cache_proto protoreflect.FileDescriptor + +var file_cache_proto_rawDesc = []byte{ + 0x0a, 0x0b, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63, + 0x61, 0x63, 0x68, 0x65, 0x22, 0x1e, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x03, 0x6b, 0x65, 0x79, 0x22, 0x38, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79, + 0x12, 0x16, 0x0a, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x34, + 0x0a, 0x0a, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x22, 0x0a, 0x0a, 0x08, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79, + 0x32, 0x61, 0x0a, 0x05, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x2b, 0x0a, 0x03, 0x47, 0x65, 0x74, + 0x12, 0x11, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52, + 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x03, 0x53, 0x65, 0x74, 0x12, 0x11, 0x2e, + 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, + 0x79, 0x22, 0x00, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } -var xxx_messageInfo_SetReply proto.InternalMessageInfo +var ( + file_cache_proto_rawDescOnce sync.Once + file_cache_proto_rawDescData = file_cache_proto_rawDesc +) -func init() { - proto.RegisterType((*GetRequest)(nil), "cache.GetRequest") - proto.RegisterType((*GetReply)(nil), "cache.GetReply") - proto.RegisterType((*SetRequest)(nil), "cache.SetRequest") - proto.RegisterType((*SetReply)(nil), "cache.SetReply") +func file_cache_proto_rawDescGZIP() []byte { + file_cache_proto_rawDescOnce.Do(func() { + file_cache_proto_rawDescData = protoimpl.X.CompressGZIP(file_cache_proto_rawDescData) + }) + return file_cache_proto_rawDescData } -func init() { - proto.RegisterFile("cache.proto", fileDescriptor_5fca3b110c9bbf3a) +var file_cache_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_cache_proto_goTypes = []interface{}{ + (*GetRequest)(nil), // 0: cache.GetRequest + (*GetReply)(nil), // 1: cache.GetReply + (*SetRequest)(nil), // 2: cache.SetRequest + (*SetReply)(nil), // 3: cache.SetReply +} +var file_cache_proto_depIdxs = []int32{ + 0, // 0: cache.Cache.Get:input_type -> cache.GetRequest + 2, // 1: cache.Cache.Set:input_type -> cache.SetRequest + 1, // 2: cache.Cache.Get:output_type -> cache.GetReply + 3, // 3: cache.Cache.Set:output_type -> cache.SetReply + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name } -var fileDescriptor_5fca3b110c9bbf3a = []byte{ - // 176 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x4e, 0x4c, 0xce, - 0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0xb8, 0xdc, - 0x53, 0x4b, 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x04, 0xb8, 0x98, 0xb3, 0x53, 0x2b, - 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x40, 0x4c, 0x25, 0x0b, 0x2e, 0x0e, 0xb0, 0x7c, 0x41, - 0x4e, 0xa5, 0x90, 0x18, 0x17, 0x5b, 0x6a, 0x45, 0x66, 0x71, 0x49, 0x31, 0x58, 0x01, 0x47, 0x10, - 0x94, 0x27, 0x24, 0xc2, 0xc5, 0x5a, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, 0xa4, 0xc0, 0xa8, 0xc1, - 0x13, 0x04, 0xe1, 0x28, 0x99, 0x70, 0x71, 0x05, 0xe3, 0x31, 0x19, 0x87, 0x2e, 0x2e, 0x2e, 0x8e, - 0x60, 0xa8, 0x7d, 0x46, 0x89, 0x5c, 0xac, 0xce, 0x20, 0x47, 0x0a, 0x69, 0x73, 0x31, 0xbb, 0xa7, - 0x96, 0x08, 0x09, 0xea, 0x41, 0x3c, 0x80, 0x70, 0xb0, 0x14, 0x3f, 0xb2, 0x50, 0x41, 0x4e, 0xa5, - 0x12, 0x03, 0x48, 0x71, 0x30, 0x92, 0xe2, 0x60, 0x4c, 0xc5, 0xc1, 0x70, 0xc5, 0x49, 0x6c, 0xe0, - 0xc0, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x0e, 0xef, 0x5f, 0x9e, 0x1b, 0x01, 0x00, 0x00, +func init() { file_cache_proto_init() } +func file_cache_proto_init() { + if File_cache_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_cache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetReply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetReply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_cache_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_cache_proto_goTypes, + DependencyIndexes: file_cache_proto_depIdxs, + MessageInfos: file_cache_proto_msgTypes, + }.Build() + File_cache_proto = out.File + file_cache_proto_rawDesc = nil + file_cache_proto_goTypes = nil + file_cache_proto_depIdxs = nil } // Reference imports to suppress errors if they are not otherwise used. @@ -266,10 +405,10 @@ type CacheServer interface { type UnimplementedCacheServer struct { } -func (*UnimplementedCacheServer) Get(ctx context.Context, req *GetRequest) (*GetReply, error) { +func (*UnimplementedCacheServer) Get(context.Context, *GetRequest) (*GetReply, error) { return nil, status.Errorf(codes.Unimplemented, "method Get not implemented") } -func (*UnimplementedCacheServer) Set(ctx context.Context, req *SetRequest) (*SetReply, error) { +func (*UnimplementedCacheServer) Set(context.Context, *SetRequest) (*SetReply, error) { return nil, status.Errorf(codes.Unimplemented, "method Set not implemented") } diff --git a/internal/grpc/cache/client/cache_client.go b/internal/grpc/cache/client/cache_client.go index 983c6ee34..d1765a65c 100644 --- a/internal/grpc/cache/client/cache_client.go +++ b/internal/grpc/cache/client/cache_client.go @@ -3,6 +3,7 @@ package client import ( "context" + "errors" "github.com/pomerium/pomerium/internal/grpc/cache" "github.com/pomerium/pomerium/internal/telemetry/trace" @@ -10,12 +11,7 @@ import ( "google.golang.org/grpc" ) -// Cacher specifies an interface for remote clients connecting to the cache service. -type Cacher interface { - Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) - Set(ctx context.Context, key string, value []byte) error - Close() error -} +var errKeyNotFound = errors.New("cache/client: key not found") // Client represents a gRPC cache service client. type Client struct { @@ -29,15 +25,18 @@ func New(conn *grpc.ClientConn) (p *Client) { } // Get retrieves a value from the cache service. -func (a *Client) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) { +func (a *Client) Get(ctx context.Context, key string) (value []byte, err error) { ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get") defer span.End() response, err := a.client.Get(ctx, &cache.GetRequest{Key: key}) if err != nil { - return false, nil, err + return nil, err } - return response.GetExists(), response.GetValue(), nil + if !response.GetExists() { + return nil, errKeyNotFound + } + return response.GetValue(), nil } // Set stores a key value pair in the cache service. diff --git a/internal/grpc/cache/mock/mock_cacher.go b/internal/grpc/cache/mock/mock_cacher.go new file mode 100644 index 000000000..08ca944ed --- /dev/null +++ b/internal/grpc/cache/mock/mock_cacher.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/pomerium/pomerium/internal/grpc/cache (interfaces: Cacher) + +// Package mock_cache is a generated GoMock package. +package mock_cache + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockCacher is a mock of Cacher interface +type MockCacher struct { + ctrl *gomock.Controller + recorder *MockCacherMockRecorder +} + +// MockCacherMockRecorder is the mock recorder for MockCacher +type MockCacherMockRecorder struct { + mock *MockCacher +} + +// NewMockCacher creates a new mock instance +func NewMockCacher(ctrl *gomock.Controller) *MockCacher { + mock := &MockCacher{ctrl: ctrl} + mock.recorder = &MockCacherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCacher) EXPECT() *MockCacherMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockCacher) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockCacherMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCacher)(nil).Close)) +} + +// Get mocks base method +func (m *MockCacher) Get(arg0 context.Context, arg1 string) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockCacherMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCacher)(nil).Get), arg0, arg1) +} + +// Set mocks base method +func (m *MockCacher) Set(arg0 context.Context, arg1 string, arg2 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set +func (mr *MockCacherMockRecorder) Set(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCacher)(nil).Set), arg0, arg1, arg2) +} diff --git a/internal/hashutil/hashutil.go b/internal/hashutil/hashutil.go new file mode 100644 index 000000000..8b2882cd4 --- /dev/null +++ b/internal/hashutil/hashutil.go @@ -0,0 +1,22 @@ +// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing +package hashutil + +import ( + "github.com/cespare/xxhash/v2" + "github.com/mitchellh/hashstructure" +) + +// Hash returns the xxhash value of an arbitrary value or struct. Returns 0 +// on error. NOT SUITABLE FOR CRYTOGRAPHIC HASHING. +// +// http://cyan4973.github.io/xxHash/ +func Hash(v interface{}) uint64 { + opts := &hashstructure.HashOptions{ + Hasher: xxhash.New(), + } + hash, err := hashstructure.Hash(v, opts) + if err != nil { + hash = 0 + } + return hash +} diff --git a/internal/hashutil/hashutil_test.go b/internal/hashutil/hashutil_test.go new file mode 100644 index 000000000..245ca3853 --- /dev/null +++ b/internal/hashutil/hashutil_test.go @@ -0,0 +1,37 @@ +// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing +package hashutil + +import "testing" + +func TestHash(t *testing.T) { + t.Parallel() + tests := []struct { + name string + v interface{} + want uint64 + }{ + {"string", "string", 6134271061086542852}, + {"num", 7, 609900476111905877}, + {"compound struct", struct { + NESCarts []string + numberOfCarts int + }{ + []string{"Battletoads", "Mega Man 1", "Clash at Demonhead"}, + 12, + }, + 9061978360207659575}, + {"compound struct with embedded func (errors!)", struct { + AnswerToEverythingFn func() int + }{ + func() int { return 42 }, + }, + 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Hash(tt.v); got != tt.want { + t.Errorf("Hash() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/identity/mock_provider.go b/internal/identity/mock_provider.go index c26ed5c6f..943d9f977 100644 --- a/internal/identity/mock_provider.go +++ b/internal/identity/mock_provider.go @@ -5,15 +5,13 @@ import ( "net/url" "golang.org/x/oauth2" - - "github.com/pomerium/pomerium/internal/sessions" ) // MockProvider provides a mocked implementation of the providers interface. type MockProvider struct { - AuthenticateResponse sessions.State + AuthenticateResponse oauth2.Token AuthenticateError error - RefreshResponse sessions.State + RefreshResponse oauth2.Token RefreshError error RevokeError error GetSignInURLResponse string @@ -22,12 +20,12 @@ type MockProvider struct { } // Authenticate is a mocked providers function. -func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { +func (mp MockProvider) Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) { return &mp.AuthenticateResponse, mp.AuthenticateError } // Refresh is a mocked providers function. -func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { +func (mp MockProvider) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) { return &mp.RefreshResponse, mp.RefreshError } diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index faa226207..7c4caea05 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -20,7 +20,6 @@ import ( "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/version" ) @@ -77,48 +76,36 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { // Authenticate creates an identity session with github from a authorization code, and follows up // call to the user and user group endpoint with the -func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { - resp, err := p.Oauth.Exchange(ctx, code) +func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) { + oauth2Token, err := p.Oauth.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("github: token exchange failed %v", err) } - s := &sessions.State{ - AccessToken: &oauth2.Token{ - AccessToken: resp.AccessToken, - TokenType: resp.TokenType, - }, - AccessTokenID: resp.AccessToken, - } - - err = p.updateSessionState(ctx, s) + err = p.updateSessionState(ctx, oauth2Token, v) if err != nil { return nil, err } - return s, nil + return oauth2Token, nil } // updateSessionState will get the user information from github and also retrieve the user's team(s) // // https://developer.github.com/v3/users/#get-the-authenticated-user -func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) error { - if s == nil || s.AccessToken == nil { - return errors.New("github: user session cannot be empty") - } - accessToken := s.AccessToken.AccessToken +func (p *Provider) updateSessionState(ctx context.Context, t *oauth2.Token, v interface{}) error { - err := p.userInfo(ctx, accessToken, s) + err := p.userInfo(ctx, t, v) if err != nil { return fmt.Errorf("github: could not retrieve user info %w", err) } - err = p.userEmail(ctx, accessToken, s) + err = p.userEmail(ctx, t, v) if err != nil { return fmt.Errorf("github: could not retrieve user email %w", err) } - err = p.userTeams(ctx, accessToken, s) + err = p.userTeams(ctx, t, v) if err != nil { return fmt.Errorf("github: could not retrieve groups %w", err) } @@ -127,14 +114,12 @@ func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) er } // Refresh renews a user's session by making a new userInfo request. -func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { - if s.AccessToken == nil { - return nil, errors.New("github: missing oauth2 access token") - } - if err := p.updateSessionState(ctx, s); err != nil { +func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) { + err := p.updateSessionState(ctx, t, v) + if err != nil { return nil, err } - return s, nil + return t, nil } // userTeams returns a slice of teams the user belongs by making a request @@ -142,7 +127,7 @@ func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.St // // https://developer.github.com/v3/teams/#list-user-teams // https://developer.github.com/v3/auth/ -func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) error { +func (p *Provider) userTeams(ctx context.Context, t *oauth2.Token, v interface{}) error { var response []struct { ID json.Number `json:"id"` @@ -154,20 +139,24 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) Privacy string `json:"privacy,omitempty"` } - headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)} + headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)} teamURL := githubAPIURL + teamPath err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response) if err != nil { return err } - log.Debug().Interface("teams", response).Msg("github: user teams") - s.Groups = nil - for _, org := range response { - s.Groups = append(s.Groups, org.ID.String()) + var out struct { + Groups []string `json:"groups"` } - - return nil + for _, org := range response { + out.Groups = append(out.Groups, org.ID.String()) + } + b, err := json.Marshal(out) + if err != nil { + return err + } + return json.Unmarshal(b, v) } // userEmail returns the primary email of the user by making @@ -175,7 +164,7 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) // // https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user // https://developer.github.com/v3/auth/ -func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State) error { +func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}) error { // response represents the github user email // https://developer.github.com/v3/users/emails/#response var response []struct { @@ -184,48 +173,67 @@ func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State) Primary bool `json:"primary"` Visibility string `json:"visibility"` } - headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)} + headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)} emailURL := githubAPIURL + emailPath err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response) if err != nil { return err } - + var out struct { + Email string `json:"email"` + Verified bool `json:"email_verified"` + } log.Debug().Interface("emails", response).Msg("github: user emails") for _, email := range response { if email.Primary && email.Verified { - s.Email = email.Email - s.EmailVerified = true - return nil + out.Email = email.Email + out.Verified = true + break } } - return nil + b, err := json.Marshal(out) + if err != nil { + return err + } + return json.Unmarshal(b, v) } -func (p *Provider) userInfo(ctx context.Context, at string, s *sessions.State) error { +func (p *Provider) userInfo(ctx context.Context, t *oauth2.Token, v interface{}) error { var response struct { ID int `json:"id"` Login string `json:"login"` Name string `json:"name"` - Email string `json:"email"` AvatarURL string `json:"avatar_url,omitempty"` } headers := map[string]string{ - "Authorization": fmt.Sprintf("token %s", at), + "Authorization": fmt.Sprintf("token %s", t.AccessToken), "Accept": "application/vnd.github.v3+json", } err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response) if err != nil { return err } + var out struct { + Subject string `json:"sub"` + Name string `json:"name,omitempty"` + User string `json:"user"` + Picture string `json:"picture,omitempty"` + // needs to be set manually + Expiry *jwt.NumericDate `json:"exp,omitempty"` + } - s.User = response.Login - s.Name = response.Name - s.Picture = response.AvatarURL + out.User = response.Login + out.Subject = response.Login + out.Name = response.Name + out.Picture = response.AvatarURL // set the session expiry - s.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline)) - return nil + out.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline)) + b, err := json.Marshal(out) + if err != nil { + return err + } + return json.Unmarshal(b, v) } // Revoke method will remove all the github grants the user diff --git a/internal/identity/oidc/azure/microsoft.go b/internal/identity/oidc/azure/microsoft.go index c6bbee8b7..736a2f5c2 100644 --- a/internal/identity/oidc/azure/microsoft.go +++ b/internal/identity/oidc/azure/microsoft.go @@ -5,7 +5,7 @@ package azure import ( "context" - "errors" + "encoding/json" "fmt" "net/http" "time" @@ -16,7 +16,6 @@ import ( "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/version" ) @@ -60,10 +59,7 @@ func (p *Provider) GetSignInURL(state string) string { // `Directory.Read.All` is required. // https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0 // https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0 -func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { - if s == nil || s.AccessToken == nil { - return nil, errors.New("identity/azure: session cannot be nil") - } +func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error { var response struct { Groups []struct { ID string `json:"id"` @@ -73,15 +69,23 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, GroupTypes []string `json:"groupTypes,omitempty"` } `json:"value"` } - headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)} + headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)} err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response) if err != nil { - return nil, err + return err + } + + log.Debug().Interface("response", response).Msg("microsoft: groups") + var out struct { + Groups []string `json:"groups"` } - var groups []string for _, group := range response.Groups { - log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("microsoft: group") - groups = append(groups, group.ID) + out.Groups = append(out.Groups, group.ID) } - return groups, nil + b, err := json.Marshal(out) + if err != nil { + return err + } + + return json.Unmarshal(b, v) } diff --git a/internal/identity/oidc/errors.go b/internal/identity/oidc/errors.go index 197269b3f..be4ee9743 100644 --- a/internal/identity/oidc/errors.go +++ b/internal/identity/oidc/errors.go @@ -1,12 +1,14 @@ package oidc -import "errors" +import ( + "errors" +) -// ErrRevokeNotImplemented error type when Revoke method is not implemented -// by an identity provider +// ErrRevokeNotImplemented is returned when revoke is not implemented +// by an identity provider. var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented") -// ErrSignoutNotImplemented error type when end session is not implemented +// ErrSignoutNotImplemented is returned when end session is not implemented // by an identity provider // https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented") @@ -14,3 +16,13 @@ var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implem // ErrMissingProviderURL is returned when an identity provider requires a provider url // does not receive one. var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url") + +// ErrMissingIDToken is returned when (usually on refresh) and identity provider +// failed to include an id_token in a oauth2 token. +var ErrMissingIDToken = errors.New("identity/oidc: missing id_token") + +// ErrMissingRefreshToken is returned if no refresh token was found. +var ErrMissingRefreshToken = errors.New("identity/oidc: missing refresh token") + +// ErrMissingAccessToken is returned when no access token was found. +var ErrMissingAccessToken = errors.New("identity/oidc: missing access token") diff --git a/internal/identity/oidc/gitlab/gitlab.go b/internal/identity/oidc/gitlab/gitlab.go index 9cc8f1144..d86b22277 100644 --- a/internal/identity/oidc/gitlab/gitlab.go +++ b/internal/identity/oidc/gitlab/gitlab.go @@ -6,7 +6,6 @@ package gitlab import ( "context" "encoding/json" - "errors" "fmt" "net/http" @@ -15,14 +14,14 @@ import ( "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/version" + "golang.org/x/oauth2" ) -// Name identifies the GitLab identity provider +// Name identifies the GitLab identity provider. const Name = "gitlab" -var defaultScopes = []string{oidc.ScopeOpenID, "read_api", "read_user", "profile", "email"} +var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email", "api"} const ( defaultProviderURL = "https://gitlab.com" @@ -64,11 +63,7 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { // // Returns 20 results at a time because the API results are paginated. // https://docs.gitlab.com/ee/api/groups.html#list-groups -func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { - if s == nil || s.AccessToken == nil { - return nil, errors.New("gitlab: user session cannot be empty") - } - +func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error { var response []struct { ID json.Number `json:"id"` Name string `json:"name,omitempty"` @@ -81,17 +76,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, FullName string `json:"full_name,omitempty"` FullPath string `json:"full_path,omitempty"` } - headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)} + headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)} err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response) if err != nil { - return nil, err + return err } - - var groups []string log.Debug().Interface("response", response).Msg("gitlab: groups") + var out struct { + Groups []string `json:"groups"` + } for _, group := range response { - groups = append(groups, group.ID.String()) + out.Groups = append(out.Groups, group.ID.String()) + } + b, err := json.Marshal(out) + if err != nil { + return err } - return groups, nil + return json.Unmarshal(b, v) } diff --git a/internal/identity/oidc/google/google.go b/internal/identity/oidc/google/google.go index b2d4191e5..d7ef90e27 100644 --- a/internal/identity/oidc/google/google.go +++ b/internal/identity/oidc/google/google.go @@ -8,6 +8,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" oidc "github.com/coreos/go-oidc" @@ -18,7 +19,6 @@ import ( "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" ) const ( @@ -54,36 +54,35 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err) } p.Provider = genericOidc - - // if service account set, configure admin sdk calls - if o.ServiceAccount != "" { - apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount) - if err != nil { - return nil, fmt.Errorf("google: could not decode service account json %w", err) - } - // Required scopes for groups api - // https://developers.google.com/admin-sdk/directory/v1/reference/groups/list - conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) - if err != nil { - return nil, fmt.Errorf("google: failed making jwt config from json %w", err) - } - var credentialsFile struct { - ImpersonateUser string `json:"impersonate_user"` - } - if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil { - return nil, err - } - conf.Subject = credentialsFile.ImpersonateUser - client := conf.Client(context.TODO()) - p.apiClient, err = admin.New(client) - if err != nil { - return nil, fmt.Errorf("google: failed creating admin service %w", err) - } - p.UserGroupFn = p.UserGroups - } else { - log.Warn().Msg("google: no service account, cannot retrieve groups") + if o.ServiceAccount == "" { + log.Warn().Msg("google: no service account, will not fetch groups") + return &p, nil } + apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount) + if err != nil { + return nil, fmt.Errorf("google: could not decode service account json %w", err) + } + // Required scopes for groups api + // https://developers.google.com/admin-sdk/directory/v1/reference/groups/list + conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) + if err != nil { + return nil, fmt.Errorf("google: failed making jwt config from json %w", err) + } + var credentialsFile struct { + ImpersonateUser string `json:"impersonate_user"` + } + if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil { + return nil, err + } + conf.Subject = credentialsFile.ImpersonateUser + client := conf.Client(context.TODO()) + p.apiClient, err = admin.New(client) + if err != nil { + return nil, fmt.Errorf("google: failed creating admin service %w", err) + } + p.UserGroupFn = p.UserGroups + return &p, nil } @@ -106,17 +105,34 @@ func (p *Provider) GetSignInURL(state string) string { // NOTE: groups via Directory API is limited to 1 QPS! // https://developers.google.com/admin-sdk/directory/v1/reference/groups/list // https://developers.google.com/admin-sdk/directory/v1/limits -func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { - var groups []string - if p.apiClient != nil { - req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100) - resp, err := req.Do() - if err != nil { - return nil, fmt.Errorf("google: group api request failed %w", err) - } - for _, group := range resp.Groups { - groups = append(groups, group.Email) - } +func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error { + if p.apiClient == nil { + return errors.New("google: trying to fetch groups, but no api client") } - return groups, nil + s, err := p.GetSubject(v) + if err != nil { + return err + } + var out struct { + Groups []string `json:"groups"` + } + req := p.apiClient.Groups.List().Context(ctx).UserKey(s) + err = req.Pages(ctx, func(resp *admin.Groups) error { + for _, group := range resp.Groups { + out.Groups = append(out.Groups, group.Email) + } + return nil + }) + if err != nil { + return err + } + _, err = req.Do() + if err != nil { + return fmt.Errorf("google: group api request failed %w", err) + } + b, err := json.Marshal(out) + if err != nil { + return err + } + return json.Unmarshal(b, v) } diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 68db63048..d18b4cd5b 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -5,6 +5,7 @@ package oidc import ( "context" + "encoding/json" "errors" "fmt" "net/http" @@ -15,12 +16,11 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity/oauth" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" ) -// Name identifies the generic OpenID Connect provider +// Name identifies the generic OpenID Connect provider. const Name = "oidc" var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"} @@ -37,11 +37,6 @@ type Provider struct { // client application information and the server's endpoint URLs. Oauth *oauth2.Config - // UserInfoURL specifies the endpoint responsible for returning claims - // about the authenticated End-User. - // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo - UserInfoURL string `json:"userinfo_endpoint,omitempty"` - // RevocationURL is the location of the OAuth 2.0 token revocation endpoint. // https://tools.ietf.org/html/rfc7009 RevocationURL string `json:"revocation_endpoint,omitempty"` @@ -53,7 +48,7 @@ type Provider struct { // UserGroupFn is, if set, used to return a slice of group IDs the // user is a member of - UserGroupFn func(context.Context, *sessions.State) ([]string, error) + UserGroupFn func(context.Context, *oauth2.Token, interface{}) error } // New creates a new instance of a generic OpenID Connect provider. @@ -100,81 +95,89 @@ func (p *Provider) GetSignInURL(state string) string { // Authenticate converts an authorization code returned from the identity // provider into a token which is then converted into a user session. -func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { +func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) { + // Exchange converts an authorization code into a token. oauth2Token, err := p.Oauth.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err) } - idToken, err := p.IdentityFromToken(ctx, oauth2Token) + + idToken, err := p.getIDToken(ctx, oauth2Token) if err != nil { return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err) } - aud, err := urlutil.ParseAndValidateURL(p.Oauth.RedirectURL) + // hydrate `v` using claims inside the returned `id_token` + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + if err := idToken.Claims(v); err != nil { + return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err) + } + + if err := p.updateUserInfo(ctx, oauth2Token, v); err != nil { + return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err) + } + + return oauth2Token, nil +} + +// updateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any +// groups endpoint (non-spec) to populate the rest of the user's information. +// +// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo +func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error { + userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t)) if err != nil { - return nil, fmt.Errorf("identity/oidc: bad redirect uri: %w", err) + return fmt.Errorf("identity/oidc: user info endpoint: %w", err) } - - s, err := sessions.NewStateFromTokens(idToken, oauth2Token, aud.Hostname()) - if err != nil { - return nil, err + if err := userInfo.Claims(v); err != nil { + return fmt.Errorf("identity/oidc: failed parsing user info endpoint claims: %w", err) } - - if err := p.Provider.Claims(&p); err == nil && p.UserInfoURL != "" { - userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) - if err != nil { - return nil, fmt.Errorf("identity/oidc: could not retrieve user info %w", err) - } - if err := userInfo.Claims(&s); err != nil { - return nil, fmt.Errorf("identity/oidc: could not parse user claims %w", err) - } - } - if p.UserGroupFn != nil { - s.Groups, err = p.UserGroupFn(ctx, s) - if err != nil { - return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err) + if err := p.UserGroupFn(ctx, t, v); err != nil { + return fmt.Errorf("identity/oidc: could not retrieve groups: %w", err) } } - return s, nil + return nil } // Refresh renews a user's session using an oidc refresh token without reprompting the user. // Group membership is also refreshed. // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens -func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { - if s.AccessToken == nil || s.AccessToken.RefreshToken == "" { - return nil, errors.New("internal/oidc: missing refresh token") +func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) { + if t == nil { + return nil, ErrMissingAccessToken + } + if t.RefreshToken == "" { + return nil, ErrMissingRefreshToken + } + var err error + newToken, err := p.Oauth.TokenSource(ctx, t).Token() + if err != nil { + return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) } - t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken} - oauthToken, err := p.Oauth.TokenSource(ctx, &t).Token() - if err != nil { - return nil, fmt.Errorf("internal/oidc: refresh failed %w", err) - } - idToken, err := p.IdentityFromToken(ctx, oauthToken) - if err != nil { - return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err) - } - if err := s.UpdateState(idToken, oauthToken); err != nil { - return nil, fmt.Errorf("internal/oidc: state update failed %w", err) - } - if p.UserGroupFn != nil { - s.Groups, err = p.UserGroupFn(ctx, s) - if err != nil { - return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err) + // Many identity providers _will not_ return `id_token` on refresh + // https://github.com/FusionAuth/fusionauth-issues/issues/110#issuecomment-481526544 + idToken, err := p.getIDToken(ctx, newToken) + if err == nil { + if err := idToken.Claims(v); err != nil { + return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err) } } - return s, nil + if err := p.updateUserInfo(ctx, newToken, v); err != nil { + return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err) + } + return newToken, nil } -// IdentityFromToken takes an identity provider issued JWT as input ('id_token') -// and returns a session state. The provided token's audience ('aud') must -// match Pomerium's client_id. -func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) { +// getIDToken returns the raw jwt payload for `id_token` from the oauth2 token +// returned following oidc code flow +// +// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse +func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) { rawIDToken, ok := t.Extra("id_token").(string) if !ok { - return nil, fmt.Errorf("internal/oidc: id_token not found") + return nil, ErrMissingIDToken } return p.Verifier.Verify(ctx, rawIDToken) } @@ -183,13 +186,16 @@ func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_ // support revocation an error is thrown. // // https://tools.ietf.org/html/rfc7009#section-2.1 -func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error { +func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error { if p.RevocationURL == "" { return ErrRevokeNotImplemented } + if t == nil { + return ErrMissingAccessToken + } params := url.Values{} - params.Add("token", token.AccessToken) + params.Add("token", t.AccessToken) params.Add("token_type_hint", "access_token") // Some providers like okta / onelogin require "client authentication" // https://developer.okta.com/docs/reference/api/oidc/#client-secret @@ -198,7 +204,7 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error { params.Add("client_secret", p.Oauth.ClientSecret) err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil) - if err != nil && err != httputil.ErrTokenRevoked { + if err != nil && errors.Is(err, httputil.ErrTokenRevoked) { return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err) } @@ -214,3 +220,20 @@ func (p *Provider) LogOut() (*url.URL, error) { } return urlutil.ParseAndValidateURL(p.EndSessionURL) } + +// GetSubject gets the RFC 7519 Subject claim (`sub`) from a +func (p *Provider) GetSubject(v interface{}) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", err + } + var s struct { + Subject string `json:"sub"` + } + + err = json.Unmarshal(b, &s) + if err != nil { + return "", err + } + return s.Subject, nil +} diff --git a/internal/identity/oidc/okta/okta.go b/internal/identity/oidc/okta/okta.go index 1c8efc4e7..261c7e7df 100644 --- a/internal/identity/oidc/okta/okta.go +++ b/internal/identity/oidc/okta/okta.go @@ -5,6 +5,7 @@ package okta import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -13,9 +14,9 @@ import ( "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" + "golang.org/x/oauth2" ) const ( @@ -64,7 +65,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { // UserGroups fetches the groups of which the user is a member // https://developer.okta.com/docs/reference/api/users/#get-user-s-groups -func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { +func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error { + s, err := p.GetSubject(v) + if err != nil { + return err + } var response []struct { ID string `json:"id"` Profile struct { @@ -74,15 +79,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, } headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)} - uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject) - err := httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response) + uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s) + err = httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response) if err != nil { - return nil, err + return err + } + log.Debug().Interface("response", response).Msg("okta: groups") + var out struct { + Groups []string `json:"groups"` } - var groups []string for _, group := range response { - log.Debug().Interface("group", group).Msg("okta: group") - groups = append(groups, group.ID) + out.Groups = append(out.Groups, group.ID) } - return groups, nil + b, err := json.Marshal(out) + if err != nil { + return err + } + + return json.Unmarshal(b, v) } diff --git a/internal/identity/oidc/onelogin/onelogin.go b/internal/identity/oidc/onelogin/onelogin.go index bccd13d5e..59e3d58ec 100644 --- a/internal/identity/oidc/onelogin/onelogin.go +++ b/internal/identity/oidc/onelogin/onelogin.go @@ -5,17 +5,15 @@ package onelogin import ( "context" - "errors" "fmt" "net/http" - "time" oidc "github.com/coreos/go-oidc" + "golang.org/x/oauth2" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/version" ) @@ -55,24 +53,10 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { // UserGroups returns a slice of group names a given user is in. // https://developers.onelogin.com/openid-connect/api/user-info -func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { - if s == nil || s.AccessToken == nil { - return nil, errors.New("identity/onelogin: session cannot be nil") +func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error { + if t == nil { + return pom_oidc.ErrMissingAccessToken } - var response struct { - User string `json:"sub"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Name string `json:"name"` - UpdatedAt time.Time `json:"updated_at"` - GivenName string `json:"given_name"` - FamilyName string `json:"family_name"` - Groups []string `json:"groups"` - } - headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)} - err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response) - if err != nil { - return nil, err - } - return response.Groups, nil + headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)} + return httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, v) } diff --git a/internal/identity/providers.go b/internal/identity/providers.go index e375d8be7..874ef3ebe 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -17,25 +17,24 @@ import ( "github.com/pomerium/pomerium/internal/identity/oidc/google" "github.com/pomerium/pomerium/internal/identity/oidc/okta" "github.com/pomerium/pomerium/internal/identity/oidc/onelogin" - "github.com/pomerium/pomerium/internal/sessions" ) var ( // compile time assertions that providers are satisfying the interface _ Authenticator = &azure.Provider{} - _ Authenticator = &gitlab.Provider{} _ Authenticator = &github.Provider{} + _ Authenticator = &gitlab.Provider{} _ Authenticator = &google.Provider{} + _ Authenticator = &MockProvider{} _ Authenticator = &oidc.Provider{} _ Authenticator = &okta.Provider{} _ Authenticator = &onelogin.Provider{} - _ Authenticator = &MockProvider{} ) // Authenticator is an interface representing the ability to authenticate with an identity provider. type Authenticator interface { - Authenticate(context.Context, string) (*sessions.State, error) - Refresh(context.Context, *sessions.State) (*sessions.State, error) + Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) + Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) Revoke(context.Context, *oauth2.Token) error GetSignInURL(state string) string LogOut() (*url.URL, error) diff --git a/internal/kv/autocache/autocache.go b/internal/kv/autocache/autocache.go index 2e8360b50..d3a7785b0 100644 --- a/internal/kv/autocache/autocache.go +++ b/internal/kv/autocache/autocache.go @@ -37,6 +37,9 @@ type Store struct { srv *http.Server } +// ErrCacheMiss is returned when the cache misses for a given key. +var ErrCacheMiss = errors.New("cache miss") + // Options represent autocache options. type Options struct { Addr string @@ -60,7 +63,7 @@ var DefaultOptions = &Options{ GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error { b := fromContext(ctx) if len(b) == 0 { - return fmt.Errorf("autocache: empty ctx for id: %s", id) + return fmt.Errorf("autocache: id %s : %w", id, ErrCacheMiss) } if err := dest.SetBytes(b); err != nil { return fmt.Errorf("autocache: sink error %w", err) diff --git a/internal/sessions/cache/cache_store.go b/internal/sessions/cache/cache_store.go deleted file mode 100644 index 57ab3ef7f..000000000 --- a/internal/sessions/cache/cache_store.go +++ /dev/null @@ -1,95 +0,0 @@ -// Package cache provides a remote cache based implementation of session store -// and loader. See pomerium's cache service for more details. -package cache - -import ( - "errors" - "fmt" - "net/http" - - "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/grpc/cache/client" - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" -) - -var _ sessions.SessionStore = &Store{} -var _ sessions.SessionLoader = &Store{} - -// Store implements the session store interface using a cache service. -type Store struct { - cache client.Cacher - encoder encoding.MarshalUnmarshaler - queryParam string - wrappedStore sessions.SessionStore -} - -// Options represent cache store's available configurations. -type Options struct { - Cache client.Cacher - Encoder encoding.MarshalUnmarshaler - QueryParam string - WrappedStore sessions.SessionStore -} - -var defaultOptions = &Options{ - QueryParam: "cache_store_key", -} - -// NewStore creates a new cache -func NewStore(o *Options) *Store { - if o.QueryParam == "" { - o.QueryParam = defaultOptions.QueryParam - } - return &Store{ - cache: o.Cache, - encoder: o.Encoder, - queryParam: o.QueryParam, - wrappedStore: o.WrappedStore, - } -} - -// LoadSession looks for a preset query parameter in the request body -// representing the key to lookup from the cache. -func (s *Store) LoadSession(r *http.Request) (string, error) { - // look for our cache's key in the default query param - sessionID := r.URL.Query().Get(s.queryParam) - if sessionID == "" { - return "", sessions.ErrNoSessionFound - } - exists, val, err := s.cache.Get(r.Context(), sessionID) - if err != nil { - log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader") - return "", err - } - if !exists { - return "", sessions.ErrNoSessionFound - } - - return string(val), nil -} - -// ClearSession clears the session from the wrapped store. -func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) { - s.wrappedStore.ClearSession(w, r) -} - -// SaveSession saves the session to the cache, and wrapped store. -func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error { - err := s.wrappedStore.SaveSession(w, r, x) - if err != nil { - return fmt.Errorf("sessions/cache: wrapped store save error %w", err) - } - - state, ok := x.(*sessions.State) - if !ok { - return errors.New("sessions/cache: cannot cache non state type") - } - - data, err := s.encoder.Marshal(&state) - if err != nil { - return fmt.Errorf("sessions/cache: marshal %w", err) - } - - return s.cache.Set(r.Context(), state.AccessTokenID, data) -} diff --git a/internal/sessions/cache/cache_store_test.go b/internal/sessions/cache/cache_store_test.go deleted file mode 100644 index eae675a99..000000000 --- a/internal/sessions/cache/cache_store_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package cache - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/internal/cryptutil" - "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/encoding/ecjson" - mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock" - "github.com/pomerium/pomerium/internal/grpc/cache/client" - "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/sessions/cookie" - "github.com/pomerium/pomerium/internal/sessions/mock" - "gopkg.in/square/go-jose.v2/jwt" -) - -type mockCache struct { - Key string - KeyExists bool - Value []byte - Err error -} - -func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) { - return mc.KeyExists, mc.Value, mc.Err -} -func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error { - return mc.Err -} -func (mc *mockCache) Close() error { - return mc.Err -} - -func TestNewStore(t *testing.T) { - - tests := []struct { - name string - Options *Options - State *sessions.State - - wantErr bool - wantLoadErr bool - wantStatus int - }{ - {"simple good", - &Options{ - Cache: &mockCache{}, - WrappedStore: &mock.Store{}, - Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}, - }, - &sessions.State{Email: "user@domain.com", User: "user"}, - false, false, - http.StatusOK}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewStore(tt.Options) - - r := httptest.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr { - t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) - } - - r = httptest.NewRequest("GET", "/", nil) - w = httptest.NewRecorder() - - got.ClearSession(w, r) - status := w.Result().StatusCode - if diff := cmp.Diff(status, tt.wantStatus); diff != "" { - t.Errorf("ClearSession() = %v", diff) - } - }) - } -} - -func TestStore_SaveSession(t *testing.T) { - cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) - encoder := ecjson.New(cipher) - if err != nil { - t.Fatal(err) - } - cs, err := cookie.NewStore(&cookie.Options{ - Name: "_pomerium", - }, encoder) - if err != nil { - t.Fatal(err) - } - tests := []struct { - name string - Options *Options - - x interface{} - wantErr bool - }{ - {"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false}, - {"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true}, - {"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true}, - {"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := tt.Options - if o.WrappedStore == nil { - o.WrappedStore = cs - - } - cacheStore := NewStore(tt.Options) - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Accept", "application/json") - w := httptest.NewRecorder() - if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr { - t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr) - } - - }) - } -} - -func TestStore_LoadSession(t *testing.T) { - key := cryptutil.NewBase64Key() - tests := []struct { - name string - state *sessions.State - cache client.Cacher - encoder encoding.MarshalUnmarshaler - queryParam string - wrappedStore sessions.SessionStore - wantErr bool - }{ - {"good", - &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}, - &mockCache{KeyExists: true}, - mock_encoder.Encoder{MarshalResponse: []byte("ok")}, - defaultOptions.QueryParam, - &mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}}, - false}, - {"missing param with key", - &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}, - &mockCache{KeyExists: true}, - mock_encoder.Encoder{MarshalResponse: []byte("ok")}, - "bad_query", - &mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}}, - true}, - {"doesn't exist", - &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}, - &mockCache{KeyExists: false}, - mock_encoder.Encoder{MarshalResponse: []byte("ok")}, - defaultOptions.QueryParam, - &mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}}, - true}, - {"retrieval error", - &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}, - &mockCache{Err: errors.New("err")}, - mock_encoder.Encoder{MarshalResponse: []byte("ok")}, - defaultOptions.QueryParam, - &mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}}, - true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &Store{ - cache: tt.cache, - encoder: tt.encoder, - queryParam: tt.queryParam, - wrappedStore: tt.wrappedStore, - } - - r := httptest.NewRequest(http.MethodGet, "/", nil) - q := r.URL.Query() - - q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID) - r.URL.RawQuery = q.Encode() - r.Header.Set("Accept", "application/json") - - _, err := s.LoadSession(r) - if (err != nil) != tt.wantErr { - t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} diff --git a/internal/sessions/state.go b/internal/sessions/state.go index 457fcf925..6d5d48af6 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -1,15 +1,11 @@ package sessions import ( - "encoding/json" - "errors" "fmt" "strings" "time" - "github.com/cespare/xxhash/v2" - oidc "github.com/coreos/go-oidc" - "github.com/mitchellh/hashstructure" + "github.com/pomerium/pomerium/internal/hashutil" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -27,6 +23,9 @@ type State struct { NotBefore *jwt.NumericDate `json:"nbf,omitempty"` IssuedAt *jwt.NumericDate `json:"iat,omitempty"` ID string `json:"jti,omitempty"` + // At_hash is an OPTIONAL Access Token hash value + // https://ldapwiki.com/wiki/At_hash + AccessTokenHash string `json:"at_hash,omitempty"` // core pomerium identity claims ; not standard to RFC 7519 Email string `json:"email"` @@ -48,84 +47,24 @@ type State struct { // Programmatic whether this state is used for machine-to-machine // programatic access. Programmatic bool `json:"programatic"` - - AccessToken *oauth2.Token `json:"act,omitempty"` - AccessTokenID string `json:"ati,omitempty"` - - idToken *oidc.IDToken -} - -// NewStateFromTokens returns a session state built from oidc and oauth2 -// tokens as part of OpenID Connect flow with a new audience appended to the -// audience claim. -func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) { - if idToken == nil { - return nil, errors.New("sessions: oidc id token missing") - } - if accessToken == nil { - return nil, errors.New("sessions: oauth2 token missing") - } - s := &State{} - if err := idToken.Claims(s); err != nil { - return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err) - } - s.Audience = []string{audience} - s.idToken = idToken - s.AccessToken = accessToken - s.AccessTokenID = s.accessTokenHash() - return s, nil -} - -// UpdateState updates the current state given a new identity (oidc) and authorization -// (oauth2) tokens following a oidc refresh. NB, unlike during authentication, -// refresh typically provides fewer claims in the token so we want to build from -// our previous state. -func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error { - if idToken == nil { - return errors.New("sessions: oidc id token missing") - } - if accessToken == nil { - return errors.New("sessions: oauth2 token missing") - } - audience := append(s.Audience[:0:0], s.Audience...) - s.AccessToken = accessToken - if err := idToken.Claims(s); err != nil { - return fmt.Errorf("sessions: update state failed %w", err) - } - s.Audience = audience - s.Expiry = jwt.NewNumericDate(accessToken.Expiry) - s.AccessTokenID = s.accessTokenHash() - return nil } // NewSession updates issuer, audience, and issuance timestamps but keeps // parent expiry. -func (s State) NewSession(issuer string, audience []string) *State { - s.IssuedAt = jwt.NewNumericDate(timeNow()) - s.NotBefore = s.IssuedAt - s.Audience = audience - s.Issuer = issuer - return &s -} - -// RouteSession creates a route session with access tokens stripped. -func (s State) RouteSession() *State { - s.AccessToken = nil - return &s +func NewSession(s *State, issuer string, audience []string, accessToken *oauth2.Token) State { + newState := *s + newState.IssuedAt = jwt.NewNumericDate(timeNow()) + newState.NotBefore = newState.IssuedAt + newState.Audience = audience + newState.Issuer = issuer + newState.AccessTokenHash = fmt.Sprintf("%x", hashutil.Hash(accessToken)) + newState.Expiry = jwt.NewNumericDate(accessToken.Expiry) + return newState } // IsExpired returns true if the users's session is expired. func (s *State) IsExpired() bool { - - if s.Expiry != nil && timeNow().After(s.Expiry.Time()) { - return true - } - - if s.AccessToken != nil && timeNow().After(s.AccessToken.Expiry) { - return true - } - - return false + return s.Expiry != nil && timeNow().After(s.Expiry.Time()) } // Impersonating returns if the request is impersonating. @@ -133,23 +72,6 @@ func (s *State) Impersonating() bool { return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0 } -// RequestEmail is the email to make the request as. -func (s *State) RequestEmail() string { - if s.ImpersonateEmail != "" { - return s.ImpersonateEmail - } - return s.Email -} - -// RequestGroups returns the groups of the Groups making the request; uses -// impersonating user if set. -func (s *State) RequestGroups() string { - if len(s.ImpersonateGroups) != 0 { - return strings.Join(s.ImpersonateGroups, ",") - } - return strings.Join(s.Groups, ",") -} - // SetImpersonation sets impersonation user and groups. func (s *State) SetImpersonation(email, groups string) { s.ImpersonateEmail = email @@ -159,34 +81,3 @@ func (s *State) SetImpersonation(email, groups string) { s.ImpersonateGroups = strings.Split(groups, ",") } } - -func (s *State) accessTokenHash() string { - hash, err := hashstructure.Hash( - s.AccessToken, - &hashstructure.HashOptions{Hasher: xxhash.New()}) - if err != nil { - return "" - } - return fmt.Sprintf("%x", hash) -} - -// UnmarshalJSON parses the JSON-encoded session state. -// TODO(BDD): remove in v0.8.0 -func (s *State) UnmarshalJSON(b []byte) error { - type Alias State - t := &struct { - *Alias - OldToken *oauth2.Token `json:"access_token,omitempty"` // < v0.5.0 - }{ - Alias: (*Alias)(s), - } - if err := json.Unmarshal(b, &t); err != nil { - return err - } - if t.AccessToken == nil { - t.AccessToken = t.OldToken - } - *s = *(*State)(t.Alias) - - return nil -} diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index d9db5e4df..8e8bb612c 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -5,8 +5,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -38,12 +36,6 @@ func TestState_Impersonating(t *testing.T) { if got := s.Impersonating(); got != tt.want { t.Errorf("State.Impersonating() = %v, want %v", got, tt.want) } - if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail { - t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail) - } - if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups { - t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups) - } }) } } @@ -63,16 +55,14 @@ func TestState_IsExpired(t *testing.T) { }{ {"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false}, {"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true}, - {"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &State{ - Audience: tt.Audience, - Expiry: tt.Expiry, - NotBefore: tt.NotBefore, - IssuedAt: tt.IssuedAt, - AccessToken: tt.AccessToken, + Audience: tt.Audience, + Expiry: tt.Expiry, + NotBefore: tt.NotBefore, + IssuedAt: tt.IssuedAt, } if exp := s.IsExpired(); exp != tt.wantErr { t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr) @@ -80,67 +70,3 @@ func TestState_IsExpired(t *testing.T) { }) } } - -func TestState_RouteSession(t *testing.T) { - now := time.Now() - timeNow = func() time.Time { - return now - } - tests := []struct { - name string - Issuer string - Audience jwt.Audience - Expiry *jwt.NumericDate - AccessToken *oauth2.Token - - issuer string - - audience []string - - want *State - }{ - {"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow())}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := State{ - Issuer: tt.Issuer, - Audience: tt.Audience, - Expiry: tt.Expiry, - AccessToken: tt.AccessToken, - } - cmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(State{}), - } - got := s.NewSession(tt.issuer, tt.audience) - got = got.RouteSession() - if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { - t.Errorf("State.RouteSession() = %s", diff) - } - - }) - } -} - -func TestState_accessTokenHash(t *testing.T) { - t.Parallel() - tests := []struct { - name string - state State - want string - }{ - {"empty access token", State{}, "34c96acdcadb1bbb"}, - {"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"}, - {"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"}, - {"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"}, - {"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &tt.state - if got := s.accessTokenHash(); got != tt.want { - t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 94d7768aa..524c0628a 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -17,6 +17,7 @@ const ( QueryRefreshToken = "pomerium_refresh_token" QueryAccessTokenID = "pomerium_session_access_token_id" QueryAudience = "pomerium_session_audience" + QueryProgrammaticToken = "pomerium_programmatic_token" ) // URL signature based query params used for verifying the authenticity of a URL.