mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
authenticate: save oauth2 tokens to cache (#698)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
ef399380b7
commit
666fd6aa35
31 changed files with 1127 additions and 1061 deletions
6
Makefile
6
Makefile
|
@ -38,6 +38,12 @@ GETENVOY_VERSION = v0.1.8
|
|||
all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet.
|
||||
|
||||
|
||||
.PHONY: generate-mocks
|
||||
generate-mocks: ## Generate mocks
|
||||
@echo "==> $@"
|
||||
@go run github.com/golang/mock/mockgen -destination authorize/evaluator/mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
|
||||
@go run github.com/golang/mock/mockgen -destination internal/grpc/cache/mock/mock_cacher.go github.com/pomerium/pomerium/internal/grpc/cache Cacher
|
||||
|
||||
.PHONY: build-deps
|
||||
build-deps: ## Install build dependencies
|
||||
@echo "==> $@"
|
||||
|
|
|
@ -17,12 +17,12 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/grpc"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cache"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
|
@ -93,7 +93,7 @@ type Authenticate struct {
|
|||
provider identity.Authenticator
|
||||
|
||||
// cacheClient is the interface for setting and getting sessions from a cache
|
||||
cacheClient client.Cacher
|
||||
cacheClient cache.Cacher
|
||||
|
||||
templates *template.Template
|
||||
}
|
||||
|
@ -106,12 +106,12 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
|
||||
// shared state encoder setup
|
||||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
|
||||
signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
|
||||
sharedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// private state encoder setup
|
||||
// private state encoder setup, used to encrypt oauth2 tokens
|
||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||
encryptedEncoder := ecjson.New(cookieCipher)
|
||||
|
@ -124,7 +124,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
Expire: opts.CookieExpire,
|
||||
}
|
||||
|
||||
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
|
||||
cookieStore, err := cookie.NewStore(cookieOptions, sharedEncoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -145,13 +145,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
|
||||
cacheClient := client.New(cacheConn)
|
||||
|
||||
cacheStore := cache.NewStore(&cache.Options{
|
||||
Cache: cacheClient,
|
||||
Encoder: encryptedEncoder,
|
||||
QueryParam: urlutil.QueryAccessTokenID,
|
||||
WrappedStore: cookieStore})
|
||||
|
||||
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
|
||||
qpStore := queryparam.NewStore(encryptedEncoder, urlutil.QueryProgrammaticToken)
|
||||
headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium)
|
||||
|
||||
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
|
@ -177,14 +171,14 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
// shared state
|
||||
sharedKey: opts.SharedKey,
|
||||
sharedCipher: sharedCipher,
|
||||
sharedEncoder: signedEncoder,
|
||||
sharedEncoder: sharedEncoder,
|
||||
// private state
|
||||
cookieSecret: decodedCookieSecret,
|
||||
cookieCipher: cookieCipher,
|
||||
cookieOptions: cookieOptions,
|
||||
sessionStore: cacheStore,
|
||||
sessionStore: cookieStore,
|
||||
encryptedEncoder: encryptedEncoder,
|
||||
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
|
||||
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
|
||||
// IdP
|
||||
provider: provider,
|
||||
// grpc client for cache
|
||||
|
|
|
@ -91,6 +91,10 @@ func TestNew(t *testing.T) {
|
|||
badGRPCConn.CacheURL = nil
|
||||
badGRPCConn.CookieName = "D"
|
||||
|
||||
emptyProviderURL := newTestOptions(t)
|
||||
emptyProviderURL.Provider = "oidc"
|
||||
emptyProviderURL.ProviderURL = ""
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *config.Options
|
||||
|
@ -103,6 +107,7 @@ func TestNew(t *testing.T) {
|
|||
{"bad cookie name", badCookieName, true},
|
||||
{"bad provider", badProvider, true},
|
||||
{"bad cache url", badGRPCConn, true},
|
||||
{"empty provider url", emptyProviderURL, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -23,6 +24,7 @@ import (
|
|||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Handler returns the authenticate service's handler chain.
|
||||
|
@ -80,19 +82,15 @@ func (a *Authenticate) Mount(r *mux.Router) {
|
|||
// session state is attached to the users's request context.
|
||||
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
jwt, err := sessions.FromContext(ctx)
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession")
|
||||
defer span.End()
|
||||
s, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
var s sessions.State
|
||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
if s.IsExpired() {
|
||||
ctx, err = a.refresh(w, r, &s)
|
||||
ctx, err = a.refresh(w, r, s)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
|
@ -106,18 +104,34 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
|
||||
defer span.End()
|
||||
newSession, err := a.provider.Refresh(ctx, s)
|
||||
accessToken, err := a.getAccessToken(ctx, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we are going to keep the same audiences for the refreshed token
|
||||
// otherwise this will be rewritten to be the ClientID of the provider
|
||||
oldAudience := s.Audience
|
||||
|
||||
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||
}
|
||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||
}
|
||||
newSession = newSession.NewSession(s.Issuer, s.Audience)
|
||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||
|
||||
newSession := sessions.NewSession(s, a.RedirectURL.Hostname(), oldAudience, newAccessToken)
|
||||
|
||||
encSession, err := a.sharedEncoder.Marshal(newSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error saving new session: %w", err)
|
||||
}
|
||||
|
||||
if err := a.setAccessToken(ctx, newAccessToken); err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error saving refreshed access token: %w", err)
|
||||
}
|
||||
// return the new session and add it to the current request context
|
||||
return sessions.NewContext(ctx, string(encSession), err), nil
|
||||
}
|
||||
|
@ -129,8 +143,11 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
|||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
// SignIn handles to authenticating a user.
|
||||
// SignIn handles authenticating a user.
|
||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
||||
defer span.End()
|
||||
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
|
@ -158,32 +175,30 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
jwtAudience = append(jwtAudience, fwdAuth)
|
||||
}
|
||||
|
||||
jwt, err := sessions.FromContext(r.Context())
|
||||
s, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
return err
|
||||
}
|
||||
var s sessions.State
|
||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
accessToken, err := a.getAccessToken(ctx, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// user impersonation
|
||||
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
|
||||
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
|
||||
}
|
||||
newSession := sessions.NewSession(s, a.RedirectURL.Host, jwtAudience, accessToken)
|
||||
|
||||
// re-persist the session, useful when session was evicted from session
|
||||
if err := a.sessionStore.SaveSession(w, r, &s); err != nil {
|
||||
if err := a.sessionStore.SaveSession(w, r, s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
newSession := s.NewSession(a.RedirectURL.Host, jwtAudience)
|
||||
|
||||
callbackParams := callbackURL.Query()
|
||||
|
||||
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
||||
newSession.Programmatic = true
|
||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||
encSession, err := a.encryptedEncoder.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
@ -192,7 +207,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
|
||||
// sign the route session, as a JWT
|
||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
|
||||
signedJWT, err := a.sharedEncoder.Marshal(newSession)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
@ -217,27 +232,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||
// Handles both GET and POST.
|
||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||
// no matter what happens, we want to clear the local session store
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
||||
defer span.End()
|
||||
|
||||
// no matter what happens, we want to clear the session store
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
|
||||
jwt, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
var s sessions.State
|
||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
redirectString := r.FormValue(urlutil.QueryRedirectURI)
|
||||
|
||||
// first, try to revoke the session if implemented
|
||||
err = a.provider.Revoke(r.Context(), s.AccessToken)
|
||||
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// next, try to build a logout url if implemented
|
||||
endSessionURL, err := a.provider.LogOut()
|
||||
if err == nil {
|
||||
params := url.Values{}
|
||||
|
@ -245,14 +245,29 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
endSessionURL.RawQuery = params.Encode()
|
||||
redirectString = endSessionURL.String()
|
||||
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||
}
|
||||
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(redirectString)
|
||||
httputil.Redirect(w, r, redirectString, http.StatusFound)
|
||||
|
||||
s, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||
return nil
|
||||
}
|
||||
|
||||
accessToken, err := a.getAccessToken(ctx, s)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting access token")
|
||||
return nil
|
||||
}
|
||||
|
||||
// first, try to revoke the session if implemented
|
||||
err = a.provider.Revoke(ctx, accessToken)
|
||||
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
|
||||
log.Warn().Err(err).Msg("authenticate.SignOut: failed revoking token")
|
||||
return nil
|
||||
}
|
||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -267,7 +282,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
|
||||
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
|
||||
// If request AJAX/XHR request, return a 401 instead .
|
||||
// If request AJAX/XHR request, return a 401 instead because the redirect
|
||||
// will almost certainly violate their CORs policy
|
||||
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
|
@ -290,7 +306,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
|
||||
redirect, err := a.getOAuthCallback(w, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth callback : %w", err)
|
||||
return fmt.Errorf("authenticate.OAuthCallback: %w", err)
|
||||
}
|
||||
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||
return nil
|
||||
|
@ -306,6 +322,9 @@ func (a *Authenticate) statusForErrorCode(errorCode string) int {
|
|||
}
|
||||
|
||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
||||
defer span.End()
|
||||
|
||||
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||
//
|
||||
// first, check if the identity provider returned an error
|
||||
|
@ -321,14 +340,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
|
||||
//
|
||||
// Exchange the supplied Authorization Code for a valid user session.
|
||||
session, err := a.provider.Authenticate(r.Context(), code)
|
||||
var s sessions.State
|
||||
accessToken, err := a.provider.Authenticate(ctx, code, &s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||
}
|
||||
|
||||
newState := sessions.NewSession(
|
||||
&s,
|
||||
a.RedirectURL.Hostname(),
|
||||
[]string{a.RedirectURL.Hostname()},
|
||||
accessToken)
|
||||
|
||||
// state includes a csrf nonce (validated by middleware) and redirect uri
|
||||
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("bad bytes: %w", err))
|
||||
}
|
||||
|
||||
// split state into concat'd components
|
||||
|
@ -357,8 +384,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// OK. Looks good so let's persist our user session
|
||||
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
||||
// Ok -- We've got a valid session here. Let's now persist the access
|
||||
// token to cache ...
|
||||
if err := a.setAccessToken(ctx, accessToken); err != nil {
|
||||
return nil, fmt.Errorf("failed saving access token: %w", err)
|
||||
}
|
||||
// ... and the user state to local storage.
|
||||
if err := a.sessionStore.SaveSession(w, r, &newState); err != nil {
|
||||
return nil, fmt.Errorf("failed saving new session: %w", err)
|
||||
}
|
||||
return redirectURL, nil
|
||||
|
@ -368,26 +400,32 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
// tokens and state with the identity provider. If successful, a new signed JWT
|
||||
// and refresh token (`refresh_token`) are returned as JSON
|
||||
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
|
||||
jwt, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
var s sessions.State
|
||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
newSession, err := a.provider.Refresh(r.Context(), &s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newSession = newSession.NewSession(s.Issuer, s.Audience)
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.RefreshAPI")
|
||||
defer span.End()
|
||||
|
||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||
s, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
|
||||
accessToken, err := a.getAccessToken(ctx, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routeNewSession := sessions.NewSession(s, a.RedirectURL.Hostname(), s.Audience, newAccessToken)
|
||||
|
||||
encSession, err := a.encryptedEncoder.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signedJWT, err := a.sharedEncoder.Marshal(routeNewSession)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -410,28 +448,79 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
|
|||
// Refresh is called by the proxy service to handle backend session refresh.
|
||||
//
|
||||
// NOTE: The actual refresh is handled as part of the "VerifySession"
|
||||
// middleware. This handler is responsible for creating a new route scoped
|
||||
// session and returning it.
|
||||
// middleware. This handler is simply responsible for returning that jwt.
|
||||
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
|
||||
jwt, err := sessions.FromContext(r.Context())
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.Refresh")
|
||||
defer span.End()
|
||||
jwt, err := sessions.FromContext(ctx)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
return fmt.Errorf("authenticate.Refresh: %w", err)
|
||||
}
|
||||
var s sessions.State
|
||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
|
||||
fmt.Fprint(w, jwt)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken gets an associated oauth2 access token from a session state
|
||||
func (a *Authenticate) getAccessToken(ctx context.Context, s *sessions.State) (*oauth2.Token, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "authenticate.getAccessToken")
|
||||
defer span.End()
|
||||
|
||||
var accessToken oauth2.Token
|
||||
tokenBytes, err := a.cacheClient.Get(ctx, s.AccessTokenHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Valid() {
|
||||
return &accessToken, nil // this token is still valid, use it!
|
||||
}
|
||||
tokenBytes, err = a.cacheClient.Get(ctx, a.timestampedHash(accessToken.RefreshToken))
|
||||
if err == nil {
|
||||
// we found another possibly newer access token associated with the
|
||||
// existing refresh token so let's try that.
|
||||
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
aud := strings.Split(r.FormValue(urlutil.QueryAudience), ",")
|
||||
routeSession := s.NewSession(r.Host, aud)
|
||||
routeSession.AccessTokenID = s.AccessTokenID
|
||||
return &accessToken, nil
|
||||
}
|
||||
|
||||
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
|
||||
func (a *Authenticate) setAccessToken(ctx context.Context, accessToken *oauth2.Token) error {
|
||||
encToken, err := a.encryptedEncoder.Marshal(accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// set this specific access token
|
||||
key := fmt.Sprintf("%x", hashutil.Hash(accessToken))
|
||||
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
|
||||
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
|
||||
}
|
||||
|
||||
// set this as the "latest" token for this access token
|
||||
key = a.timestampedHash(accessToken.RefreshToken)
|
||||
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
|
||||
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
|
||||
w.Write(signedJWT)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Authenticate) timestampedHash(s string) string {
|
||||
return fmt.Sprintf("%x-%v", hashutil.Hash(s), time.Now().Truncate(time.Minute).Unix())
|
||||
}
|
||||
|
||||
func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, error) {
|
||||
jwt, err := sessions.FromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
var s sessions.State
|
||||
if err := a.sharedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
|
|
@ -11,20 +11,24 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
mock_cache "github.com/pomerium/pomerium/internal/grpc/cache/mock"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
@ -115,23 +119,28 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
encoder encoding.MarshalUnmarshaler
|
||||
wantCode int
|
||||
}{
|
||||
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mc := mock_cache.NewMockCacher(ctrl)
|
||||
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||
|
||||
a := &Authenticate{
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
|
@ -144,6 +153,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
Name: "cookie",
|
||||
Domain: "foo",
|
||||
},
|
||||
cacheClient: mc,
|
||||
}
|
||||
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
||||
|
||||
|
@ -176,6 +186,7 @@ func uriParseHelper(s string) *url.URL {
|
|||
|
||||
func TestAuthenticate_SignOut(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
@ -190,19 +201,24 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
wantCode int
|
||||
wantBody string
|
||||
}{
|
||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
|
||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
|
||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
|
||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mc := mock_cache.NewMockCacher(ctrl)
|
||||
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||
a := &Authenticate{
|
||||
sessionStore: tt.sessionStore,
|
||||
provider: tt.provider,
|
||||
encryptedEncoder: mock.Encoder{},
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
sharedEncoder: mock.Encoder{},
|
||||
cacheClient: mc,
|
||||
}
|
||||
u, _ := url.Parse("/sign_out")
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
|
@ -256,40 +272,50 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
want string
|
||||
wantCode int
|
||||
}{
|
||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusUnauthorized},
|
||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}, AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
|
||||
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusUnauthorized},
|
||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
|
||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mc := mock_cache.NewMockCacher(ctrl)
|
||||
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signer, err := jws.NewHS256Signer(nil, "mock")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authURL, _ := url.Parse(tt.authenticateURL)
|
||||
a := &Authenticate{
|
||||
RedirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
RedirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
cacheClient: mc,
|
||||
encryptedEncoder: signer,
|
||||
}
|
||||
u, _ := url.Parse("/oauthGet")
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
params.Add("error", tt.paramErr)
|
||||
params.Add("code", tt.code)
|
||||
nonce := cryptutil.NewBase64Key() // mock csrf
|
||||
|
||||
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
|
||||
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
|
||||
|
||||
|
@ -336,15 +362,21 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
|
||||
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
|
||||
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusFound},
|
||||
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mc := mock_cache.NewMockCacher(ctrl)
|
||||
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -361,6 +393,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
encryptedEncoder: signer,
|
||||
cacheClient: mc,
|
||||
sharedEncoder: mock.Encoder{},
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, err := tt.session.LoadSession(r)
|
||||
|
@ -402,14 +436,20 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mc := mock_cache.NewMockCacher(ctrl)
|
||||
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -423,6 +463,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
|||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
cacheClient: mc,
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
|
@ -441,53 +482,111 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_Refresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
session sessions.SessionStore
|
||||
ctxError error
|
||||
session *sessions.State
|
||||
at *oauth2.Token
|
||||
|
||||
provider identity.Authenticator
|
||||
secretEncoder encoding.MarshalUnmarshaler
|
||||
sharedEncoder encoding.MarshalUnmarshaler
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
|
||||
{"good",
|
||||
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))},
|
||||
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
|
||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||
200},
|
||||
{"session and oauth2 expired",
|
||||
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
|
||||
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(-10 * time.Minute)},
|
||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||
200},
|
||||
{"session expired",
|
||||
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
|
||||
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
|
||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||
200},
|
||||
{"failed refresh",
|
||||
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
|
||||
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
|
||||
identity.MockProvider{RefreshError: errors.New("oh no")},
|
||||
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||
302},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mc := mock_cache.NewMockCacher(ctrl)
|
||||
// just enough is stubbed out here so we can use our own mock provider
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
out := fmt.Sprintf(`{"issuer":"http://%s"}`, r.Host)
|
||||
fmt.Fprintln(w, out)
|
||||
}))
|
||||
defer ts.Close()
|
||||
rURL := ts.URL
|
||||
a, err := New(config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
CookieSecret: cryptutil.NewBase64Key(),
|
||||
AuthenticateURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
Provider: "oidc",
|
||||
ClientID: "mock",
|
||||
ClientSecret: "mock",
|
||||
ProviderURL: rURL,
|
||||
AuthenticateCallbackPath: "mock",
|
||||
CookieName: "pomerium",
|
||||
Addr: ":0",
|
||||
CacheURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a := Authenticate{
|
||||
sharedKey: cryptutil.NewBase64Key(),
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
encryptedEncoder: tt.secretEncoder,
|
||||
sharedEncoder: tt.sharedEncoder,
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
a.cacheClient = mc
|
||||
a.provider = tt.provider
|
||||
|
||||
u, _ := url.Parse("/oauthGet")
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
destination := urlutil.NewSignedURL(a.sharedKey,
|
||||
&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
Path: "/.pomerium/refresh"})
|
||||
|
||||
u.RawQuery = params.Encode()
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, destination.String(), nil)
|
||||
|
||||
jwt, err := a.sharedEncoder.Marshal(tt.session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawToken, err := a.encryptedEncoder.Marshal(tt.at)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return(rawToken, nil).AnyTimes()
|
||||
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
a.cacheClient = mc
|
||||
|
||||
r.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", jwt))
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
|
||||
router := mux.NewRouter()
|
||||
a.Mount(router)
|
||||
router.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
|
||||
t.Errorf("Refresh() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
//go:generate mockgen -destination mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
|
||||
|
||||
// Package evaluator defines a Evaluator interfaces that can be implemented by
|
||||
// a policy evaluator framework.
|
||||
package evaluator
|
||||
|
|
|
@ -218,10 +218,6 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
|
|||
|
||||
// 1 - build a signed url to call refresh on authenticate service
|
||||
refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"})
|
||||
q := refreshURI.Query()
|
||||
q.Set(urlutil.QueryAccessTokenID, state.AccessTokenID) // hash value points to parent token
|
||||
q.Set(urlutil.QueryAudience, strings.Join(state.Audience, ",")) // request's audience, this route
|
||||
refreshURI.RawQuery = q.Encode()
|
||||
signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String()
|
||||
|
||||
// 2 - http call to authenticate service
|
||||
|
@ -229,6 +225,7 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", rawSession))
|
||||
req.Header.Set("X-Requested-With", "XmlHttpRequest")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
|
|
30
cache/grpc_test.go
vendored
30
cache/grpc_test.go
vendored
|
@ -11,6 +11,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache"
|
||||
|
@ -48,11 +49,8 @@ func TestCache_Get_and_Set(t *testing.T) {
|
|||
&cache.SetReply{},
|
||||
&cache.GetRequest{Key: "key"},
|
||||
&cache.GetReply{
|
||||
Exists: true,
|
||||
Value: []byte("hello"),
|
||||
XXX_NoUnkeyedLiteral: struct{}{},
|
||||
XXX_unrecognized: nil,
|
||||
XXX_sizecache: 0,
|
||||
Exists: true,
|
||||
Value: []byte("hello"),
|
||||
},
|
||||
false,
|
||||
false,
|
||||
|
@ -63,11 +61,8 @@ func TestCache_Get_and_Set(t *testing.T) {
|
|||
&cache.SetReply{},
|
||||
&cache.GetRequest{Key: "no-such-key"},
|
||||
&cache.GetReply{
|
||||
Exists: false,
|
||||
Value: nil,
|
||||
XXX_NoUnkeyedLiteral: struct{}{},
|
||||
XXX_unrecognized: nil,
|
||||
XXX_sizecache: 0,
|
||||
Exists: false,
|
||||
Value: nil,
|
||||
},
|
||||
false,
|
||||
false,
|
||||
|
@ -78,11 +73,8 @@ func TestCache_Get_and_Set(t *testing.T) {
|
|||
nil,
|
||||
&cache.GetRequest{Key: hugeKey},
|
||||
&cache.GetReply{
|
||||
Exists: false,
|
||||
Value: nil,
|
||||
XXX_NoUnkeyedLiteral: struct{}{},
|
||||
XXX_unrecognized: nil,
|
||||
XXX_sizecache: 0,
|
||||
Exists: false,
|
||||
Value: nil,
|
||||
},
|
||||
true,
|
||||
false,
|
||||
|
@ -96,7 +88,11 @@ func TestCache_Get_and_Set(t *testing.T) {
|
|||
t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(setGot, tt.SetReply); diff != "" {
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(cache.SetReply{}, cache.GetReply{}),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(setGot, tt.SetReply, cmpOpts...); diff != "" {
|
||||
t.Errorf("Cache.Set() = %v", diff)
|
||||
}
|
||||
getGot, err := c.Get(tt.ctx, tt.GetRequest)
|
||||
|
@ -104,7 +100,7 @@ func TestCache_Get_and_Set(t *testing.T) {
|
|||
t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(getGot, tt.GetReply); diff != "" {
|
||||
if diff := cmp.Diff(getGot, tt.GetReply, cmpOpts...); diff != "" {
|
||||
t.Errorf("Cache.Get() = %v", diff)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
|
||||
|
@ -151,6 +152,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options,
|
|||
},
|
||||
Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{
|
||||
GrpcService: &envoy_config_core_v3.GrpcService{
|
||||
Timeout: ptypes.DurationProto(time.Second * 30),
|
||||
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
|
||||
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
|
||||
ClusterName: "pomerium-authz",
|
||||
|
|
14
internal/grpc/cache/cache.go
vendored
Normal file
14
internal/grpc/cache/cache.go
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
// Package cache defines a Cacher interfaces that can be implemented by any
|
||||
// key value store system.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Cacher specifies an interface for remote clients connecting to the cache service.
|
||||
type Cacher interface {
|
||||
Get(ctx context.Context, key string) (value []byte, err error)
|
||||
Set(ctx context.Context, key string, value []byte) error
|
||||
Close() error
|
||||
}
|
439
internal/grpc/cache/cache.pb.go
vendored
439
internal/grpc/cache/cache.pb.go
vendored
|
@ -1,217 +1,356 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.21.0
|
||||
// protoc v3.11.4
|
||||
// source: cache.proto
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
context "context"
|
||||
fmt "fmt"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
math "math"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type GetRequest struct {
|
||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GetRequest) Reset() { *m = GetRequest{} }
|
||||
func (m *GetRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*GetRequest) ProtoMessage() {}
|
||||
func (x *GetRequest) Reset() {
|
||||
*x = GetRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cache_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *GetRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetRequest) ProtoMessage() {}
|
||||
|
||||
func (x *GetRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cache_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetRequest.ProtoReflect.Descriptor instead.
|
||||
func (*GetRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_5fca3b110c9bbf3a, []int{0}
|
||||
return file_cache_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (m *GetRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_GetRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *GetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_GetRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *GetRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_GetRequest.Merge(m, src)
|
||||
}
|
||||
func (m *GetRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_GetRequest.Size(m)
|
||||
}
|
||||
func (m *GetRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_GetRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_GetRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *GetRequest) GetKey() string {
|
||||
if m != nil {
|
||||
return m.Key
|
||||
func (x *GetRequest) GetKey() string {
|
||||
if x != nil {
|
||||
return x.Key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type GetReply struct {
|
||||
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
|
||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
|
||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GetReply) Reset() { *m = GetReply{} }
|
||||
func (m *GetReply) String() string { return proto.CompactTextString(m) }
|
||||
func (*GetReply) ProtoMessage() {}
|
||||
func (x *GetReply) Reset() {
|
||||
*x = GetReply{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cache_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *GetReply) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetReply) ProtoMessage() {}
|
||||
|
||||
func (x *GetReply) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cache_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetReply.ProtoReflect.Descriptor instead.
|
||||
func (*GetReply) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_5fca3b110c9bbf3a, []int{1}
|
||||
return file_cache_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (m *GetReply) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_GetReply.Unmarshal(m, b)
|
||||
}
|
||||
func (m *GetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_GetReply.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *GetReply) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_GetReply.Merge(m, src)
|
||||
}
|
||||
func (m *GetReply) XXX_Size() int {
|
||||
return xxx_messageInfo_GetReply.Size(m)
|
||||
}
|
||||
func (m *GetReply) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_GetReply.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_GetReply proto.InternalMessageInfo
|
||||
|
||||
func (m *GetReply) GetExists() bool {
|
||||
if m != nil {
|
||||
return m.Exists
|
||||
func (x *GetReply) GetExists() bool {
|
||||
if x != nil {
|
||||
return x.Exists
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *GetReply) GetValue() []byte {
|
||||
if m != nil {
|
||||
return m.Value
|
||||
func (x *GetReply) GetValue() []byte {
|
||||
if x != nil {
|
||||
return x.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SetRequest struct {
|
||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (m *SetRequest) Reset() { *m = SetRequest{} }
|
||||
func (m *SetRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*SetRequest) ProtoMessage() {}
|
||||
func (x *SetRequest) Reset() {
|
||||
*x = SetRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cache_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SetRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SetRequest) ProtoMessage() {}
|
||||
|
||||
func (x *SetRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cache_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SetRequest.ProtoReflect.Descriptor instead.
|
||||
func (*SetRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_5fca3b110c9bbf3a, []int{2}
|
||||
return file_cache_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (m *SetRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_SetRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *SetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_SetRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *SetRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_SetRequest.Merge(m, src)
|
||||
}
|
||||
func (m *SetRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_SetRequest.Size(m)
|
||||
}
|
||||
func (m *SetRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_SetRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_SetRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *SetRequest) GetKey() string {
|
||||
if m != nil {
|
||||
return m.Key
|
||||
func (x *SetRequest) GetKey() string {
|
||||
if x != nil {
|
||||
return x.Key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *SetRequest) GetValue() []byte {
|
||||
if m != nil {
|
||||
return m.Value
|
||||
func (x *SetRequest) GetValue() []byte {
|
||||
if x != nil {
|
||||
return x.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SetReply struct {
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (m *SetReply) Reset() { *m = SetReply{} }
|
||||
func (m *SetReply) String() string { return proto.CompactTextString(m) }
|
||||
func (*SetReply) ProtoMessage() {}
|
||||
func (x *SetReply) Reset() {
|
||||
*x = SetReply{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_cache_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SetReply) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SetReply) ProtoMessage() {}
|
||||
|
||||
func (x *SetReply) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_cache_proto_msgTypes[3]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SetReply.ProtoReflect.Descriptor instead.
|
||||
func (*SetReply) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_5fca3b110c9bbf3a, []int{3}
|
||||
return file_cache_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (m *SetReply) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_SetReply.Unmarshal(m, b)
|
||||
}
|
||||
func (m *SetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_SetReply.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *SetReply) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_SetReply.Merge(m, src)
|
||||
}
|
||||
func (m *SetReply) XXX_Size() int {
|
||||
return xxx_messageInfo_SetReply.Size(m)
|
||||
}
|
||||
func (m *SetReply) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_SetReply.DiscardUnknown(m)
|
||||
var File_cache_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_cache_proto_rawDesc = []byte{
|
||||
0x0a, 0x0b, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63,
|
||||
0x61, 0x63, 0x68, 0x65, 0x22, 0x1e, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x03, 0x6b, 0x65, 0x79, 0x22, 0x38, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
|
||||
0x12, 0x16, 0x0a, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
|
||||
0x52, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
|
||||
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x34,
|
||||
0x0a, 0x0a, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03,
|
||||
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14,
|
||||
0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76,
|
||||
0x61, 0x6c, 0x75, 0x65, 0x22, 0x0a, 0x0a, 0x08, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
|
||||
0x32, 0x61, 0x0a, 0x05, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x2b, 0x0a, 0x03, 0x47, 0x65, 0x74,
|
||||
0x12, 0x11, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52,
|
||||
0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x03, 0x53, 0x65, 0x74, 0x12, 0x11, 0x2e,
|
||||
0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c,
|
||||
0x79, 0x22, 0x00, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var xxx_messageInfo_SetReply proto.InternalMessageInfo
|
||||
var (
|
||||
file_cache_proto_rawDescOnce sync.Once
|
||||
file_cache_proto_rawDescData = file_cache_proto_rawDesc
|
||||
)
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*GetRequest)(nil), "cache.GetRequest")
|
||||
proto.RegisterType((*GetReply)(nil), "cache.GetReply")
|
||||
proto.RegisterType((*SetRequest)(nil), "cache.SetRequest")
|
||||
proto.RegisterType((*SetReply)(nil), "cache.SetReply")
|
||||
func file_cache_proto_rawDescGZIP() []byte {
|
||||
file_cache_proto_rawDescOnce.Do(func() {
|
||||
file_cache_proto_rawDescData = protoimpl.X.CompressGZIP(file_cache_proto_rawDescData)
|
||||
})
|
||||
return file_cache_proto_rawDescData
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterFile("cache.proto", fileDescriptor_5fca3b110c9bbf3a)
|
||||
var file_cache_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||
var file_cache_proto_goTypes = []interface{}{
|
||||
(*GetRequest)(nil), // 0: cache.GetRequest
|
||||
(*GetReply)(nil), // 1: cache.GetReply
|
||||
(*SetRequest)(nil), // 2: cache.SetRequest
|
||||
(*SetReply)(nil), // 3: cache.SetReply
|
||||
}
|
||||
var file_cache_proto_depIdxs = []int32{
|
||||
0, // 0: cache.Cache.Get:input_type -> cache.GetRequest
|
||||
2, // 1: cache.Cache.Set:input_type -> cache.SetRequest
|
||||
1, // 2: cache.Cache.Get:output_type -> cache.GetReply
|
||||
3, // 3: cache.Cache.Set:output_type -> cache.SetReply
|
||||
2, // [2:4] is the sub-list for method output_type
|
||||
0, // [0:2] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
var fileDescriptor_5fca3b110c9bbf3a = []byte{
|
||||
// 176 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x4e, 0x4c, 0xce,
|
||||
0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0xb8, 0xdc,
|
||||
0x53, 0x4b, 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x04, 0xb8, 0x98, 0xb3, 0x53, 0x2b,
|
||||
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x40, 0x4c, 0x25, 0x0b, 0x2e, 0x0e, 0xb0, 0x7c, 0x41,
|
||||
0x4e, 0xa5, 0x90, 0x18, 0x17, 0x5b, 0x6a, 0x45, 0x66, 0x71, 0x49, 0x31, 0x58, 0x01, 0x47, 0x10,
|
||||
0x94, 0x27, 0x24, 0xc2, 0xc5, 0x5a, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, 0xa4, 0xc0, 0xa8, 0xc1,
|
||||
0x13, 0x04, 0xe1, 0x28, 0x99, 0x70, 0x71, 0x05, 0xe3, 0x31, 0x19, 0x87, 0x2e, 0x2e, 0x2e, 0x8e,
|
||||
0x60, 0xa8, 0x7d, 0x46, 0x89, 0x5c, 0xac, 0xce, 0x20, 0x47, 0x0a, 0x69, 0x73, 0x31, 0xbb, 0xa7,
|
||||
0x96, 0x08, 0x09, 0xea, 0x41, 0x3c, 0x80, 0x70, 0xb0, 0x14, 0x3f, 0xb2, 0x50, 0x41, 0x4e, 0xa5,
|
||||
0x12, 0x03, 0x48, 0x71, 0x30, 0x92, 0xe2, 0x60, 0x4c, 0xc5, 0xc1, 0x70, 0xc5, 0x49, 0x6c, 0xe0,
|
||||
0xc0, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x0e, 0xef, 0x5f, 0x9e, 0x1b, 0x01, 0x00, 0x00,
|
||||
func init() { file_cache_proto_init() }
|
||||
func file_cache_proto_init() {
|
||||
if File_cache_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_cache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*GetRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*GetReply); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_cache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetReply); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_cache_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 4,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_cache_proto_goTypes,
|
||||
DependencyIndexes: file_cache_proto_depIdxs,
|
||||
MessageInfos: file_cache_proto_msgTypes,
|
||||
}.Build()
|
||||
File_cache_proto = out.File
|
||||
file_cache_proto_rawDesc = nil
|
||||
file_cache_proto_goTypes = nil
|
||||
file_cache_proto_depIdxs = nil
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
|
@ -266,10 +405,10 @@ type CacheServer interface {
|
|||
type UnimplementedCacheServer struct {
|
||||
}
|
||||
|
||||
func (*UnimplementedCacheServer) Get(ctx context.Context, req *GetRequest) (*GetReply, error) {
|
||||
func (*UnimplementedCacheServer) Get(context.Context, *GetRequest) (*GetReply, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
|
||||
}
|
||||
func (*UnimplementedCacheServer) Set(ctx context.Context, req *SetRequest) (*SetReply, error) {
|
||||
func (*UnimplementedCacheServer) Set(context.Context, *SetRequest) (*SetReply, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
|
||||
}
|
||||
|
||||
|
|
17
internal/grpc/cache/client/cache_client.go
vendored
17
internal/grpc/cache/client/cache_client.go
vendored
|
@ -3,6 +3,7 @@ package client
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
|
@ -10,12 +11,7 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Cacher specifies an interface for remote clients connecting to the cache service.
|
||||
type Cacher interface {
|
||||
Get(ctx context.Context, key string) (keyExists bool, value []byte, err error)
|
||||
Set(ctx context.Context, key string, value []byte) error
|
||||
Close() error
|
||||
}
|
||||
var errKeyNotFound = errors.New("cache/client: key not found")
|
||||
|
||||
// Client represents a gRPC cache service client.
|
||||
type Client struct {
|
||||
|
@ -29,15 +25,18 @@ func New(conn *grpc.ClientConn) (p *Client) {
|
|||
}
|
||||
|
||||
// Get retrieves a value from the cache service.
|
||||
func (a *Client) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
|
||||
func (a *Client) Get(ctx context.Context, key string) (value []byte, err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get")
|
||||
defer span.End()
|
||||
|
||||
response, err := a.client.Get(ctx, &cache.GetRequest{Key: key})
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
return nil, err
|
||||
}
|
||||
return response.GetExists(), response.GetValue(), nil
|
||||
if !response.GetExists() {
|
||||
return nil, errKeyNotFound
|
||||
}
|
||||
return response.GetValue(), nil
|
||||
}
|
||||
|
||||
// Set stores a key value pair in the cache service.
|
||||
|
|
77
internal/grpc/cache/mock/mock_cacher.go
vendored
Normal file
77
internal/grpc/cache/mock/mock_cacher.go
vendored
Normal file
|
@ -0,0 +1,77 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/pomerium/pomerium/internal/grpc/cache (interfaces: Cacher)
|
||||
|
||||
// Package mock_cache is a generated GoMock package.
|
||||
package mock_cache
|
||||
|
||||
import (
|
||||
context "context"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockCacher is a mock of Cacher interface
|
||||
type MockCacher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCacherMockRecorder
|
||||
}
|
||||
|
||||
// MockCacherMockRecorder is the mock recorder for MockCacher
|
||||
type MockCacherMockRecorder struct {
|
||||
mock *MockCacher
|
||||
}
|
||||
|
||||
// NewMockCacher creates a new mock instance
|
||||
func NewMockCacher(ctrl *gomock.Controller) *MockCacher {
|
||||
mock := &MockCacher{ctrl: ctrl}
|
||||
mock.recorder = &MockCacherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCacher) EXPECT() *MockCacherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockCacher) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockCacherMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCacher)(nil).Close))
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
func (m *MockCacher) Get(arg0 context.Context, arg1 string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0, arg1)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
func (mr *MockCacherMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCacher)(nil).Get), arg0, arg1)
|
||||
}
|
||||
|
||||
// Set mocks base method
|
||||
func (m *MockCacher) Set(arg0 context.Context, arg1 string, arg2 []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Set indicates an expected call of Set
|
||||
func (mr *MockCacherMockRecorder) Set(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCacher)(nil).Set), arg0, arg1, arg2)
|
||||
}
|
22
internal/hashutil/hashutil.go
Normal file
22
internal/hashutil/hashutil.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
|
||||
package hashutil
|
||||
|
||||
import (
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
)
|
||||
|
||||
// Hash returns the xxhash value of an arbitrary value or struct. Returns 0
|
||||
// on error. NOT SUITABLE FOR CRYTOGRAPHIC HASHING.
|
||||
//
|
||||
// http://cyan4973.github.io/xxHash/
|
||||
func Hash(v interface{}) uint64 {
|
||||
opts := &hashstructure.HashOptions{
|
||||
Hasher: xxhash.New(),
|
||||
}
|
||||
hash, err := hashstructure.Hash(v, opts)
|
||||
if err != nil {
|
||||
hash = 0
|
||||
}
|
||||
return hash
|
||||
}
|
37
internal/hashutil/hashutil_test.go
Normal file
37
internal/hashutil/hashutil_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
|
||||
package hashutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
want uint64
|
||||
}{
|
||||
{"string", "string", 6134271061086542852},
|
||||
{"num", 7, 609900476111905877},
|
||||
{"compound struct", struct {
|
||||
NESCarts []string
|
||||
numberOfCarts int
|
||||
}{
|
||||
[]string{"Battletoads", "Mega Man 1", "Clash at Demonhead"},
|
||||
12,
|
||||
},
|
||||
9061978360207659575},
|
||||
{"compound struct with embedded func (errors!)", struct {
|
||||
AnswerToEverythingFn func() int
|
||||
}{
|
||||
func() int { return 42 },
|
||||
},
|
||||
0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := Hash(tt.v); got != tt.want {
|
||||
t.Errorf("Hash() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -5,15 +5,13 @@ import (
|
|||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
// MockProvider provides a mocked implementation of the providers interface.
|
||||
type MockProvider struct {
|
||||
AuthenticateResponse sessions.State
|
||||
AuthenticateResponse oauth2.Token
|
||||
AuthenticateError error
|
||||
RefreshResponse sessions.State
|
||||
RefreshResponse oauth2.Token
|
||||
RefreshError error
|
||||
RevokeError error
|
||||
GetSignInURLResponse string
|
||||
|
@ -22,12 +20,12 @@ type MockProvider struct {
|
|||
}
|
||||
|
||||
// Authenticate is a mocked providers function.
|
||||
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
func (mp MockProvider) Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) {
|
||||
return &mp.AuthenticateResponse, mp.AuthenticateError
|
||||
}
|
||||
|
||||
// Refresh is a mocked providers function.
|
||||
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
func (mp MockProvider) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) {
|
||||
return &mp.RefreshResponse, mp.RefreshError
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
|
@ -77,48 +76,36 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
|
||||
// Authenticate creates an identity session with github from a authorization code, and follows up
|
||||
// call to the user and user group endpoint with the
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
resp, err := p.Oauth.Exchange(ctx, code)
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
|
||||
oauth2Token, err := p.Oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: token exchange failed %v", err)
|
||||
}
|
||||
|
||||
s := &sessions.State{
|
||||
AccessToken: &oauth2.Token{
|
||||
AccessToken: resp.AccessToken,
|
||||
TokenType: resp.TokenType,
|
||||
},
|
||||
AccessTokenID: resp.AccessToken,
|
||||
}
|
||||
|
||||
err = p.updateSessionState(ctx, s)
|
||||
err = p.updateSessionState(ctx, oauth2Token, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
return oauth2Token, nil
|
||||
}
|
||||
|
||||
// updateSessionState will get the user information from github and also retrieve the user's team(s)
|
||||
//
|
||||
// https://developer.github.com/v3/users/#get-the-authenticated-user
|
||||
func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) error {
|
||||
if s == nil || s.AccessToken == nil {
|
||||
return errors.New("github: user session cannot be empty")
|
||||
}
|
||||
accessToken := s.AccessToken.AccessToken
|
||||
func (p *Provider) updateSessionState(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
|
||||
err := p.userInfo(ctx, accessToken, s)
|
||||
err := p.userInfo(ctx, t, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("github: could not retrieve user info %w", err)
|
||||
}
|
||||
|
||||
err = p.userEmail(ctx, accessToken, s)
|
||||
err = p.userEmail(ctx, t, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("github: could not retrieve user email %w", err)
|
||||
}
|
||||
|
||||
err = p.userTeams(ctx, accessToken, s)
|
||||
err = p.userTeams(ctx, t, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("github: could not retrieve groups %w", err)
|
||||
}
|
||||
|
@ -127,14 +114,12 @@ func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) er
|
|||
}
|
||||
|
||||
// Refresh renews a user's session by making a new userInfo request.
|
||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.AccessToken == nil {
|
||||
return nil, errors.New("github: missing oauth2 access token")
|
||||
}
|
||||
if err := p.updateSessionState(ctx, s); err != nil {
|
||||
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
|
||||
err := p.updateSessionState(ctx, t, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// userTeams returns a slice of teams the user belongs by making a request
|
||||
|
@ -142,7 +127,7 @@ func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.St
|
|||
//
|
||||
// https://developer.github.com/v3/teams/#list-user-teams
|
||||
// https://developer.github.com/v3/auth/
|
||||
func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) error {
|
||||
func (p *Provider) userTeams(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
|
||||
var response []struct {
|
||||
ID json.Number `json:"id"`
|
||||
|
@ -154,20 +139,24 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
|
|||
Privacy string `json:"privacy,omitempty"`
|
||||
}
|
||||
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
|
||||
teamURL := githubAPIURL + teamPath
|
||||
err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Interface("teams", response).Msg("github: user teams")
|
||||
s.Groups = nil
|
||||
for _, org := range response {
|
||||
s.Groups = append(s.Groups, org.ID.String())
|
||||
var out struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
return nil
|
||||
for _, org := range response {
|
||||
out.Groups = append(out.Groups, org.ID.String())
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
// userEmail returns the primary email of the user by making
|
||||
|
@ -175,7 +164,7 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
|
|||
//
|
||||
// https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user
|
||||
// https://developer.github.com/v3/auth/
|
||||
func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State) error {
|
||||
func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
// response represents the github user email
|
||||
// https://developer.github.com/v3/users/emails/#response
|
||||
var response []struct {
|
||||
|
@ -184,48 +173,67 @@ func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State)
|
|||
Primary bool `json:"primary"`
|
||||
Visibility string `json:"visibility"`
|
||||
}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
|
||||
emailURL := githubAPIURL + emailPath
|
||||
err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var out struct {
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
}
|
||||
log.Debug().Interface("emails", response).Msg("github: user emails")
|
||||
for _, email := range response {
|
||||
if email.Primary && email.Verified {
|
||||
s.Email = email.Email
|
||||
s.EmailVerified = true
|
||||
return nil
|
||||
out.Email = email.Email
|
||||
out.Verified = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
func (p *Provider) userInfo(ctx context.Context, at string, s *sessions.State) error {
|
||||
func (p *Provider) userInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
var response struct {
|
||||
ID int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url,omitempty"`
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": fmt.Sprintf("token %s", at),
|
||||
"Authorization": fmt.Sprintf("token %s", t.AccessToken),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var out struct {
|
||||
Subject string `json:"sub"`
|
||||
Name string `json:"name,omitempty"`
|
||||
User string `json:"user"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
// needs to be set manually
|
||||
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||
}
|
||||
|
||||
s.User = response.Login
|
||||
s.Name = response.Name
|
||||
s.Picture = response.AvatarURL
|
||||
out.User = response.Login
|
||||
out.Subject = response.Login
|
||||
out.Name = response.Name
|
||||
out.Picture = response.AvatarURL
|
||||
// set the session expiry
|
||||
s.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
|
||||
return nil
|
||||
out.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
// Revoke method will remove all the github grants the user
|
||||
|
|
|
@ -5,7 +5,7 @@ package azure
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
|
@ -60,10 +59,7 @@ func (p *Provider) GetSignInURL(state string) string {
|
|||
// `Directory.Read.All` is required.
|
||||
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
|
||||
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
|
||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
if s == nil || s.AccessToken == nil {
|
||||
return nil, errors.New("identity/azure: session cannot be nil")
|
||||
}
|
||||
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
var response struct {
|
||||
Groups []struct {
|
||||
ID string `json:"id"`
|
||||
|
@ -73,15 +69,23 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
|
|||
GroupTypes []string `json:"groupTypes,omitempty"`
|
||||
} `json:"value"`
|
||||
}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
|
||||
err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Interface("response", response).Msg("microsoft: groups")
|
||||
var out struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
var groups []string
|
||||
for _, group := range response.Groups {
|
||||
log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("microsoft: group")
|
||||
groups = append(groups, group.ID)
|
||||
out.Groups = append(out.Groups, group.ID)
|
||||
}
|
||||
return groups, nil
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package oidc
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrRevokeNotImplemented error type when Revoke method is not implemented
|
||||
// by an identity provider
|
||||
// ErrRevokeNotImplemented is returned when revoke is not implemented
|
||||
// by an identity provider.
|
||||
var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented")
|
||||
|
||||
// ErrSignoutNotImplemented error type when end session is not implemented
|
||||
// ErrSignoutNotImplemented is returned when end session is not implemented
|
||||
// by an identity provider
|
||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
|
||||
|
@ -14,3 +16,13 @@ var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implem
|
|||
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
||||
// does not receive one.
|
||||
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
|
||||
|
||||
// ErrMissingIDToken is returned when (usually on refresh) and identity provider
|
||||
// failed to include an id_token in a oauth2 token.
|
||||
var ErrMissingIDToken = errors.New("identity/oidc: missing id_token")
|
||||
|
||||
// ErrMissingRefreshToken is returned if no refresh token was found.
|
||||
var ErrMissingRefreshToken = errors.New("identity/oidc: missing refresh token")
|
||||
|
||||
// ErrMissingAccessToken is returned when no access token was found.
|
||||
var ErrMissingAccessToken = errors.New("identity/oidc: missing access token")
|
||||
|
|
|
@ -6,7 +6,6 @@ package gitlab
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
|
@ -15,14 +14,14 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Name identifies the GitLab identity provider
|
||||
// Name identifies the GitLab identity provider.
|
||||
const Name = "gitlab"
|
||||
|
||||
var defaultScopes = []string{oidc.ScopeOpenID, "read_api", "read_user", "profile", "email"}
|
||||
var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email", "api"}
|
||||
|
||||
const (
|
||||
defaultProviderURL = "https://gitlab.com"
|
||||
|
@ -64,11 +63,7 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
//
|
||||
// Returns 20 results at a time because the API results are paginated.
|
||||
// https://docs.gitlab.com/ee/api/groups.html#list-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
if s == nil || s.AccessToken == nil {
|
||||
return nil, errors.New("gitlab: user session cannot be empty")
|
||||
}
|
||||
|
||||
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
var response []struct {
|
||||
ID json.Number `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
@ -81,17 +76,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
|
|||
FullName string `json:"full_name,omitempty"`
|
||||
FullPath string `json:"full_path,omitempty"`
|
||||
}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
|
||||
err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var groups []string
|
||||
log.Debug().Interface("response", response).Msg("gitlab: groups")
|
||||
var out struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
for _, group := range response {
|
||||
groups = append(groups, group.ID.String())
|
||||
out.Groups = append(out.Groups, group.ID.String())
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
|
@ -18,7 +19,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -54,36 +54,35 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
||||
}
|
||||
p.Provider = genericOidc
|
||||
|
||||
// if service account set, configure admin sdk calls
|
||||
if o.ServiceAccount != "" {
|
||||
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: could not decode service account json %w", err)
|
||||
}
|
||||
// Required scopes for groups api
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: failed making jwt config from json %w", err)
|
||||
}
|
||||
var credentialsFile struct {
|
||||
ImpersonateUser string `json:"impersonate_user"`
|
||||
}
|
||||
if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Subject = credentialsFile.ImpersonateUser
|
||||
client := conf.Client(context.TODO())
|
||||
p.apiClient, err = admin.New(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
||||
}
|
||||
p.UserGroupFn = p.UserGroups
|
||||
} else {
|
||||
log.Warn().Msg("google: no service account, cannot retrieve groups")
|
||||
if o.ServiceAccount == "" {
|
||||
log.Warn().Msg("google: no service account, will not fetch groups")
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: could not decode service account json %w", err)
|
||||
}
|
||||
// Required scopes for groups api
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: failed making jwt config from json %w", err)
|
||||
}
|
||||
var credentialsFile struct {
|
||||
ImpersonateUser string `json:"impersonate_user"`
|
||||
}
|
||||
if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Subject = credentialsFile.ImpersonateUser
|
||||
client := conf.Client(context.TODO())
|
||||
p.apiClient, err = admin.New(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
||||
}
|
||||
p.UserGroupFn = p.UserGroups
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
|
@ -106,17 +105,34 @@ func (p *Provider) GetSignInURL(state string) string {
|
|||
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
// https://developers.google.com/admin-sdk/directory/v1/limits
|
||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
var groups []string
|
||||
if p.apiClient != nil {
|
||||
req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100)
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: group api request failed %w", err)
|
||||
}
|
||||
for _, group := range resp.Groups {
|
||||
groups = append(groups, group.Email)
|
||||
}
|
||||
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
if p.apiClient == nil {
|
||||
return errors.New("google: trying to fetch groups, but no api client")
|
||||
}
|
||||
return groups, nil
|
||||
s, err := p.GetSubject(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var out struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
req := p.apiClient.Groups.List().Context(ctx).UserKey(s)
|
||||
err = req.Pages(ctx, func(resp *admin.Groups) error {
|
||||
for _, group := range resp.Groups {
|
||||
out.Groups = append(out.Groups, group.Email)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = req.Do()
|
||||
if err != nil {
|
||||
return fmt.Errorf("google: group api request failed %w", err)
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ package oidc
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -15,12 +16,11 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// Name identifies the generic OpenID Connect provider
|
||||
// Name identifies the generic OpenID Connect provider.
|
||||
const Name = "oidc"
|
||||
|
||||
var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||
|
@ -37,11 +37,6 @@ type Provider struct {
|
|||
// client application information and the server's endpoint URLs.
|
||||
Oauth *oauth2.Config
|
||||
|
||||
// UserInfoURL specifies the endpoint responsible for returning claims
|
||||
// about the authenticated End-User.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
UserInfoURL string `json:"userinfo_endpoint,omitempty"`
|
||||
|
||||
// RevocationURL is the location of the OAuth 2.0 token revocation endpoint.
|
||||
// https://tools.ietf.org/html/rfc7009
|
||||
RevocationURL string `json:"revocation_endpoint,omitempty"`
|
||||
|
@ -53,7 +48,7 @@ type Provider struct {
|
|||
|
||||
// UserGroupFn is, if set, used to return a slice of group IDs the
|
||||
// user is a member of
|
||||
UserGroupFn func(context.Context, *sessions.State) ([]string, error)
|
||||
UserGroupFn func(context.Context, *oauth2.Token, interface{}) error
|
||||
}
|
||||
|
||||
// New creates a new instance of a generic OpenID Connect provider.
|
||||
|
@ -100,81 +95,89 @@ func (p *Provider) GetSignInURL(state string) string {
|
|||
|
||||
// Authenticate converts an authorization code returned from the identity
|
||||
// provider into a token which is then converted into a user session.
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
|
||||
// Exchange converts an authorization code into a token.
|
||||
oauth2Token, err := p.Oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
||||
}
|
||||
idToken, err := p.IdentityFromToken(ctx, oauth2Token)
|
||||
|
||||
idToken, err := p.getIDToken(ctx, oauth2Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
||||
}
|
||||
|
||||
aud, err := urlutil.ParseAndValidateURL(p.Oauth.RedirectURL)
|
||||
// hydrate `v` using claims inside the returned `id_token`
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
if err := idToken.Claims(v); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||
}
|
||||
|
||||
if err := p.updateUserInfo(ctx, oauth2Token, v); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||
}
|
||||
|
||||
return oauth2Token, nil
|
||||
}
|
||||
|
||||
// updateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
|
||||
// groups endpoint (non-spec) to populate the rest of the user's information.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: bad redirect uri: %w", err)
|
||||
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
|
||||
}
|
||||
|
||||
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, aud.Hostname())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err := userInfo.Claims(v); err != nil {
|
||||
return fmt.Errorf("identity/oidc: failed parsing user info endpoint claims: %w", err)
|
||||
}
|
||||
|
||||
if err := p.Provider.Claims(&p); err == nil && p.UserInfoURL != "" {
|
||||
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: could not retrieve user info %w", err)
|
||||
}
|
||||
if err := userInfo.Claims(&s); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: could not parse user claims %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.UserGroupFn != nil {
|
||||
s.Groups, err = p.UserGroupFn(ctx, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
|
||||
if err := p.UserGroupFn(ctx, t, v); err != nil {
|
||||
return fmt.Errorf("identity/oidc: could not retrieve groups: %w", err)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an oidc refresh token without reprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.AccessToken == nil || s.AccessToken.RefreshToken == "" {
|
||||
return nil, errors.New("internal/oidc: missing refresh token")
|
||||
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
|
||||
if t == nil {
|
||||
return nil, ErrMissingAccessToken
|
||||
}
|
||||
if t.RefreshToken == "" {
|
||||
return nil, ErrMissingRefreshToken
|
||||
}
|
||||
var err error
|
||||
newToken, err := p.Oauth.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||
}
|
||||
|
||||
t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken}
|
||||
oauthToken, err := p.Oauth.TokenSource(ctx, &t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/oidc: refresh failed %w", err)
|
||||
}
|
||||
idToken, err := p.IdentityFromToken(ctx, oauthToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
||||
}
|
||||
if err := s.UpdateState(idToken, oauthToken); err != nil {
|
||||
return nil, fmt.Errorf("internal/oidc: state update failed %w", err)
|
||||
}
|
||||
if p.UserGroupFn != nil {
|
||||
s.Groups, err = p.UserGroupFn(ctx, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
|
||||
// Many identity providers _will not_ return `id_token` on refresh
|
||||
// https://github.com/FusionAuth/fusionauth-issues/issues/110#issuecomment-481526544
|
||||
idToken, err := p.getIDToken(ctx, newToken)
|
||||
if err == nil {
|
||||
if err := idToken.Claims(v); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
if err := p.updateUserInfo(ctx, newToken, v); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||
}
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
// IdentityFromToken takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
|
||||
// getIDToken returns the raw jwt payload for `id_token` from the oauth2 token
|
||||
// returned following oidc code flow
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||
func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
|
||||
rawIDToken, ok := t.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal/oidc: id_token not found")
|
||||
return nil, ErrMissingIDToken
|
||||
}
|
||||
return p.Verifier.Verify(ctx, rawIDToken)
|
||||
}
|
||||
|
@ -183,13 +186,16 @@ func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_
|
|||
// support revocation an error is thrown.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc7009#section-2.1
|
||||
func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||
func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
|
||||
if p.RevocationURL == "" {
|
||||
return ErrRevokeNotImplemented
|
||||
}
|
||||
if t == nil {
|
||||
return ErrMissingAccessToken
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("token", token.AccessToken)
|
||||
params.Add("token", t.AccessToken)
|
||||
params.Add("token_type_hint", "access_token")
|
||||
// Some providers like okta / onelogin require "client authentication"
|
||||
// https://developer.okta.com/docs/reference/api/oidc/#client-secret
|
||||
|
@ -198,7 +204,7 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
|||
params.Add("client_secret", p.Oauth.ClientSecret)
|
||||
|
||||
err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
if err != nil && errors.Is(err, httputil.ErrTokenRevoked) {
|
||||
return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err)
|
||||
}
|
||||
|
||||
|
@ -214,3 +220,20 @@ func (p *Provider) LogOut() (*url.URL, error) {
|
|||
}
|
||||
return urlutil.ParseAndValidateURL(p.EndSessionURL)
|
||||
}
|
||||
|
||||
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
|
||||
func (p *Provider) GetSubject(v interface{}) (string, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var s struct {
|
||||
Subject string `json:"sub"`
|
||||
}
|
||||
|
||||
err = json.Unmarshal(b, &s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Subject, nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ package okta
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -13,9 +14,9 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -64,7 +65,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
|
||||
// UserGroups fetches the groups of which the user is a member
|
||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
s, err := p.GetSubject(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var response []struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
|
@ -74,15 +79,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
|
|||
}
|
||||
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)}
|
||||
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject)
|
||||
err := httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
|
||||
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s)
|
||||
err = httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
log.Debug().Interface("response", response).Msg("okta: groups")
|
||||
var out struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
var groups []string
|
||||
for _, group := range response {
|
||||
log.Debug().Interface("group", group).Msg("okta: group")
|
||||
groups = append(groups, group.ID)
|
||||
out.Groups = append(out.Groups, group.ID)
|
||||
}
|
||||
return groups, nil
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
|
|
@ -5,17 +5,15 @@ package onelogin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
|
@ -55,24 +53,10 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
|
||||
// UserGroups returns a slice of group names a given user is in.
|
||||
// https://developers.onelogin.com/openid-connect/api/user-info
|
||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
if s == nil || s.AccessToken == nil {
|
||||
return nil, errors.New("identity/onelogin: session cannot be nil")
|
||||
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
if t == nil {
|
||||
return pom_oidc.ErrMissingAccessToken
|
||||
}
|
||||
var response struct {
|
||||
User string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Name string `json:"name"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
||||
err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response.Groups, nil
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
|
||||
return httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, v)
|
||||
}
|
||||
|
|
|
@ -17,25 +17,24 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oidc/google"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var (
|
||||
// compile time assertions that providers are satisfying the interface
|
||||
_ Authenticator = &azure.Provider{}
|
||||
_ Authenticator = &gitlab.Provider{}
|
||||
_ Authenticator = &github.Provider{}
|
||||
_ Authenticator = &gitlab.Provider{}
|
||||
_ Authenticator = &google.Provider{}
|
||||
_ Authenticator = &MockProvider{}
|
||||
_ Authenticator = &oidc.Provider{}
|
||||
_ Authenticator = &okta.Provider{}
|
||||
_ Authenticator = &onelogin.Provider{}
|
||||
_ Authenticator = &MockProvider{}
|
||||
)
|
||||
|
||||
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
||||
type Authenticator interface {
|
||||
Authenticate(context.Context, string) (*sessions.State, error)
|
||||
Refresh(context.Context, *sessions.State) (*sessions.State, error)
|
||||
Authenticate(context.Context, string, interface{}) (*oauth2.Token, error)
|
||||
Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error)
|
||||
Revoke(context.Context, *oauth2.Token) error
|
||||
GetSignInURL(state string) string
|
||||
LogOut() (*url.URL, error)
|
||||
|
|
|
@ -37,6 +37,9 @@ type Store struct {
|
|||
srv *http.Server
|
||||
}
|
||||
|
||||
// ErrCacheMiss is returned when the cache misses for a given key.
|
||||
var ErrCacheMiss = errors.New("cache miss")
|
||||
|
||||
// Options represent autocache options.
|
||||
type Options struct {
|
||||
Addr string
|
||||
|
@ -60,7 +63,7 @@ var DefaultOptions = &Options{
|
|||
GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error {
|
||||
b := fromContext(ctx)
|
||||
if len(b) == 0 {
|
||||
return fmt.Errorf("autocache: empty ctx for id: %s", id)
|
||||
return fmt.Errorf("autocache: id %s : %w", id, ErrCacheMiss)
|
||||
}
|
||||
if err := dest.SetBytes(b); err != nil {
|
||||
return fmt.Errorf("autocache: sink error %w", err)
|
||||
|
|
95
internal/sessions/cache/cache_store.go
vendored
95
internal/sessions/cache/cache_store.go
vendored
|
@ -1,95 +0,0 @@
|
|||
// Package cache provides a remote cache based implementation of session store
|
||||
// and loader. See pomerium's cache service for more details.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var _ sessions.SessionStore = &Store{}
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
// Store implements the session store interface using a cache service.
|
||||
type Store struct {
|
||||
cache client.Cacher
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
queryParam string
|
||||
wrappedStore sessions.SessionStore
|
||||
}
|
||||
|
||||
// Options represent cache store's available configurations.
|
||||
type Options struct {
|
||||
Cache client.Cacher
|
||||
Encoder encoding.MarshalUnmarshaler
|
||||
QueryParam string
|
||||
WrappedStore sessions.SessionStore
|
||||
}
|
||||
|
||||
var defaultOptions = &Options{
|
||||
QueryParam: "cache_store_key",
|
||||
}
|
||||
|
||||
// NewStore creates a new cache
|
||||
func NewStore(o *Options) *Store {
|
||||
if o.QueryParam == "" {
|
||||
o.QueryParam = defaultOptions.QueryParam
|
||||
}
|
||||
return &Store{
|
||||
cache: o.Cache,
|
||||
encoder: o.Encoder,
|
||||
queryParam: o.QueryParam,
|
||||
wrappedStore: o.WrappedStore,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSession looks for a preset query parameter in the request body
|
||||
// representing the key to lookup from the cache.
|
||||
func (s *Store) LoadSession(r *http.Request) (string, error) {
|
||||
// look for our cache's key in the default query param
|
||||
sessionID := r.URL.Query().Get(s.queryParam)
|
||||
if sessionID == "" {
|
||||
return "", sessions.ErrNoSessionFound
|
||||
}
|
||||
exists, val, err := s.cache.Get(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader")
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", sessions.ErrNoSessionFound
|
||||
}
|
||||
|
||||
return string(val), nil
|
||||
}
|
||||
|
||||
// ClearSession clears the session from the wrapped store.
|
||||
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
s.wrappedStore.ClearSession(w, r)
|
||||
}
|
||||
|
||||
// SaveSession saves the session to the cache, and wrapped store.
|
||||
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||
err := s.wrappedStore.SaveSession(w, r, x)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
|
||||
}
|
||||
|
||||
state, ok := x.(*sessions.State)
|
||||
if !ok {
|
||||
return errors.New("sessions/cache: cannot cache non state type")
|
||||
}
|
||||
|
||||
data, err := s.encoder.Marshal(&state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions/cache: marshal %w", err)
|
||||
}
|
||||
|
||||
return s.cache.Set(r.Context(), state.AccessTokenID, data)
|
||||
}
|
190
internal/sessions/cache/cache_store_test.go
vendored
190
internal/sessions/cache/cache_store_test.go
vendored
|
@ -1,190 +0,0 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
type mockCache struct {
|
||||
Key string
|
||||
KeyExists bool
|
||||
Value []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
|
||||
return mc.KeyExists, mc.Value, mc.Err
|
||||
}
|
||||
func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error {
|
||||
return mc.Err
|
||||
}
|
||||
func (mc *mockCache) Close() error {
|
||||
return mc.Err
|
||||
}
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
Options *Options
|
||||
State *sessions.State
|
||||
|
||||
wantErr bool
|
||||
wantLoadErr bool
|
||||
wantStatus int
|
||||
}{
|
||||
{"simple good",
|
||||
&Options{
|
||||
Cache: &mockCache{},
|
||||
WrappedStore: &mock.Store{},
|
||||
Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
},
|
||||
&sessions.State{Email: "user@domain.com", User: "user"},
|
||||
false, false,
|
||||
http.StatusOK},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewStore(tt.Options)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
got.ClearSession(w, r)
|
||||
status := w.Result().StatusCode
|
||||
if diff := cmp.Diff(status, tt.wantStatus); diff != "" {
|
||||
t.Errorf("ClearSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_SaveSession(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cs, err := cookie.NewStore(&cookie.Options{
|
||||
Name: "_pomerium",
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
Options *Options
|
||||
|
||||
x interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
|
||||
{"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
|
||||
{"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
|
||||
{"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := tt.Options
|
||||
if o.WrappedStore == nil {
|
||||
o.WrappedStore = cs
|
||||
|
||||
}
|
||||
cacheStore := NewStore(tt.Options)
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_LoadSession(t *testing.T) {
|
||||
key := cryptutil.NewBase64Key()
|
||||
tests := []struct {
|
||||
name string
|
||||
state *sessions.State
|
||||
cache client.Cacher
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
queryParam string
|
||||
wrappedStore sessions.SessionStore
|
||||
wantErr bool
|
||||
}{
|
||||
{"good",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: true},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
false},
|
||||
{"missing param with key",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: true},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
"bad_query",
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
{"doesn't exist",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{KeyExists: false},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
{"retrieval error",
|
||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
||||
&mockCache{Err: errors.New("err")},
|
||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
||||
defaultOptions.QueryParam,
|
||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
||||
true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Store{
|
||||
cache: tt.cache,
|
||||
encoder: tt.encoder,
|
||||
queryParam: tt.queryParam,
|
||||
wrappedStore: tt.wrappedStore,
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
q := r.URL.Query()
|
||||
|
||||
q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
_, err := s.LoadSession(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,15 +1,11 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
@ -27,6 +23,9 @@ type State struct {
|
|||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||
ID string `json:"jti,omitempty"`
|
||||
// At_hash is an OPTIONAL Access Token hash value
|
||||
// https://ldapwiki.com/wiki/At_hash
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
|
||||
// core pomerium identity claims ; not standard to RFC 7519
|
||||
Email string `json:"email"`
|
||||
|
@ -48,84 +47,24 @@ type State struct {
|
|||
// Programmatic whether this state is used for machine-to-machine
|
||||
// programatic access.
|
||||
Programmatic bool `json:"programatic"`
|
||||
|
||||
AccessToken *oauth2.Token `json:"act,omitempty"`
|
||||
AccessTokenID string `json:"ati,omitempty"`
|
||||
|
||||
idToken *oidc.IDToken
|
||||
}
|
||||
|
||||
// NewStateFromTokens returns a session state built from oidc and oauth2
|
||||
// tokens as part of OpenID Connect flow with a new audience appended to the
|
||||
// audience claim.
|
||||
func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) {
|
||||
if idToken == nil {
|
||||
return nil, errors.New("sessions: oidc id token missing")
|
||||
}
|
||||
if accessToken == nil {
|
||||
return nil, errors.New("sessions: oauth2 token missing")
|
||||
}
|
||||
s := &State{}
|
||||
if err := idToken.Claims(s); err != nil {
|
||||
return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err)
|
||||
}
|
||||
s.Audience = []string{audience}
|
||||
s.idToken = idToken
|
||||
s.AccessToken = accessToken
|
||||
s.AccessTokenID = s.accessTokenHash()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// UpdateState updates the current state given a new identity (oidc) and authorization
|
||||
// (oauth2) tokens following a oidc refresh. NB, unlike during authentication,
|
||||
// refresh typically provides fewer claims in the token so we want to build from
|
||||
// our previous state.
|
||||
func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error {
|
||||
if idToken == nil {
|
||||
return errors.New("sessions: oidc id token missing")
|
||||
}
|
||||
if accessToken == nil {
|
||||
return errors.New("sessions: oauth2 token missing")
|
||||
}
|
||||
audience := append(s.Audience[:0:0], s.Audience...)
|
||||
s.AccessToken = accessToken
|
||||
if err := idToken.Claims(s); err != nil {
|
||||
return fmt.Errorf("sessions: update state failed %w", err)
|
||||
}
|
||||
s.Audience = audience
|
||||
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
||||
s.AccessTokenID = s.accessTokenHash()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSession updates issuer, audience, and issuance timestamps but keeps
|
||||
// parent expiry.
|
||||
func (s State) NewSession(issuer string, audience []string) *State {
|
||||
s.IssuedAt = jwt.NewNumericDate(timeNow())
|
||||
s.NotBefore = s.IssuedAt
|
||||
s.Audience = audience
|
||||
s.Issuer = issuer
|
||||
return &s
|
||||
}
|
||||
|
||||
// RouteSession creates a route session with access tokens stripped.
|
||||
func (s State) RouteSession() *State {
|
||||
s.AccessToken = nil
|
||||
return &s
|
||||
func NewSession(s *State, issuer string, audience []string, accessToken *oauth2.Token) State {
|
||||
newState := *s
|
||||
newState.IssuedAt = jwt.NewNumericDate(timeNow())
|
||||
newState.NotBefore = newState.IssuedAt
|
||||
newState.Audience = audience
|
||||
newState.Issuer = issuer
|
||||
newState.AccessTokenHash = fmt.Sprintf("%x", hashutil.Hash(accessToken))
|
||||
newState.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
||||
return newState
|
||||
}
|
||||
|
||||
// IsExpired returns true if the users's session is expired.
|
||||
func (s *State) IsExpired() bool {
|
||||
|
||||
if s.Expiry != nil && timeNow().After(s.Expiry.Time()) {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.AccessToken != nil && timeNow().After(s.AccessToken.Expiry) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
return s.Expiry != nil && timeNow().After(s.Expiry.Time())
|
||||
}
|
||||
|
||||
// Impersonating returns if the request is impersonating.
|
||||
|
@ -133,23 +72,6 @@ func (s *State) Impersonating() bool {
|
|||
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
|
||||
}
|
||||
|
||||
// RequestEmail is the email to make the request as.
|
||||
func (s *State) RequestEmail() string {
|
||||
if s.ImpersonateEmail != "" {
|
||||
return s.ImpersonateEmail
|
||||
}
|
||||
return s.Email
|
||||
}
|
||||
|
||||
// RequestGroups returns the groups of the Groups making the request; uses
|
||||
// impersonating user if set.
|
||||
func (s *State) RequestGroups() string {
|
||||
if len(s.ImpersonateGroups) != 0 {
|
||||
return strings.Join(s.ImpersonateGroups, ",")
|
||||
}
|
||||
return strings.Join(s.Groups, ",")
|
||||
}
|
||||
|
||||
// SetImpersonation sets impersonation user and groups.
|
||||
func (s *State) SetImpersonation(email, groups string) {
|
||||
s.ImpersonateEmail = email
|
||||
|
@ -159,34 +81,3 @@ func (s *State) SetImpersonation(email, groups string) {
|
|||
s.ImpersonateGroups = strings.Split(groups, ",")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) accessTokenHash() string {
|
||||
hash, err := hashstructure.Hash(
|
||||
s.AccessToken,
|
||||
&hashstructure.HashOptions{Hasher: xxhash.New()})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses the JSON-encoded session state.
|
||||
// TODO(BDD): remove in v0.8.0
|
||||
func (s *State) UnmarshalJSON(b []byte) error {
|
||||
type Alias State
|
||||
t := &struct {
|
||||
*Alias
|
||||
OldToken *oauth2.Token `json:"access_token,omitempty"` // < v0.5.0
|
||||
}{
|
||||
Alias: (*Alias)(s),
|
||||
}
|
||||
if err := json.Unmarshal(b, &t); err != nil {
|
||||
return err
|
||||
}
|
||||
if t.AccessToken == nil {
|
||||
t.AccessToken = t.OldToken
|
||||
}
|
||||
*s = *(*State)(t.Alias)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,8 +5,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
@ -38,12 +36,6 @@ func TestState_Impersonating(t *testing.T) {
|
|||
if got := s.Impersonating(); got != tt.want {
|
||||
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
|
||||
}
|
||||
if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail {
|
||||
t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
|
||||
}
|
||||
if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups {
|
||||
t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -63,16 +55,14 @@ func TestState_IsExpired(t *testing.T) {
|
|||
}{
|
||||
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
|
||||
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
||||
{"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &State{
|
||||
Audience: tt.Audience,
|
||||
Expiry: tt.Expiry,
|
||||
NotBefore: tt.NotBefore,
|
||||
IssuedAt: tt.IssuedAt,
|
||||
AccessToken: tt.AccessToken,
|
||||
Audience: tt.Audience,
|
||||
Expiry: tt.Expiry,
|
||||
NotBefore: tt.NotBefore,
|
||||
IssuedAt: tt.IssuedAt,
|
||||
}
|
||||
if exp := s.IsExpired(); exp != tt.wantErr {
|
||||
t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr)
|
||||
|
@ -80,67 +70,3 @@ func TestState_IsExpired(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_RouteSession(t *testing.T) {
|
||||
now := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return now
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
Issuer string
|
||||
Audience jwt.Audience
|
||||
Expiry *jwt.NumericDate
|
||||
AccessToken *oauth2.Token
|
||||
|
||||
issuer string
|
||||
|
||||
audience []string
|
||||
|
||||
want *State
|
||||
}{
|
||||
{"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow())}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := State{
|
||||
Issuer: tt.Issuer,
|
||||
Audience: tt.Audience,
|
||||
Expiry: tt.Expiry,
|
||||
AccessToken: tt.AccessToken,
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(State{}),
|
||||
}
|
||||
got := s.NewSession(tt.issuer, tt.audience)
|
||||
got = got.RouteSession()
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
t.Errorf("State.RouteSession() = %s", diff)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_accessTokenHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
state State
|
||||
want string
|
||||
}{
|
||||
{"empty access token", State{}, "34c96acdcadb1bbb"},
|
||||
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
|
||||
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
|
||||
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
|
||||
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &tt.state
|
||||
if got := s.accessTokenHash(); got != tt.want {
|
||||
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ const (
|
|||
QueryRefreshToken = "pomerium_refresh_token"
|
||||
QueryAccessTokenID = "pomerium_session_access_token_id"
|
||||
QueryAudience = "pomerium_session_audience"
|
||||
QueryProgrammaticToken = "pomerium_programmatic_token"
|
||||
)
|
||||
|
||||
// URL signature based query params used for verifying the authenticity of a URL.
|
||||
|
|
Loading…
Add table
Reference in a new issue