authenticate: save oauth2 tokens to cache (#698)

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2020-05-18 10:45:07 -07:00 committed by Travis Groth
parent ef399380b7
commit 666fd6aa35
31 changed files with 1127 additions and 1061 deletions

View file

@ -38,6 +38,12 @@ GETENVOY_VERSION = v0.1.8
all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet. all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet.
.PHONY: generate-mocks
generate-mocks: ## Generate mocks
@echo "==> $@"
@go run github.com/golang/mock/mockgen -destination authorize/evaluator/mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
@go run github.com/golang/mock/mockgen -destination internal/grpc/cache/mock/mock_cacher.go github.com/pomerium/pomerium/internal/grpc/cache Cacher
.PHONY: build-deps .PHONY: build-deps
build-deps: ## Install build dependencies build-deps: ## Install build dependencies
@echo "==> $@" @echo "==> $@"

View file

@ -17,12 +17,12 @@ import (
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/grpc" "github.com/pomerium/pomerium/internal/grpc"
"github.com/pomerium/pomerium/internal/grpc/cache"
"github.com/pomerium/pomerium/internal/grpc/cache/client" "github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cache"
"github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header" "github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/sessions/queryparam"
@ -93,7 +93,7 @@ type Authenticate struct {
provider identity.Authenticator provider identity.Authenticator
// cacheClient is the interface for setting and getting sessions from a cache // cacheClient is the interface for setting and getting sessions from a cache
cacheClient client.Cacher cacheClient cache.Cacher
templates *template.Template templates *template.Template
} }
@ -106,12 +106,12 @@ func New(opts config.Options) (*Authenticate, error) {
// shared state encoder setup // shared state encoder setup
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey) sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host) sharedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// private state encoder setup // private state encoder setup, used to encrypt oauth2 tokens
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret) cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
encryptedEncoder := ecjson.New(cookieCipher) encryptedEncoder := ecjson.New(cookieCipher)
@ -124,7 +124,7 @@ func New(opts config.Options) (*Authenticate, error) {
Expire: opts.CookieExpire, Expire: opts.CookieExpire,
} }
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder) cookieStore, err := cookie.NewStore(cookieOptions, sharedEncoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,13 +145,7 @@ func New(opts config.Options) (*Authenticate, error) {
cacheClient := client.New(cacheConn) cacheClient := client.New(cacheConn)
cacheStore := cache.NewStore(&cache.Options{ qpStore := queryparam.NewStore(encryptedEncoder, urlutil.QueryProgrammaticToken)
Cache: cacheClient,
Encoder: encryptedEncoder,
QueryParam: urlutil.QueryAccessTokenID,
WrappedStore: cookieStore})
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium) headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium)
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
@ -177,14 +171,14 @@ func New(opts config.Options) (*Authenticate, error) {
// shared state // shared state
sharedKey: opts.SharedKey, sharedKey: opts.SharedKey,
sharedCipher: sharedCipher, sharedCipher: sharedCipher,
sharedEncoder: signedEncoder, sharedEncoder: sharedEncoder,
// private state // private state
cookieSecret: decodedCookieSecret, cookieSecret: decodedCookieSecret,
cookieCipher: cookieCipher, cookieCipher: cookieCipher,
cookieOptions: cookieOptions, cookieOptions: cookieOptions,
sessionStore: cacheStore, sessionStore: cookieStore,
encryptedEncoder: encryptedEncoder, encryptedEncoder: encryptedEncoder,
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore}, sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
// IdP // IdP
provider: provider, provider: provider,
// grpc client for cache // grpc client for cache

View file

@ -91,6 +91,10 @@ func TestNew(t *testing.T) {
badGRPCConn.CacheURL = nil badGRPCConn.CacheURL = nil
badGRPCConn.CookieName = "D" badGRPCConn.CookieName = "D"
emptyProviderURL := newTestOptions(t)
emptyProviderURL.Provider = "oidc"
emptyProviderURL.ProviderURL = ""
tests := []struct { tests := []struct {
name string name string
opts *config.Options opts *config.Options
@ -103,6 +107,7 @@ func TestNew(t *testing.T) {
{"bad cookie name", badCookieName, true}, {"bad cookie name", badCookieName, true},
{"bad provider", badProvider, true}, {"bad provider", badProvider, true},
{"bad cache url", badGRPCConn, true}, {"bad cache url", badGRPCConn, true},
{"empty provider url", emptyProviderURL, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/csrf" "github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -23,6 +24,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/cors" "github.com/rs/cors"
"golang.org/x/oauth2"
) )
// Handler returns the authenticate service's handler chain. // Handler returns the authenticate service's handler chain.
@ -80,19 +82,15 @@ func (a *Authenticate) Mount(r *mux.Router) {
// session state is attached to the users's request context. // session state is attached to the users's request context.
func (a *Authenticate) VerifySession(next http.Handler) http.Handler { func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context() ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession")
jwt, err := sessions.FromContext(ctx) defer span.End()
s, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error") log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
if s.IsExpired() { if s.IsExpired() {
ctx, err = a.refresh(w, r, &s) ctx, err = a.refresh(w, r, s)
if err != nil { if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
@ -106,18 +104,34 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) { func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh") ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
defer span.End() defer span.End()
newSession, err := a.provider.Refresh(ctx, s) accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
return nil, err
}
// we are going to keep the same audiences for the refreshed token
// otherwise this will be rewritten to be the ClientID of the provider
oldAudience := s.Audience
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
if err != nil { if err != nil {
return nil, fmt.Errorf("authenticate: refresh failed: %w", err) return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
} }
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err) newSession := sessions.NewSession(s, a.RedirectURL.Hostname(), oldAudience, newAccessToken)
}
newSession = newSession.NewSession(s.Issuer, s.Audience) encSession, err := a.sharedEncoder.Marshal(newSession)
encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: error saving new session: %w", err)
}
if err := a.setAccessToken(ctx, newAccessToken); err != nil {
return nil, fmt.Errorf("authenticate: error saving refreshed access token: %w", err)
}
// return the new session and add it to the current request context // return the new session and add it to the current request context
return sessions.NewContext(ctx, string(encSession), err), nil return sessions.NewContext(ctx, string(encSession), err), nil
} }
@ -129,8 +143,11 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "User-agent: *\nDisallow: /") fmt.Fprintf(w, "User-agent: *\nDisallow: /")
} }
// SignIn handles to authenticating a user. // SignIn handles authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End()
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
@ -158,32 +175,30 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
jwtAudience = append(jwtAudience, fwdAuth) jwtAudience = append(jwtAudience, fwdAuth)
} }
jwt, err := sessions.FromContext(r.Context()) s, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return err
} }
var s sessions.State accessToken, err := a.getAccessToken(ctx, s)
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return err
} }
// user impersonation // user impersonation
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" { if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups)) s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
} }
newSession := sessions.NewSession(s, a.RedirectURL.Host, jwtAudience, accessToken)
// re-persist the session, useful when session was evicted from session // re-persist the session, useful when session was evicted from session
if err := a.sessionStore.SaveSession(w, r, &s); err != nil { if err := a.sessionStore.SaveSession(w, r, s); err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
newSession := s.NewSession(a.RedirectURL.Host, jwtAudience)
callbackParams := callbackURL.Query() callbackParams := callbackURL.Query()
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
newSession.Programmatic = true newSession.Programmatic = true
encSession, err := a.encryptedEncoder.Marshal(newSession) encSession, err := a.encryptedEncoder.Marshal(accessToken)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
@ -192,7 +207,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
} }
// sign the route session, as a JWT // sign the route session, as a JWT
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession()) signedJWT, err := a.sharedEncoder.Marshal(newSession)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
@ -217,27 +232,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
// SignOut signs the user out and attempts to revoke the user's identity session // SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST. // Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// no matter what happens, we want to clear the local session store ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End()
// no matter what happens, we want to clear the session store
a.sessionStore.ClearSession(w, r) a.sessionStore.ClearSession(w, r)
jwt, err := sessions.FromContext(r.Context())
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
redirectString := r.FormValue(urlutil.QueryRedirectURI) redirectString := r.FormValue(urlutil.QueryRedirectURI)
// first, try to revoke the session if implemented
err = a.provider.Revoke(r.Context(), s.AccessToken)
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
return httputil.NewError(http.StatusBadRequest, err)
}
// next, try to build a logout url if implemented
endSessionURL, err := a.provider.LogOut() endSessionURL, err := a.provider.LogOut()
if err == nil { if err == nil {
params := url.Values{} params := url.Values{}
@ -245,14 +245,29 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
endSessionURL.RawQuery = params.Encode() endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String() redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
return httputil.NewError(http.StatusBadRequest, err) log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
} }
redirectURL, err := urlutil.ParseAndValidateURL(redirectString) httputil.Redirect(w, r, redirectString, http.StatusFound)
s, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
return nil
}
accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting access token")
return nil
}
// first, try to revoke the session if implemented
err = a.provider.Revoke(ctx, accessToken)
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
log.Warn().Err(err).Msg("authenticate.SignOut: failed revoking token")
return nil
} }
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
return nil return nil
} }
@ -267,7 +282,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://tools.ietf.org/html/rfc6749#section-4.2.1
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest // https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error { func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
// If request AJAX/XHR request, return a 401 instead . // If request AJAX/XHR request, return a 401 instead because the redirect
// will almost certainly violate their CORs policy
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") { if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
return httputil.NewError(http.StatusUnauthorized, err) return httputil.NewError(http.StatusUnauthorized, err)
} }
@ -290,7 +306,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
redirect, err := a.getOAuthCallback(w, r) redirect, err := a.getOAuthCallback(w, r)
if err != nil { if err != nil {
return fmt.Errorf("oauth callback : %w", err) return fmt.Errorf("authenticate.OAuthCallback: %w", err)
} }
httputil.Redirect(w, r, redirect.String(), http.StatusFound) httputil.Redirect(w, r, redirect.String(), http.StatusFound)
return nil return nil
@ -306,6 +322,9 @@ func (a *Authenticate) statusForErrorCode(errorCode string) int {
} }
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
// //
// first, check if the identity provider returned an error // first, check if the identity provider returned an error
@ -321,14 +340,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5 // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
// //
// Exchange the supplied Authorization Code for a valid user session. // Exchange the supplied Authorization Code for a valid user session.
session, err := a.provider.Authenticate(r.Context(), code) var s sessions.State
accessToken, err := a.provider.Authenticate(ctx, code, &s)
if err != nil { if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err) return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
} }
newState := sessions.NewSession(
&s,
a.RedirectURL.Hostname(),
[]string{a.RedirectURL.Hostname()},
accessToken)
// state includes a csrf nonce (validated by middleware) and redirect uri // state includes a csrf nonce (validated by middleware) and redirect uri
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("bad bytes: %w", err))
} }
// split state into concat'd components // split state into concat'd components
@ -357,8 +384,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
return nil, httputil.NewError(http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
// OK. Looks good so let's persist our user session // Ok -- We've got a valid session here. Let's now persist the access
if err := a.sessionStore.SaveSession(w, r, session); err != nil { // token to cache ...
if err := a.setAccessToken(ctx, accessToken); err != nil {
return nil, fmt.Errorf("failed saving access token: %w", err)
}
// ... and the user state to local storage.
if err := a.sessionStore.SaveSession(w, r, &newState); err != nil {
return nil, fmt.Errorf("failed saving new session: %w", err) return nil, fmt.Errorf("failed saving new session: %w", err)
} }
return redirectURL, nil return redirectURL, nil
@ -368,26 +400,32 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// tokens and state with the identity provider. If successful, a new signed JWT // tokens and state with the identity provider. If successful, a new signed JWT
// and refresh token (`refresh_token`) are returned as JSON // and refresh token (`refresh_token`) are returned as JSON
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
jwt, err := sessions.FromContext(r.Context()) ctx, span := trace.StartSpan(r.Context(), "authenticate.RefreshAPI")
if err != nil { defer span.End()
return httputil.NewError(http.StatusBadRequest, err)
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
newSession, err := a.provider.Refresh(r.Context(), &s)
if err != nil {
return err
}
newSession = newSession.NewSession(s.Issuer, s.Audience)
encSession, err := a.encryptedEncoder.Marshal(newSession) s, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
return err return err
} }
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession()) accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
return err
}
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
if err != nil {
return err
}
routeNewSession := sessions.NewSession(s, a.RedirectURL.Hostname(), s.Audience, newAccessToken)
encSession, err := a.encryptedEncoder.Marshal(accessToken)
if err != nil {
return err
}
signedJWT, err := a.sharedEncoder.Marshal(routeNewSession)
if err != nil { if err != nil {
return err return err
} }
@ -410,28 +448,79 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
// Refresh is called by the proxy service to handle backend session refresh. // Refresh is called by the proxy service to handle backend session refresh.
// //
// NOTE: The actual refresh is handled as part of the "VerifySession" // NOTE: The actual refresh is handled as part of the "VerifySession"
// middleware. This handler is responsible for creating a new route scoped // middleware. This handler is simply responsible for returning that jwt.
// session and returning it.
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
jwt, err := sessions.FromContext(r.Context()) ctx, span := trace.StartSpan(r.Context(), "authenticate.Refresh")
defer span.End()
jwt, err := sessions.FromContext(ctx)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return fmt.Errorf("authenticate.Refresh: %w", err)
} }
var s sessions.State w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil { fmt.Fprint(w, jwt)
return httputil.NewError(http.StatusBadRequest, err) return nil
} }
aud := strings.Split(r.FormValue(urlutil.QueryAudience), ",") // getAccessToken gets an associated oauth2 access token from a session state
routeSession := s.NewSession(r.Host, aud) func (a *Authenticate) getAccessToken(ctx context.Context, s *sessions.State) (*oauth2.Token, error) {
routeSession.AccessTokenID = s.AccessTokenID ctx, span := trace.StartSpan(ctx, "authenticate.getAccessToken")
defer span.End()
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession()) var accessToken oauth2.Token
tokenBytes, err := a.cacheClient.Get(ctx, s.AccessTokenHash)
if err != nil {
return nil, err
}
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
return nil, err
}
if accessToken.Valid() {
return &accessToken, nil // this token is still valid, use it!
}
tokenBytes, err = a.cacheClient.Get(ctx, a.timestampedHash(accessToken.RefreshToken))
if err == nil {
// we found another possibly newer access token associated with the
// existing refresh token so let's try that.
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
return nil, err
}
}
return &accessToken, nil
}
func (a *Authenticate) setAccessToken(ctx context.Context, accessToken *oauth2.Token) error {
encToken, err := a.encryptedEncoder.Marshal(accessToken)
if err != nil { if err != nil {
return err return err
} }
// set this specific access token
key := fmt.Sprintf("%x", hashutil.Hash(accessToken))
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
}
// set this as the "latest" token for this access token
key = a.timestampedHash(accessToken.RefreshToken)
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
}
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
w.Write(signedJWT)
return nil return nil
} }
func (a *Authenticate) timestampedHash(s string) string {
return fmt.Sprintf("%x-%v", hashutil.Hash(s), time.Now().Truncate(time.Minute).Unix())
}
func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, error) {
jwt, err := sessions.FromContext(ctx)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
}
var s sessions.State
if err := a.sharedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
}
return &s, nil
}

View file

@ -11,20 +11,24 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/frontend"
mock_cache "github.com/pomerium/pomerium/internal/grpc/cache/mock"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/cookie"
mstore "github.com/pomerium/pomerium/internal/sessions/mock" mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/golang/mock/gomock"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
@ -115,23 +119,28 @@ func TestAuthenticate_SignIn(t *testing.T) {
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
wantCode int wantCode int
}{ }{
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
a := &Authenticate{ a := &Authenticate{
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
@ -144,6 +153,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
Name: "cookie", Name: "cookie",
Domain: "foo", Domain: "foo",
}, },
cacheClient: mc,
} }
uri := &url.URL{Scheme: tt.scheme, Host: tt.host} uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
@ -176,6 +186,7 @@ func uriParseHelper(s string) *url.URL {
func TestAuthenticate_SignOut(t *testing.T) { func TestAuthenticate_SignOut(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
method string method string
@ -190,19 +201,24 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int wantCode int
wantBody string wantBody string
}{ }{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""}, {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"}, {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"}, {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"}, {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
a := &Authenticate{ a := &Authenticate{
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
provider: tt.provider, provider: tt.provider,
encryptedEncoder: mock.Encoder{}, encryptedEncoder: mock.Encoder{},
templates: template.Must(frontend.NewTemplates()), templates: template.Must(frontend.NewTemplates()),
sharedEncoder: mock.Encoder{},
cacheClient: mc,
} }
u, _ := url.Parse("/sign_out") u, _ := url.Parse("/sign_out")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
@ -256,40 +272,50 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string want string
wantCode int wantCode int
}{ }{
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound}, {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}, AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError}, {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusUnauthorized}, {"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusUnauthorized},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
signer, err := jws.NewHS256Signer(nil, "mock")
if err != nil {
t.Fatal(err)
}
authURL, _ := url.Parse(tt.authenticateURL) authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{ a := &Authenticate{
RedirectURL: authURL, RedirectURL: authURL,
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
cookieCipher: aead, cookieCipher: aead,
cacheClient: mc,
encryptedEncoder: signer,
} }
u, _ := url.Parse("/oauthGet") u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
params.Add("error", tt.paramErr) params.Add("error", tt.paramErr)
params.Add("code", tt.code) params.Add("code", tt.code)
nonce := cryptutil.NewBase64Key() // mock csrf nonce := cryptutil.NewBase64Key() // mock csrf
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts)) // (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac)) b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
@ -336,15 +362,21 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
wantStatus int wantStatus int
}{ }{
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, {"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, {"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, {"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, {"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, {"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusFound},
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized}, {"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -361,6 +393,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
provider: tt.provider, provider: tt.provider,
cookieCipher: aead, cookieCipher: aead,
encryptedEncoder: signer, encryptedEncoder: signer,
cacheClient: mc,
sharedEncoder: mock.Encoder{},
} }
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
state, err := tt.session.LoadSession(r) state, err := tt.session.LoadSession(r)
@ -402,14 +436,20 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
wantStatus int wantStatus int
}{ }{
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, {"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, {"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, {"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError}, {"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -423,6 +463,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
cookieCipher: aead, cookieCipher: aead,
cacheClient: mc,
} }
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r) state, _ := tt.session.LoadSession(r)
@ -441,53 +482,111 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
}) })
} }
} }
func TestAuthenticate_Refresh(t *testing.T) { func TestAuthenticate_Refresh(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
session sessions.SessionStore session *sessions.State
ctxError error at *oauth2.Token
provider identity.Authenticator provider identity.Authenticator
secretEncoder encoding.MarshalUnmarshaler secretEncoder encoding.MarshalUnmarshaler
sharedEncoder encoding.MarshalUnmarshaler
wantStatus int wantStatus int
}{ }{
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, {"good",
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))},
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError}, &oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
mock.Encoder{MarshalResponse: []byte("ok")},
200},
{"session and oauth2 expired",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(-10 * time.Minute)},
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
mock.Encoder{MarshalResponse: []byte("ok")},
200},
{"session expired",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
mock.Encoder{MarshalResponse: []byte("ok")},
200},
{"failed refresh",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
identity.MockProvider{RefreshError: errors.New("oh no")},
mock.Encoder{MarshalResponse: []byte("ok")},
302},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
// just enough is stubbed out here so we can use our own mock provider
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
out := fmt.Sprintf(`{"issuer":"http://%s"}`, r.Host)
fmt.Fprintln(w, out)
}))
defer ts.Close()
rURL := ts.URL
a, err := New(config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
AuthenticateURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
Provider: "oidc",
ClientID: "mock",
ClientSecret: "mock",
ProviderURL: rURL,
AuthenticateCallbackPath: "mock",
CookieName: "pomerium",
Addr: ":0",
CacheURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a := Authenticate{ a.cacheClient = mc
sharedKey: cryptutil.NewBase64Key(), a.provider = tt.provider
cookieSecret: cryptutil.NewKey(),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
encryptedEncoder: tt.secretEncoder,
sharedEncoder: tt.sharedEncoder,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery)
destination := urlutil.NewSignedURL(a.sharedKey,
&url.URL{
Scheme: "https",
Host: "example.com",
Path: "/.pomerium/refresh"})
u.RawQuery = params.Encode()
r := httptest.NewRequest(http.MethodGet, destination.String(), nil)
jwt, err := a.sharedEncoder.Marshal(tt.session)
if err != nil {
t.Fatal(err)
}
rawToken, err := a.encryptedEncoder.Marshal(tt.at)
if err != nil {
t.Fatal(err)
}
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return(rawToken, nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
a.cacheClient = mc
r.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", jwt))
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r) router := mux.NewRouter()
a.Mount(router)
router.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) t.Errorf("Refresh() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
} }
}) })
} }

View file

@ -1,5 +1,3 @@
//go:generate mockgen -destination mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
// Package evaluator defines a Evaluator interfaces that can be implemented by // Package evaluator defines a Evaluator interfaces that can be implemented by
// a policy evaluator framework. // a policy evaluator framework.
package evaluator package evaluator

View file

@ -218,10 +218,6 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
// 1 - build a signed url to call refresh on authenticate service // 1 - build a signed url to call refresh on authenticate service
refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"}) refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"})
q := refreshURI.Query()
q.Set(urlutil.QueryAccessTokenID, state.AccessTokenID) // hash value points to parent token
q.Set(urlutil.QueryAudience, strings.Join(state.Audience, ",")) // request's audience, this route
refreshURI.RawQuery = q.Encode()
signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String() signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String()
// 2 - http call to authenticate service // 2 - http call to authenticate service
@ -229,6 +225,7 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: refresh request: %w", err) return nil, fmt.Errorf("authorize: refresh request: %w", err)
} }
req.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", rawSession))
req.Header.Set("X-Requested-With", "XmlHttpRequest") req.Header.Set("X-Requested-With", "XmlHttpRequest")
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")

18
cache/grpc_test.go vendored
View file

@ -11,6 +11,7 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/grpc/cache" "github.com/pomerium/pomerium/internal/grpc/cache"
@ -50,9 +51,6 @@ func TestCache_Get_and_Set(t *testing.T) {
&cache.GetReply{ &cache.GetReply{
Exists: true, Exists: true,
Value: []byte("hello"), Value: []byte("hello"),
XXX_NoUnkeyedLiteral: struct{}{},
XXX_unrecognized: nil,
XXX_sizecache: 0,
}, },
false, false,
false, false,
@ -65,9 +63,6 @@ func TestCache_Get_and_Set(t *testing.T) {
&cache.GetReply{ &cache.GetReply{
Exists: false, Exists: false,
Value: nil, Value: nil,
XXX_NoUnkeyedLiteral: struct{}{},
XXX_unrecognized: nil,
XXX_sizecache: 0,
}, },
false, false,
false, false,
@ -80,9 +75,6 @@ func TestCache_Get_and_Set(t *testing.T) {
&cache.GetReply{ &cache.GetReply{
Exists: false, Exists: false,
Value: nil, Value: nil,
XXX_NoUnkeyedLiteral: struct{}{},
XXX_unrecognized: nil,
XXX_sizecache: 0,
}, },
true, true,
false, false,
@ -96,7 +88,11 @@ func TestCache_Get_and_Set(t *testing.T) {
t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError) t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError)
return return
} }
if diff := cmp.Diff(setGot, tt.SetReply); diff != "" { cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(cache.SetReply{}, cache.GetReply{}),
}
if diff := cmp.Diff(setGot, tt.SetReply, cmpOpts...); diff != "" {
t.Errorf("Cache.Set() = %v", diff) t.Errorf("Cache.Set() = %v", diff)
} }
getGot, err := c.Get(tt.ctx, tt.GetRequest) getGot, err := c.Get(tt.ctx, tt.GetRequest)
@ -104,7 +100,7 @@ func TestCache_Get_and_Set(t *testing.T) {
t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError) t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError)
return return
} }
if diff := cmp.Diff(getGot, tt.GetReply); diff != "" { if diff := cmp.Diff(getGot, tt.GetReply, cmpOpts...); diff != "" {
t.Errorf("Cache.Get() = %v", diff) t.Errorf("Cache.Get() = %v", diff)
} }
}) })

View file

@ -5,6 +5,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"sort" "sort"
"time"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
@ -151,6 +152,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options,
}, },
Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{ Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{
GrpcService: &envoy_config_core_v3.GrpcService{ GrpcService: &envoy_config_core_v3.GrpcService{
Timeout: ptypes.DurationProto(time.Second * 30),
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
ClusterName: "pomerium-authz", ClusterName: "pomerium-authz",

14
internal/grpc/cache/cache.go vendored Normal file
View file

@ -0,0 +1,14 @@
// Package cache defines a Cacher interfaces that can be implemented by any
// key value store system.
package cache
import (
"context"
)
// Cacher specifies an interface for remote clients connecting to the cache service.
type Cacher interface {
Get(ctx context.Context, key string) (value []byte, err error)
Set(ctx context.Context, key string, value []byte) error
Close() error
}

View file

@ -1,217 +1,356 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.21.0
// protoc v3.11.4
// source: cache.proto // source: cache.proto
package cache package cache
import ( import (
context "context" context "context"
fmt "fmt"
proto "github.com/golang/protobuf/proto" proto "github.com/golang/protobuf/proto"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes" codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status" status "google.golang.org/grpc/status"
math "math" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
) )
// Reference imports to suppress errors if they are not otherwise used. const (
var _ = proto.Marshal // Verify that this generated code is sufficiently up-to-date.
var _ = fmt.Errorf _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
var _ = math.Inf // Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion that a sufficiently up-to-date version
// is compatible with the proto package it is being compiled against. // of the legacy proto package is being used.
// A compilation error at this line likely means your copy of the const _ = proto.ProtoPackageIsVersion4
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type GetRequest struct { type GetRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
} }
func (m *GetRequest) Reset() { *m = GetRequest{} } func (x *GetRequest) Reset() {
func (m *GetRequest) String() string { return proto.CompactTextString(m) } *x = GetRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetRequest) ProtoMessage() {} func (*GetRequest) ProtoMessage() {}
func (x *GetRequest) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetRequest.ProtoReflect.Descriptor instead.
func (*GetRequest) Descriptor() ([]byte, []int) { func (*GetRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{0} return file_cache_proto_rawDescGZIP(), []int{0}
} }
func (m *GetRequest) XXX_Unmarshal(b []byte) error { func (x *GetRequest) GetKey() string {
return xxx_messageInfo_GetRequest.Unmarshal(m, b) if x != nil {
} return x.Key
func (m *GetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GetRequest.Marshal(b, m, deterministic)
}
func (m *GetRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_GetRequest.Merge(m, src)
}
func (m *GetRequest) XXX_Size() int {
return xxx_messageInfo_GetRequest.Size(m)
}
func (m *GetRequest) XXX_DiscardUnknown() {
xxx_messageInfo_GetRequest.DiscardUnknown(m)
}
var xxx_messageInfo_GetRequest proto.InternalMessageInfo
func (m *GetRequest) GetKey() string {
if m != nil {
return m.Key
} }
return "" return ""
} }
type GetReply struct { type GetReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"` Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
} }
func (m *GetReply) Reset() { *m = GetReply{} } func (x *GetReply) Reset() {
func (m *GetReply) String() string { return proto.CompactTextString(m) } *x = GetReply{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetReply) ProtoMessage() {} func (*GetReply) ProtoMessage() {}
func (x *GetReply) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetReply.ProtoReflect.Descriptor instead.
func (*GetReply) Descriptor() ([]byte, []int) { func (*GetReply) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{1} return file_cache_proto_rawDescGZIP(), []int{1}
} }
func (m *GetReply) XXX_Unmarshal(b []byte) error { func (x *GetReply) GetExists() bool {
return xxx_messageInfo_GetReply.Unmarshal(m, b) if x != nil {
} return x.Exists
func (m *GetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GetReply.Marshal(b, m, deterministic)
}
func (m *GetReply) XXX_Merge(src proto.Message) {
xxx_messageInfo_GetReply.Merge(m, src)
}
func (m *GetReply) XXX_Size() int {
return xxx_messageInfo_GetReply.Size(m)
}
func (m *GetReply) XXX_DiscardUnknown() {
xxx_messageInfo_GetReply.DiscardUnknown(m)
}
var xxx_messageInfo_GetReply proto.InternalMessageInfo
func (m *GetReply) GetExists() bool {
if m != nil {
return m.Exists
} }
return false return false
} }
func (m *GetReply) GetValue() []byte { func (x *GetReply) GetValue() []byte {
if m != nil { if x != nil {
return m.Value return x.Value
} }
return nil return nil
} }
type SetRequest struct { type SetRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
} }
func (m *SetRequest) Reset() { *m = SetRequest{} } func (x *SetRequest) Reset() {
func (m *SetRequest) String() string { return proto.CompactTextString(m) } *x = SetRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetRequest) ProtoMessage() {} func (*SetRequest) ProtoMessage() {}
func (x *SetRequest) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetRequest.ProtoReflect.Descriptor instead.
func (*SetRequest) Descriptor() ([]byte, []int) { func (*SetRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{2} return file_cache_proto_rawDescGZIP(), []int{2}
} }
func (m *SetRequest) XXX_Unmarshal(b []byte) error { func (x *SetRequest) GetKey() string {
return xxx_messageInfo_SetRequest.Unmarshal(m, b) if x != nil {
} return x.Key
func (m *SetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SetRequest.Marshal(b, m, deterministic)
}
func (m *SetRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_SetRequest.Merge(m, src)
}
func (m *SetRequest) XXX_Size() int {
return xxx_messageInfo_SetRequest.Size(m)
}
func (m *SetRequest) XXX_DiscardUnknown() {
xxx_messageInfo_SetRequest.DiscardUnknown(m)
}
var xxx_messageInfo_SetRequest proto.InternalMessageInfo
func (m *SetRequest) GetKey() string {
if m != nil {
return m.Key
} }
return "" return ""
} }
func (m *SetRequest) GetValue() []byte { func (x *SetRequest) GetValue() []byte {
if m != nil { if x != nil {
return m.Value return x.Value
} }
return nil return nil
} }
type SetReply struct { type SetReply struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"` state protoimpl.MessageState
XXX_unrecognized []byte `json:"-"` sizeCache protoimpl.SizeCache
XXX_sizecache int32 `json:"-"` unknownFields protoimpl.UnknownFields
}
func (x *SetReply) Reset() {
*x = SetReply{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetReply) String() string {
return protoimpl.X.MessageStringOf(x)
} }
func (m *SetReply) Reset() { *m = SetReply{} }
func (m *SetReply) String() string { return proto.CompactTextString(m) }
func (*SetReply) ProtoMessage() {} func (*SetReply) ProtoMessage() {}
func (x *SetReply) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetReply.ProtoReflect.Descriptor instead.
func (*SetReply) Descriptor() ([]byte, []int) { func (*SetReply) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{3} return file_cache_proto_rawDescGZIP(), []int{3}
} }
func (m *SetReply) XXX_Unmarshal(b []byte) error { var File_cache_proto protoreflect.FileDescriptor
return xxx_messageInfo_SetReply.Unmarshal(m, b)
} var file_cache_proto_rawDesc = []byte{
func (m *SetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { 0x0a, 0x0b, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63,
return xxx_messageInfo_SetReply.Marshal(b, m, deterministic) 0x61, 0x63, 0x68, 0x65, 0x22, 0x1e, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
} 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
func (m *SetReply) XXX_Merge(src proto.Message) { 0x03, 0x6b, 0x65, 0x79, 0x22, 0x38, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
xxx_messageInfo_SetReply.Merge(m, src) 0x12, 0x16, 0x0a, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
} 0x52, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
func (m *SetReply) XXX_Size() int { 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x34,
return xxx_messageInfo_SetReply.Size(m) 0x0a, 0x0a, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03,
} 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14,
func (m *SetReply) XXX_DiscardUnknown() { 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76,
xxx_messageInfo_SetReply.DiscardUnknown(m) 0x61, 0x6c, 0x75, 0x65, 0x22, 0x0a, 0x0a, 0x08, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x32, 0x61, 0x0a, 0x05, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x2b, 0x0a, 0x03, 0x47, 0x65, 0x74,
0x12, 0x11, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52,
0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x03, 0x53, 0x65, 0x74, 0x12, 0x11, 0x2e,
0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c,
0x79, 0x22, 0x00, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var xxx_messageInfo_SetReply proto.InternalMessageInfo var (
file_cache_proto_rawDescOnce sync.Once
file_cache_proto_rawDescData = file_cache_proto_rawDesc
)
func init() { func file_cache_proto_rawDescGZIP() []byte {
proto.RegisterType((*GetRequest)(nil), "cache.GetRequest") file_cache_proto_rawDescOnce.Do(func() {
proto.RegisterType((*GetReply)(nil), "cache.GetReply") file_cache_proto_rawDescData = protoimpl.X.CompressGZIP(file_cache_proto_rawDescData)
proto.RegisterType((*SetRequest)(nil), "cache.SetRequest") })
proto.RegisterType((*SetReply)(nil), "cache.SetReply") return file_cache_proto_rawDescData
} }
func init() { var file_cache_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
proto.RegisterFile("cache.proto", fileDescriptor_5fca3b110c9bbf3a) var file_cache_proto_goTypes = []interface{}{
(*GetRequest)(nil), // 0: cache.GetRequest
(*GetReply)(nil), // 1: cache.GetReply
(*SetRequest)(nil), // 2: cache.SetRequest
(*SetReply)(nil), // 3: cache.SetReply
}
var file_cache_proto_depIdxs = []int32{
0, // 0: cache.Cache.Get:input_type -> cache.GetRequest
2, // 1: cache.Cache.Set:input_type -> cache.SetRequest
1, // 2: cache.Cache.Get:output_type -> cache.GetReply
3, // 3: cache.Cache.Set:output_type -> cache.SetReply
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
} }
var fileDescriptor_5fca3b110c9bbf3a = []byte{ func init() { file_cache_proto_init() }
// 176 bytes of a gzipped FileDescriptorProto func file_cache_proto_init() {
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x4e, 0x4c, 0xce, if File_cache_proto != nil {
0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0xb8, 0xdc, return
0x53, 0x4b, 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x04, 0xb8, 0x98, 0xb3, 0x53, 0x2b, }
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x40, 0x4c, 0x25, 0x0b, 0x2e, 0x0e, 0xb0, 0x7c, 0x41, if !protoimpl.UnsafeEnabled {
0x4e, 0xa5, 0x90, 0x18, 0x17, 0x5b, 0x6a, 0x45, 0x66, 0x71, 0x49, 0x31, 0x58, 0x01, 0x47, 0x10, file_cache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
0x94, 0x27, 0x24, 0xc2, 0xc5, 0x5a, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, 0xa4, 0xc0, 0xa8, 0xc1, switch v := v.(*GetRequest); i {
0x13, 0x04, 0xe1, 0x28, 0x99, 0x70, 0x71, 0x05, 0xe3, 0x31, 0x19, 0x87, 0x2e, 0x2e, 0x2e, 0x8e, case 0:
0x60, 0xa8, 0x7d, 0x46, 0x89, 0x5c, 0xac, 0xce, 0x20, 0x47, 0x0a, 0x69, 0x73, 0x31, 0xbb, 0xa7, return &v.state
0x96, 0x08, 0x09, 0xea, 0x41, 0x3c, 0x80, 0x70, 0xb0, 0x14, 0x3f, 0xb2, 0x50, 0x41, 0x4e, 0xa5, case 1:
0x12, 0x03, 0x48, 0x71, 0x30, 0x92, 0xe2, 0x60, 0x4c, 0xc5, 0xc1, 0x70, 0xc5, 0x49, 0x6c, 0xe0, return &v.sizeCache
0xc0, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x0e, 0xef, 0x5f, 0x9e, 0x1b, 0x01, 0x00, 0x00, case 2:
return &v.unknownFields
default:
return nil
}
}
file_cache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_cache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_cache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_cache_proto_rawDesc,
NumEnums: 0,
NumMessages: 4,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_cache_proto_goTypes,
DependencyIndexes: file_cache_proto_depIdxs,
MessageInfos: file_cache_proto_msgTypes,
}.Build()
File_cache_proto = out.File
file_cache_proto_rawDesc = nil
file_cache_proto_goTypes = nil
file_cache_proto_depIdxs = nil
} }
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
@ -266,10 +405,10 @@ type CacheServer interface {
type UnimplementedCacheServer struct { type UnimplementedCacheServer struct {
} }
func (*UnimplementedCacheServer) Get(ctx context.Context, req *GetRequest) (*GetReply, error) { func (*UnimplementedCacheServer) Get(context.Context, *GetRequest) (*GetReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented") return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
} }
func (*UnimplementedCacheServer) Set(ctx context.Context, req *SetRequest) (*SetReply, error) { func (*UnimplementedCacheServer) Set(context.Context, *SetRequest) (*SetReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented") return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
} }

View file

@ -3,6 +3,7 @@ package client
import ( import (
"context" "context"
"errors"
"github.com/pomerium/pomerium/internal/grpc/cache" "github.com/pomerium/pomerium/internal/grpc/cache"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
@ -10,12 +11,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
// Cacher specifies an interface for remote clients connecting to the cache service. var errKeyNotFound = errors.New("cache/client: key not found")
type Cacher interface {
Get(ctx context.Context, key string) (keyExists bool, value []byte, err error)
Set(ctx context.Context, key string, value []byte) error
Close() error
}
// Client represents a gRPC cache service client. // Client represents a gRPC cache service client.
type Client struct { type Client struct {
@ -29,15 +25,18 @@ func New(conn *grpc.ClientConn) (p *Client) {
} }
// Get retrieves a value from the cache service. // Get retrieves a value from the cache service.
func (a *Client) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) { func (a *Client) Get(ctx context.Context, key string) (value []byte, err error) {
ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get") ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get")
defer span.End() defer span.End()
response, err := a.client.Get(ctx, &cache.GetRequest{Key: key}) response, err := a.client.Get(ctx, &cache.GetRequest{Key: key})
if err != nil { if err != nil {
return false, nil, err return nil, err
} }
return response.GetExists(), response.GetValue(), nil if !response.GetExists() {
return nil, errKeyNotFound
}
return response.GetValue(), nil
} }
// Set stores a key value pair in the cache service. // Set stores a key value pair in the cache service.

77
internal/grpc/cache/mock/mock_cacher.go vendored Normal file
View file

@ -0,0 +1,77 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/pomerium/pomerium/internal/grpc/cache (interfaces: Cacher)
// Package mock_cache is a generated GoMock package.
package mock_cache
import (
context "context"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockCacher is a mock of Cacher interface
type MockCacher struct {
ctrl *gomock.Controller
recorder *MockCacherMockRecorder
}
// MockCacherMockRecorder is the mock recorder for MockCacher
type MockCacherMockRecorder struct {
mock *MockCacher
}
// NewMockCacher creates a new mock instance
func NewMockCacher(ctrl *gomock.Controller) *MockCacher {
mock := &MockCacher{ctrl: ctrl}
mock.recorder = &MockCacherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCacher) EXPECT() *MockCacherMockRecorder {
return m.recorder
}
// Close mocks base method
func (m *MockCacher) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockCacherMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCacher)(nil).Close))
}
// Get mocks base method
func (m *MockCacher) Get(arg0 context.Context, arg1 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockCacherMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCacher)(nil).Get), arg0, arg1)
}
// Set mocks base method
func (m *MockCacher) Set(arg0 context.Context, arg1 string, arg2 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// Set indicates an expected call of Set
func (mr *MockCacherMockRecorder) Set(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCacher)(nil).Set), arg0, arg1, arg2)
}

View file

@ -0,0 +1,22 @@
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
package hashutil
import (
"github.com/cespare/xxhash/v2"
"github.com/mitchellh/hashstructure"
)
// Hash returns the xxhash value of an arbitrary value or struct. Returns 0
// on error. NOT SUITABLE FOR CRYTOGRAPHIC HASHING.
//
// http://cyan4973.github.io/xxHash/
func Hash(v interface{}) uint64 {
opts := &hashstructure.HashOptions{
Hasher: xxhash.New(),
}
hash, err := hashstructure.Hash(v, opts)
if err != nil {
hash = 0
}
return hash
}

View file

@ -0,0 +1,37 @@
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
package hashutil
import "testing"
func TestHash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
v interface{}
want uint64
}{
{"string", "string", 6134271061086542852},
{"num", 7, 609900476111905877},
{"compound struct", struct {
NESCarts []string
numberOfCarts int
}{
[]string{"Battletoads", "Mega Man 1", "Clash at Demonhead"},
12,
},
9061978360207659575},
{"compound struct with embedded func (errors!)", struct {
AnswerToEverythingFn func() int
}{
func() int { return 42 },
},
0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Hash(tt.v); got != tt.want {
t.Errorf("Hash() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -5,15 +5,13 @@ import (
"net/url" "net/url"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/sessions"
) )
// MockProvider provides a mocked implementation of the providers interface. // MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct { type MockProvider struct {
AuthenticateResponse sessions.State AuthenticateResponse oauth2.Token
AuthenticateError error AuthenticateError error
RefreshResponse sessions.State RefreshResponse oauth2.Token
RefreshError error RefreshError error
RevokeError error RevokeError error
GetSignInURLResponse string GetSignInURLResponse string
@ -22,12 +20,12 @@ type MockProvider struct {
} }
// Authenticate is a mocked providers function. // Authenticate is a mocked providers function.
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { func (mp MockProvider) Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) {
return &mp.AuthenticateResponse, mp.AuthenticateError return &mp.AuthenticateResponse, mp.AuthenticateError
} }
// Refresh is a mocked providers function. // Refresh is a mocked providers function.
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { func (mp MockProvider) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) {
return &mp.RefreshResponse, mp.RefreshError return &mp.RefreshResponse, mp.RefreshError
} }

View file

@ -20,7 +20,6 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
) )
@ -77,48 +76,36 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// Authenticate creates an identity session with github from a authorization code, and follows up // Authenticate creates an identity session with github from a authorization code, and follows up
// call to the user and user group endpoint with the // call to the user and user group endpoint with the
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
resp, err := p.Oauth.Exchange(ctx, code) oauth2Token, err := p.Oauth.Exchange(ctx, code)
if err != nil { if err != nil {
return nil, fmt.Errorf("github: token exchange failed %v", err) return nil, fmt.Errorf("github: token exchange failed %v", err)
} }
s := &sessions.State{ err = p.updateSessionState(ctx, oauth2Token, v)
AccessToken: &oauth2.Token{
AccessToken: resp.AccessToken,
TokenType: resp.TokenType,
},
AccessTokenID: resp.AccessToken,
}
err = p.updateSessionState(ctx, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, nil return oauth2Token, nil
} }
// updateSessionState will get the user information from github and also retrieve the user's team(s) // updateSessionState will get the user information from github and also retrieve the user's team(s)
// //
// https://developer.github.com/v3/users/#get-the-authenticated-user // https://developer.github.com/v3/users/#get-the-authenticated-user
func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) error { func (p *Provider) updateSessionState(ctx context.Context, t *oauth2.Token, v interface{}) error {
if s == nil || s.AccessToken == nil {
return errors.New("github: user session cannot be empty")
}
accessToken := s.AccessToken.AccessToken
err := p.userInfo(ctx, accessToken, s) err := p.userInfo(ctx, t, v)
if err != nil { if err != nil {
return fmt.Errorf("github: could not retrieve user info %w", err) return fmt.Errorf("github: could not retrieve user info %w", err)
} }
err = p.userEmail(ctx, accessToken, s) err = p.userEmail(ctx, t, v)
if err != nil { if err != nil {
return fmt.Errorf("github: could not retrieve user email %w", err) return fmt.Errorf("github: could not retrieve user email %w", err)
} }
err = p.userTeams(ctx, accessToken, s) err = p.userTeams(ctx, t, v)
if err != nil { if err != nil {
return fmt.Errorf("github: could not retrieve groups %w", err) return fmt.Errorf("github: could not retrieve groups %w", err)
} }
@ -127,14 +114,12 @@ func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) er
} }
// Refresh renews a user's session by making a new userInfo request. // Refresh renews a user's session by making a new userInfo request.
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
if s.AccessToken == nil { err := p.updateSessionState(ctx, t, v)
return nil, errors.New("github: missing oauth2 access token") if err != nil {
}
if err := p.updateSessionState(ctx, s); err != nil {
return nil, err return nil, err
} }
return s, nil return t, nil
} }
// userTeams returns a slice of teams the user belongs by making a request // userTeams returns a slice of teams the user belongs by making a request
@ -142,7 +127,7 @@ func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.St
// //
// https://developer.github.com/v3/teams/#list-user-teams // https://developer.github.com/v3/teams/#list-user-teams
// https://developer.github.com/v3/auth/ // https://developer.github.com/v3/auth/
func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) error { func (p *Provider) userTeams(ctx context.Context, t *oauth2.Token, v interface{}) error {
var response []struct { var response []struct {
ID json.Number `json:"id"` ID json.Number `json:"id"`
@ -154,20 +139,24 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
Privacy string `json:"privacy,omitempty"` Privacy string `json:"privacy,omitempty"`
} }
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)} headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
teamURL := githubAPIURL + teamPath teamURL := githubAPIURL + teamPath
err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response) err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response)
if err != nil { if err != nil {
return err return err
} }
log.Debug().Interface("teams", response).Msg("github: user teams") log.Debug().Interface("teams", response).Msg("github: user teams")
s.Groups = nil var out struct {
for _, org := range response { Groups []string `json:"groups"`
s.Groups = append(s.Groups, org.ID.String())
} }
for _, org := range response {
return nil out.Groups = append(out.Groups, org.ID.String())
}
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
} }
// userEmail returns the primary email of the user by making // userEmail returns the primary email of the user by making
@ -175,7 +164,7 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
// //
// https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user // https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user
// https://developer.github.com/v3/auth/ // https://developer.github.com/v3/auth/
func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State) error { func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}) error {
// response represents the github user email // response represents the github user email
// https://developer.github.com/v3/users/emails/#response // https://developer.github.com/v3/users/emails/#response
var response []struct { var response []struct {
@ -184,48 +173,67 @@ func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State)
Primary bool `json:"primary"` Primary bool `json:"primary"`
Visibility string `json:"visibility"` Visibility string `json:"visibility"`
} }
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)} headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
emailURL := githubAPIURL + emailPath emailURL := githubAPIURL + emailPath
err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response) err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response)
if err != nil { if err != nil {
return err return err
} }
var out struct {
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
log.Debug().Interface("emails", response).Msg("github: user emails") log.Debug().Interface("emails", response).Msg("github: user emails")
for _, email := range response { for _, email := range response {
if email.Primary && email.Verified { if email.Primary && email.Verified {
s.Email = email.Email out.Email = email.Email
s.EmailVerified = true out.Verified = true
return nil break
} }
} }
return nil b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
} }
func (p *Provider) userInfo(ctx context.Context, at string, s *sessions.State) error { func (p *Provider) userInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
var response struct { var response struct {
ID int `json:"id"` ID int `json:"id"`
Login string `json:"login"` Login string `json:"login"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url,omitempty"` AvatarURL string `json:"avatar_url,omitempty"`
} }
headers := map[string]string{ headers := map[string]string{
"Authorization": fmt.Sprintf("token %s", at), "Authorization": fmt.Sprintf("token %s", t.AccessToken),
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
} }
err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response) err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response)
if err != nil { if err != nil {
return err return err
} }
var out struct {
Subject string `json:"sub"`
Name string `json:"name,omitempty"`
User string `json:"user"`
Picture string `json:"picture,omitempty"`
// needs to be set manually
Expiry *jwt.NumericDate `json:"exp,omitempty"`
}
s.User = response.Login out.User = response.Login
s.Name = response.Name out.Subject = response.Login
s.Picture = response.AvatarURL out.Name = response.Name
out.Picture = response.AvatarURL
// set the session expiry // set the session expiry
s.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline)) out.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
return nil b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
} }
// Revoke method will remove all the github grants the user // Revoke method will remove all the github grants the user

View file

@ -5,7 +5,7 @@ package azure
import ( import (
"context" "context"
"errors" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@ -16,7 +16,6 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
) )
@ -60,10 +59,7 @@ func (p *Provider) GetSignInURL(state string) string {
// `Directory.Read.All` is required. // `Directory.Read.All` is required.
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0 // https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0 // https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
if s == nil || s.AccessToken == nil {
return nil, errors.New("identity/azure: session cannot be nil")
}
var response struct { var response struct {
Groups []struct { Groups []struct {
ID string `json:"id"` ID string `json:"id"`
@ -73,15 +69,23 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
GroupTypes []string `json:"groupTypes,omitempty"` GroupTypes []string `json:"groupTypes,omitempty"`
} `json:"value"` } `json:"value"`
} }
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)} headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response) err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil { if err != nil {
return nil, err return err
}
log.Debug().Interface("response", response).Msg("microsoft: groups")
var out struct {
Groups []string `json:"groups"`
} }
var groups []string
for _, group := range response.Groups { for _, group := range response.Groups {
log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("microsoft: group") out.Groups = append(out.Groups, group.ID)
groups = append(groups, group.ID)
} }
return groups, nil b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
} }

View file

@ -1,12 +1,14 @@
package oidc package oidc
import "errors" import (
"errors"
)
// ErrRevokeNotImplemented error type when Revoke method is not implemented // ErrRevokeNotImplemented is returned when revoke is not implemented
// by an identity provider // by an identity provider.
var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented") var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented")
// ErrSignoutNotImplemented error type when end session is not implemented // ErrSignoutNotImplemented is returned when end session is not implemented
// by an identity provider // by an identity provider
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated // https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented") var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
@ -14,3 +16,13 @@ var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implem
// ErrMissingProviderURL is returned when an identity provider requires a provider url // ErrMissingProviderURL is returned when an identity provider requires a provider url
// does not receive one. // does not receive one.
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url") var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
// ErrMissingIDToken is returned when (usually on refresh) and identity provider
// failed to include an id_token in a oauth2 token.
var ErrMissingIDToken = errors.New("identity/oidc: missing id_token")
// ErrMissingRefreshToken is returned if no refresh token was found.
var ErrMissingRefreshToken = errors.New("identity/oidc: missing refresh token")
// ErrMissingAccessToken is returned when no access token was found.
var ErrMissingAccessToken = errors.New("identity/oidc: missing access token")

View file

@ -6,7 +6,6 @@ package gitlab
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
@ -15,14 +14,14 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
"golang.org/x/oauth2"
) )
// Name identifies the GitLab identity provider // Name identifies the GitLab identity provider.
const Name = "gitlab" const Name = "gitlab"
var defaultScopes = []string{oidc.ScopeOpenID, "read_api", "read_user", "profile", "email"} var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email", "api"}
const ( const (
defaultProviderURL = "https://gitlab.com" defaultProviderURL = "https://gitlab.com"
@ -64,11 +63,7 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// //
// Returns 20 results at a time because the API results are paginated. // Returns 20 results at a time because the API results are paginated.
// https://docs.gitlab.com/ee/api/groups.html#list-groups // https://docs.gitlab.com/ee/api/groups.html#list-groups
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
if s == nil || s.AccessToken == nil {
return nil, errors.New("gitlab: user session cannot be empty")
}
var response []struct { var response []struct {
ID json.Number `json:"id"` ID json.Number `json:"id"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
@ -81,17 +76,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
FullName string `json:"full_name,omitempty"` FullName string `json:"full_name,omitempty"`
FullPath string `json:"full_path,omitempty"` FullPath string `json:"full_path,omitempty"`
} }
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)} headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response) err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil { if err != nil {
return nil, err return err
} }
var groups []string
log.Debug().Interface("response", response).Msg("gitlab: groups") log.Debug().Interface("response", response).Msg("gitlab: groups")
var out struct {
Groups []string `json:"groups"`
}
for _, group := range response { for _, group := range response {
groups = append(groups, group.ID.String()) out.Groups = append(out.Groups, group.ID.String())
}
b, err := json.Marshal(out)
if err != nil {
return err
} }
return groups, nil return json.Unmarshal(b, v)
} }

View file

@ -8,6 +8,7 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
@ -18,7 +19,6 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
) )
const ( const (
@ -54,9 +54,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err) return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
} }
p.Provider = genericOidc p.Provider = genericOidc
if o.ServiceAccount == "" {
log.Warn().Msg("google: no service account, will not fetch groups")
return &p, nil
}
// if service account set, configure admin sdk calls
if o.ServiceAccount != "" {
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount) apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
if err != nil { if err != nil {
return nil, fmt.Errorf("google: could not decode service account json %w", err) return nil, fmt.Errorf("google: could not decode service account json %w", err)
@ -80,9 +82,6 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
return nil, fmt.Errorf("google: failed creating admin service %w", err) return nil, fmt.Errorf("google: failed creating admin service %w", err)
} }
p.UserGroupFn = p.UserGroups p.UserGroupFn = p.UserGroups
} else {
log.Warn().Msg("google: no service account, cannot retrieve groups")
}
return &p, nil return &p, nil
} }
@ -106,17 +105,34 @@ func (p *Provider) GetSignInURL(state string) string {
// NOTE: groups via Directory API is limited to 1 QPS! // NOTE: groups via Directory API is limited to 1 QPS!
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list // https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
// https://developers.google.com/admin-sdk/directory/v1/limits // https://developers.google.com/admin-sdk/directory/v1/limits
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
var groups []string if p.apiClient == nil {
if p.apiClient != nil { return errors.New("google: trying to fetch groups, but no api client")
req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100) }
resp, err := req.Do() s, err := p.GetSubject(v)
if err != nil { if err != nil {
return nil, fmt.Errorf("google: group api request failed %w", err) return err
} }
var out struct {
Groups []string `json:"groups"`
}
req := p.apiClient.Groups.List().Context(ctx).UserKey(s)
err = req.Pages(ctx, func(resp *admin.Groups) error {
for _, group := range resp.Groups { for _, group := range resp.Groups {
groups = append(groups, group.Email) out.Groups = append(out.Groups, group.Email)
} }
return nil
})
if err != nil {
return err
} }
return groups, nil _, err = req.Do()
if err != nil {
return fmt.Errorf("google: group api request failed %w", err)
}
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
} }

View file

@ -5,6 +5,7 @@ package oidc
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -15,12 +16,11 @@ import (
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
) )
// Name identifies the generic OpenID Connect provider // Name identifies the generic OpenID Connect provider.
const Name = "oidc" const Name = "oidc"
var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"} var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"}
@ -37,11 +37,6 @@ type Provider struct {
// client application information and the server's endpoint URLs. // client application information and the server's endpoint URLs.
Oauth *oauth2.Config Oauth *oauth2.Config
// UserInfoURL specifies the endpoint responsible for returning claims
// about the authenticated End-User.
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
UserInfoURL string `json:"userinfo_endpoint,omitempty"`
// RevocationURL is the location of the OAuth 2.0 token revocation endpoint. // RevocationURL is the location of the OAuth 2.0 token revocation endpoint.
// https://tools.ietf.org/html/rfc7009 // https://tools.ietf.org/html/rfc7009
RevocationURL string `json:"revocation_endpoint,omitempty"` RevocationURL string `json:"revocation_endpoint,omitempty"`
@ -53,7 +48,7 @@ type Provider struct {
// UserGroupFn is, if set, used to return a slice of group IDs the // UserGroupFn is, if set, used to return a slice of group IDs the
// user is a member of // user is a member of
UserGroupFn func(context.Context, *sessions.State) ([]string, error) UserGroupFn func(context.Context, *oauth2.Token, interface{}) error
} }
// New creates a new instance of a generic OpenID Connect provider. // New creates a new instance of a generic OpenID Connect provider.
@ -100,81 +95,89 @@ func (p *Provider) GetSignInURL(state string) string {
// Authenticate converts an authorization code returned from the identity // Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session. // provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
// Exchange converts an authorization code into a token.
oauth2Token, err := p.Oauth.Exchange(ctx, code) oauth2Token, err := p.Oauth.Exchange(ctx, code)
if err != nil { if err != nil {
return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err) return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err)
} }
idToken, err := p.IdentityFromToken(ctx, oauth2Token)
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil { if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err) return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
} }
aud, err := urlutil.ParseAndValidateURL(p.Oauth.RedirectURL) // hydrate `v` using claims inside the returned `id_token`
if err != nil { // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
return nil, fmt.Errorf("identity/oidc: bad redirect uri: %w", err) if err := idToken.Claims(v); err != nil {
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
} }
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, aud.Hostname()) if err := p.updateUserInfo(ctx, oauth2Token, v); err != nil {
if err != nil { return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
return nil, err
} }
if err := p.Provider.Claims(&p); err == nil && p.UserInfoURL != "" { return oauth2Token, nil
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
if err != nil {
return nil, fmt.Errorf("identity/oidc: could not retrieve user info %w", err)
}
if err := userInfo.Claims(&s); err != nil {
return nil, fmt.Errorf("identity/oidc: could not parse user claims %w", err)
}
} }
// updateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
// groups endpoint (non-spec) to populate the rest of the user's information.
//
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t))
if err != nil {
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
}
if err := userInfo.Claims(v); err != nil {
return fmt.Errorf("identity/oidc: failed parsing user info endpoint claims: %w", err)
}
if p.UserGroupFn != nil { if p.UserGroupFn != nil {
s.Groups, err = p.UserGroupFn(ctx, s) if err := p.UserGroupFn(ctx, t, v); err != nil {
if err != nil { return fmt.Errorf("identity/oidc: could not retrieve groups: %w", err)
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
} }
} }
return s, nil return nil
} }
// Refresh renews a user's session using an oidc refresh token without reprompting the user. // Refresh renews a user's session using an oidc refresh token without reprompting the user.
// Group membership is also refreshed. // Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
if s.AccessToken == nil || s.AccessToken.RefreshToken == "" { if t == nil {
return nil, errors.New("internal/oidc: missing refresh token") return nil, ErrMissingAccessToken
}
if t.RefreshToken == "" {
return nil, ErrMissingRefreshToken
}
var err error
newToken, err := p.Oauth.TokenSource(ctx, t).Token()
if err != nil {
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
} }
t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken} // Many identity providers _will not_ return `id_token` on refresh
oauthToken, err := p.Oauth.TokenSource(ctx, &t).Token() // https://github.com/FusionAuth/fusionauth-issues/issues/110#issuecomment-481526544
if err != nil { idToken, err := p.getIDToken(ctx, newToken)
return nil, fmt.Errorf("internal/oidc: refresh failed %w", err) if err == nil {
} if err := idToken.Claims(v); err != nil {
idToken, err := p.IdentityFromToken(ctx, oauthToken) return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
}
if err := s.UpdateState(idToken, oauthToken); err != nil {
return nil, fmt.Errorf("internal/oidc: state update failed %w", err)
}
if p.UserGroupFn != nil {
s.Groups, err = p.UserGroupFn(ctx, s)
if err != nil {
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
} }
} }
return s, nil if err := p.updateUserInfo(ctx, newToken, v); err != nil {
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
}
return newToken, nil
} }
// IdentityFromToken takes an identity provider issued JWT as input ('id_token') // getIDToken returns the raw jwt payload for `id_token` from the oauth2 token
// and returns a session state. The provided token's audience ('aud') must // returned following oidc code flow
// match Pomerium's client_id. //
func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) { // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
rawIDToken, ok := t.Extra("id_token").(string) rawIDToken, ok := t.Extra("id_token").(string)
if !ok { if !ok {
return nil, fmt.Errorf("internal/oidc: id_token not found") return nil, ErrMissingIDToken
} }
return p.Verifier.Verify(ctx, rawIDToken) return p.Verifier.Verify(ctx, rawIDToken)
} }
@ -183,13 +186,16 @@ func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_
// support revocation an error is thrown. // support revocation an error is thrown.
// //
// https://tools.ietf.org/html/rfc7009#section-2.1 // https://tools.ietf.org/html/rfc7009#section-2.1
func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error { func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
if p.RevocationURL == "" { if p.RevocationURL == "" {
return ErrRevokeNotImplemented return ErrRevokeNotImplemented
} }
if t == nil {
return ErrMissingAccessToken
}
params := url.Values{} params := url.Values{}
params.Add("token", token.AccessToken) params.Add("token", t.AccessToken)
params.Add("token_type_hint", "access_token") params.Add("token_type_hint", "access_token")
// Some providers like okta / onelogin require "client authentication" // Some providers like okta / onelogin require "client authentication"
// https://developer.okta.com/docs/reference/api/oidc/#client-secret // https://developer.okta.com/docs/reference/api/oidc/#client-secret
@ -198,7 +204,7 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
params.Add("client_secret", p.Oauth.ClientSecret) params.Add("client_secret", p.Oauth.ClientSecret)
err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil) err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked { if err != nil && errors.Is(err, httputil.ErrTokenRevoked) {
return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err) return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err)
} }
@ -214,3 +220,20 @@ func (p *Provider) LogOut() (*url.URL, error) {
} }
return urlutil.ParseAndValidateURL(p.EndSessionURL) return urlutil.ParseAndValidateURL(p.EndSessionURL)
} }
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
func (p *Provider) GetSubject(v interface{}) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
var s struct {
Subject string `json:"sub"`
}
err = json.Unmarshal(b, &s)
if err != nil {
return "", err
}
return s.Subject, nil
}

View file

@ -5,6 +5,7 @@ package okta
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -13,9 +14,9 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
"golang.org/x/oauth2"
) )
const ( const (
@ -64,7 +65,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// UserGroups fetches the groups of which the user is a member // UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups // https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
s, err := p.GetSubject(v)
if err != nil {
return err
}
var response []struct { var response []struct {
ID string `json:"id"` ID string `json:"id"`
Profile struct { Profile struct {
@ -74,15 +79,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
} }
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)} headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)}
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject) uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s)
err := httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response) err = httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
if err != nil { if err != nil {
return nil, err return err
}
log.Debug().Interface("response", response).Msg("okta: groups")
var out struct {
Groups []string `json:"groups"`
} }
var groups []string
for _, group := range response { for _, group := range response {
log.Debug().Interface("group", group).Msg("okta: group") out.Groups = append(out.Groups, group.ID)
groups = append(groups, group.ID)
} }
return groups, nil b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
} }

View file

@ -5,17 +5,15 @@ package onelogin
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"time"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
) )
@ -55,24 +53,10 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// UserGroups returns a slice of group names a given user is in. // UserGroups returns a slice of group names a given user is in.
// https://developers.onelogin.com/openid-connect/api/user-info // https://developers.onelogin.com/openid-connect/api/user-info
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) { func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
if s == nil || s.AccessToken == nil { if t == nil {
return nil, errors.New("identity/onelogin: session cannot be nil") return pom_oidc.ErrMissingAccessToken
} }
var response struct { headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
User string `json:"sub"` return httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, v)
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Name string `json:"name"`
UpdatedAt time.Time `json:"updated_at"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Groups []string `json:"groups"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
}
return response.Groups, nil
} }

View file

@ -17,25 +17,24 @@ import (
"github.com/pomerium/pomerium/internal/identity/oidc/google" "github.com/pomerium/pomerium/internal/identity/oidc/google"
"github.com/pomerium/pomerium/internal/identity/oidc/okta" "github.com/pomerium/pomerium/internal/identity/oidc/okta"
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin" "github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
"github.com/pomerium/pomerium/internal/sessions"
) )
var ( var (
// compile time assertions that providers are satisfying the interface // compile time assertions that providers are satisfying the interface
_ Authenticator = &azure.Provider{} _ Authenticator = &azure.Provider{}
_ Authenticator = &gitlab.Provider{}
_ Authenticator = &github.Provider{} _ Authenticator = &github.Provider{}
_ Authenticator = &gitlab.Provider{}
_ Authenticator = &google.Provider{} _ Authenticator = &google.Provider{}
_ Authenticator = &MockProvider{}
_ Authenticator = &oidc.Provider{} _ Authenticator = &oidc.Provider{}
_ Authenticator = &okta.Provider{} _ Authenticator = &okta.Provider{}
_ Authenticator = &onelogin.Provider{} _ Authenticator = &onelogin.Provider{}
_ Authenticator = &MockProvider{}
) )
// Authenticator is an interface representing the ability to authenticate with an identity provider. // Authenticator is an interface representing the ability to authenticate with an identity provider.
type Authenticator interface { type Authenticator interface {
Authenticate(context.Context, string) (*sessions.State, error) Authenticate(context.Context, string, interface{}) (*oauth2.Token, error)
Refresh(context.Context, *sessions.State) (*sessions.State, error) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) string GetSignInURL(state string) string
LogOut() (*url.URL, error) LogOut() (*url.URL, error)

View file

@ -37,6 +37,9 @@ type Store struct {
srv *http.Server srv *http.Server
} }
// ErrCacheMiss is returned when the cache misses for a given key.
var ErrCacheMiss = errors.New("cache miss")
// Options represent autocache options. // Options represent autocache options.
type Options struct { type Options struct {
Addr string Addr string
@ -60,7 +63,7 @@ var DefaultOptions = &Options{
GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error { GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error {
b := fromContext(ctx) b := fromContext(ctx)
if len(b) == 0 { if len(b) == 0 {
return fmt.Errorf("autocache: empty ctx for id: %s", id) return fmt.Errorf("autocache: id %s : %w", id, ErrCacheMiss)
} }
if err := dest.SetBytes(b); err != nil { if err := dest.SetBytes(b); err != nil {
return fmt.Errorf("autocache: sink error %w", err) return fmt.Errorf("autocache: sink error %w", err)

View file

@ -1,95 +0,0 @@
// Package cache provides a remote cache based implementation of session store
// and loader. See pomerium's cache service for more details.
package cache
import (
"errors"
"fmt"
"net/http"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
// Store implements the session store interface using a cache service.
type Store struct {
cache client.Cacher
encoder encoding.MarshalUnmarshaler
queryParam string
wrappedStore sessions.SessionStore
}
// Options represent cache store's available configurations.
type Options struct {
Cache client.Cacher
Encoder encoding.MarshalUnmarshaler
QueryParam string
WrappedStore sessions.SessionStore
}
var defaultOptions = &Options{
QueryParam: "cache_store_key",
}
// NewStore creates a new cache
func NewStore(o *Options) *Store {
if o.QueryParam == "" {
o.QueryParam = defaultOptions.QueryParam
}
return &Store{
cache: o.Cache,
encoder: o.Encoder,
queryParam: o.QueryParam,
wrappedStore: o.WrappedStore,
}
}
// LoadSession looks for a preset query parameter in the request body
// representing the key to lookup from the cache.
func (s *Store) LoadSession(r *http.Request) (string, error) {
// look for our cache's key in the default query param
sessionID := r.URL.Query().Get(s.queryParam)
if sessionID == "" {
return "", sessions.ErrNoSessionFound
}
exists, val, err := s.cache.Get(r.Context(), sessionID)
if err != nil {
log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader")
return "", err
}
if !exists {
return "", sessions.ErrNoSessionFound
}
return string(val), nil
}
// ClearSession clears the session from the wrapped store.
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
s.wrappedStore.ClearSession(w, r)
}
// SaveSession saves the session to the cache, and wrapped store.
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
err := s.wrappedStore.SaveSession(w, r, x)
if err != nil {
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
}
state, ok := x.(*sessions.State)
if !ok {
return errors.New("sessions/cache: cannot cache non state type")
}
data, err := s.encoder.Marshal(&state)
if err != nil {
return fmt.Errorf("sessions/cache: marshal %w", err)
}
return s.cache.Set(r.Context(), state.AccessTokenID, data)
}

View file

@ -1,190 +0,0 @@
package cache
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/mock"
"gopkg.in/square/go-jose.v2/jwt"
)
type mockCache struct {
Key string
KeyExists bool
Value []byte
Err error
}
func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
return mc.KeyExists, mc.Value, mc.Err
}
func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error {
return mc.Err
}
func (mc *mockCache) Close() error {
return mc.Err
}
func TestNewStore(t *testing.T) {
tests := []struct {
name string
Options *Options
State *sessions.State
wantErr bool
wantLoadErr bool
wantStatus int
}{
{"simple good",
&Options{
Cache: &mockCache{},
WrappedStore: &mock.Store{},
Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")},
},
&sessions.State{Email: "user@domain.com", User: "user"},
false, false,
http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewStore(tt.Options)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
got.ClearSession(w, r)
status := w.Result().StatusCode
if diff := cmp.Diff(status, tt.wantStatus); diff != "" {
t.Errorf("ClearSession() = %v", diff)
}
})
}
}
func TestStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
Options *Options
x interface{}
wantErr bool
}{
{"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
{"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
{"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
{"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := tt.Options
if o.WrappedStore == nil {
o.WrappedStore = cs
}
cacheStore := NewStore(tt.Options)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestStore_LoadSession(t *testing.T) {
key := cryptutil.NewBase64Key()
tests := []struct {
name string
state *sessions.State
cache client.Cacher
encoder encoding.MarshalUnmarshaler
queryParam string
wrappedStore sessions.SessionStore
wantErr bool
}{
{"good",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
false},
{"missing param with key",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
"bad_query",
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"doesn't exist",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: false},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"retrieval error",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{Err: errors.New("err")},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Store{
cache: tt.cache,
encoder: tt.encoder,
queryParam: tt.queryParam,
wrappedStore: tt.wrappedStore,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
q := r.URL.Query()
q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID)
r.URL.RawQuery = q.Encode()
r.Header.Set("Accept", "application/json")
_, err := s.LoadSession(r)
if (err != nil) != tt.wantErr {
t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -1,15 +1,11 @@
package sessions package sessions
import ( import (
"encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/cespare/xxhash/v2" "github.com/pomerium/pomerium/internal/hashutil"
oidc "github.com/coreos/go-oidc"
"github.com/mitchellh/hashstructure"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
@ -27,6 +23,9 @@ type State struct {
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"` ID string `json:"jti,omitempty"`
// At_hash is an OPTIONAL Access Token hash value
// https://ldapwiki.com/wiki/At_hash
AccessTokenHash string `json:"at_hash,omitempty"`
// core pomerium identity claims ; not standard to RFC 7519 // core pomerium identity claims ; not standard to RFC 7519
Email string `json:"email"` Email string `json:"email"`
@ -48,84 +47,24 @@ type State struct {
// Programmatic whether this state is used for machine-to-machine // Programmatic whether this state is used for machine-to-machine
// programatic access. // programatic access.
Programmatic bool `json:"programatic"` Programmatic bool `json:"programatic"`
AccessToken *oauth2.Token `json:"act,omitempty"`
AccessTokenID string `json:"ati,omitempty"`
idToken *oidc.IDToken
}
// NewStateFromTokens returns a session state built from oidc and oauth2
// tokens as part of OpenID Connect flow with a new audience appended to the
// audience claim.
func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) {
if idToken == nil {
return nil, errors.New("sessions: oidc id token missing")
}
if accessToken == nil {
return nil, errors.New("sessions: oauth2 token missing")
}
s := &State{}
if err := idToken.Claims(s); err != nil {
return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err)
}
s.Audience = []string{audience}
s.idToken = idToken
s.AccessToken = accessToken
s.AccessTokenID = s.accessTokenHash()
return s, nil
}
// UpdateState updates the current state given a new identity (oidc) and authorization
// (oauth2) tokens following a oidc refresh. NB, unlike during authentication,
// refresh typically provides fewer claims in the token so we want to build from
// our previous state.
func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error {
if idToken == nil {
return errors.New("sessions: oidc id token missing")
}
if accessToken == nil {
return errors.New("sessions: oauth2 token missing")
}
audience := append(s.Audience[:0:0], s.Audience...)
s.AccessToken = accessToken
if err := idToken.Claims(s); err != nil {
return fmt.Errorf("sessions: update state failed %w", err)
}
s.Audience = audience
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
s.AccessTokenID = s.accessTokenHash()
return nil
} }
// NewSession updates issuer, audience, and issuance timestamps but keeps // NewSession updates issuer, audience, and issuance timestamps but keeps
// parent expiry. // parent expiry.
func (s State) NewSession(issuer string, audience []string) *State { func NewSession(s *State, issuer string, audience []string, accessToken *oauth2.Token) State {
s.IssuedAt = jwt.NewNumericDate(timeNow()) newState := *s
s.NotBefore = s.IssuedAt newState.IssuedAt = jwt.NewNumericDate(timeNow())
s.Audience = audience newState.NotBefore = newState.IssuedAt
s.Issuer = issuer newState.Audience = audience
return &s newState.Issuer = issuer
} newState.AccessTokenHash = fmt.Sprintf("%x", hashutil.Hash(accessToken))
newState.Expiry = jwt.NewNumericDate(accessToken.Expiry)
// RouteSession creates a route session with access tokens stripped. return newState
func (s State) RouteSession() *State {
s.AccessToken = nil
return &s
} }
// IsExpired returns true if the users's session is expired. // IsExpired returns true if the users's session is expired.
func (s *State) IsExpired() bool { func (s *State) IsExpired() bool {
return s.Expiry != nil && timeNow().After(s.Expiry.Time())
if s.Expiry != nil && timeNow().After(s.Expiry.Time()) {
return true
}
if s.AccessToken != nil && timeNow().After(s.AccessToken.Expiry) {
return true
}
return false
} }
// Impersonating returns if the request is impersonating. // Impersonating returns if the request is impersonating.
@ -133,23 +72,6 @@ func (s *State) Impersonating() bool {
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0 return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
} }
// RequestEmail is the email to make the request as.
func (s *State) RequestEmail() string {
if s.ImpersonateEmail != "" {
return s.ImpersonateEmail
}
return s.Email
}
// RequestGroups returns the groups of the Groups making the request; uses
// impersonating user if set.
func (s *State) RequestGroups() string {
if len(s.ImpersonateGroups) != 0 {
return strings.Join(s.ImpersonateGroups, ",")
}
return strings.Join(s.Groups, ",")
}
// SetImpersonation sets impersonation user and groups. // SetImpersonation sets impersonation user and groups.
func (s *State) SetImpersonation(email, groups string) { func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateEmail = email s.ImpersonateEmail = email
@ -159,34 +81,3 @@ func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateGroups = strings.Split(groups, ",") s.ImpersonateGroups = strings.Split(groups, ",")
} }
} }
func (s *State) accessTokenHash() string {
hash, err := hashstructure.Hash(
s.AccessToken,
&hashstructure.HashOptions{Hasher: xxhash.New()})
if err != nil {
return ""
}
return fmt.Sprintf("%x", hash)
}
// UnmarshalJSON parses the JSON-encoded session state.
// TODO(BDD): remove in v0.8.0
func (s *State) UnmarshalJSON(b []byte) error {
type Alias State
t := &struct {
*Alias
OldToken *oauth2.Token `json:"access_token,omitempty"` // < v0.5.0
}{
Alias: (*Alias)(s),
}
if err := json.Unmarshal(b, &t); err != nil {
return err
}
if t.AccessToken == nil {
t.AccessToken = t.OldToken
}
*s = *(*State)(t.Alias)
return nil
}

View file

@ -5,8 +5,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
@ -38,12 +36,6 @@ func TestState_Impersonating(t *testing.T) {
if got := s.Impersonating(); got != tt.want { if got := s.Impersonating(); got != tt.want {
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want) t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
} }
if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail {
t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
}
if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups {
t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
}
}) })
} }
} }
@ -63,7 +55,6 @@ func TestState_IsExpired(t *testing.T) {
}{ }{
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false}, {"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true}, {"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
{"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -72,7 +63,6 @@ func TestState_IsExpired(t *testing.T) {
Expiry: tt.Expiry, Expiry: tt.Expiry,
NotBefore: tt.NotBefore, NotBefore: tt.NotBefore,
IssuedAt: tt.IssuedAt, IssuedAt: tt.IssuedAt,
AccessToken: tt.AccessToken,
} }
if exp := s.IsExpired(); exp != tt.wantErr { if exp := s.IsExpired(); exp != tt.wantErr {
t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr) t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr)
@ -80,67 +70,3 @@ func TestState_IsExpired(t *testing.T) {
}) })
} }
} }
func TestState_RouteSession(t *testing.T) {
now := time.Now()
timeNow = func() time.Time {
return now
}
tests := []struct {
name string
Issuer string
Audience jwt.Audience
Expiry *jwt.NumericDate
AccessToken *oauth2.Token
issuer string
audience []string
want *State
}{
{"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow())}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := State{
Issuer: tt.Issuer,
Audience: tt.Audience,
Expiry: tt.Expiry,
AccessToken: tt.AccessToken,
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(State{}),
}
got := s.NewSession(tt.issuer, tt.audience)
got = got.RouteSession()
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("State.RouteSession() = %s", diff)
}
})
}
}
func TestState_accessTokenHash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
state State
want string
}{
{"empty access token", State{}, "34c96acdcadb1bbb"},
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &tt.state
if got := s.accessTokenHash(); got != tt.want {
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -17,6 +17,7 @@ const (
QueryRefreshToken = "pomerium_refresh_token" QueryRefreshToken = "pomerium_refresh_token"
QueryAccessTokenID = "pomerium_session_access_token_id" QueryAccessTokenID = "pomerium_session_access_token_id"
QueryAudience = "pomerium_session_audience" QueryAudience = "pomerium_session_audience"
QueryProgrammaticToken = "pomerium_programmatic_token"
) )
// URL signature based query params used for verifying the authenticity of a URL. // URL signature based query params used for verifying the authenticity of a URL.