authenticate: save oauth2 tokens to cache (#698)

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2020-05-18 10:45:07 -07:00 committed by Travis Groth
parent ef399380b7
commit 666fd6aa35
31 changed files with 1127 additions and 1061 deletions

View file

@ -38,6 +38,12 @@ GETENVOY_VERSION = v0.1.8
all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet.
.PHONY: generate-mocks
generate-mocks: ## Generate mocks
@echo "==> $@"
@go run github.com/golang/mock/mockgen -destination authorize/evaluator/mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
@go run github.com/golang/mock/mockgen -destination internal/grpc/cache/mock/mock_cacher.go github.com/pomerium/pomerium/internal/grpc/cache Cacher
.PHONY: build-deps
build-deps: ## Install build dependencies
@echo "==> $@"

View file

@ -17,12 +17,12 @@ import (
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/grpc"
"github.com/pomerium/pomerium/internal/grpc/cache"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cache"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
@ -93,7 +93,7 @@ type Authenticate struct {
provider identity.Authenticator
// cacheClient is the interface for setting and getting sessions from a cache
cacheClient client.Cacher
cacheClient cache.Cacher
templates *template.Template
}
@ -106,12 +106,12 @@ func New(opts config.Options) (*Authenticate, error) {
// shared state encoder setup
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
sharedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
if err != nil {
return nil, err
}
// private state encoder setup
// private state encoder setup, used to encrypt oauth2 tokens
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
encryptedEncoder := ecjson.New(cookieCipher)
@ -124,7 +124,7 @@ func New(opts config.Options) (*Authenticate, error) {
Expire: opts.CookieExpire,
}
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
cookieStore, err := cookie.NewStore(cookieOptions, sharedEncoder)
if err != nil {
return nil, err
}
@ -145,13 +145,7 @@ func New(opts config.Options) (*Authenticate, error) {
cacheClient := client.New(cacheConn)
cacheStore := cache.NewStore(&cache.Options{
Cache: cacheClient,
Encoder: encryptedEncoder,
QueryParam: urlutil.QueryAccessTokenID,
WrappedStore: cookieStore})
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
qpStore := queryparam.NewStore(encryptedEncoder, urlutil.QueryProgrammaticToken)
headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium)
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
@ -177,14 +171,14 @@ func New(opts config.Options) (*Authenticate, error) {
// shared state
sharedKey: opts.SharedKey,
sharedCipher: sharedCipher,
sharedEncoder: signedEncoder,
sharedEncoder: sharedEncoder,
// private state
cookieSecret: decodedCookieSecret,
cookieCipher: cookieCipher,
cookieOptions: cookieOptions,
sessionStore: cacheStore,
sessionStore: cookieStore,
encryptedEncoder: encryptedEncoder,
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
// IdP
provider: provider,
// grpc client for cache

View file

@ -91,6 +91,10 @@ func TestNew(t *testing.T) {
badGRPCConn.CacheURL = nil
badGRPCConn.CookieName = "D"
emptyProviderURL := newTestOptions(t)
emptyProviderURL.Provider = "oidc"
emptyProviderURL.ProviderURL = ""
tests := []struct {
name string
opts *config.Options
@ -103,6 +107,7 @@ func TestNew(t *testing.T) {
{"bad cookie name", badCookieName, true},
{"bad provider", badProvider, true},
{"bad cache url", badGRPCConn, true},
{"empty provider url", emptyProviderURL, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log"
@ -23,6 +24,7 @@ import (
"github.com/gorilla/mux"
"github.com/rs/cors"
"golang.org/x/oauth2"
)
// Handler returns the authenticate service's handler chain.
@ -80,19 +82,15 @@ func (a *Authenticate) Mount(r *mux.Router) {
// session state is attached to the users's request context.
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
jwt, err := sessions.FromContext(ctx)
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession")
defer span.End()
s, err := a.getSessionFromCtx(ctx)
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
if s.IsExpired() {
ctx, err = a.refresh(w, r, &s)
ctx, err = a.refresh(w, r, s)
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
return a.reauthenticateOrFail(w, r, err)
@ -106,18 +104,34 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
defer span.End()
newSession, err := a.provider.Refresh(ctx, s)
accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
return nil, err
}
// we are going to keep the same audiences for the refreshed token
// otherwise this will be rewritten to be the ClientID of the provider
oldAudience := s.Audience
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
if err != nil {
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
}
newSession = newSession.NewSession(s.Issuer, s.Audience)
encSession, err := a.encryptedEncoder.Marshal(newSession)
newSession := sessions.NewSession(s, a.RedirectURL.Hostname(), oldAudience, newAccessToken)
encSession, err := a.sharedEncoder.Marshal(newSession)
if err != nil {
return nil, err
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: error saving new session: %w", err)
}
if err := a.setAccessToken(ctx, newAccessToken); err != nil {
return nil, fmt.Errorf("authenticate: error saving refreshed access token: %w", err)
}
// return the new session and add it to the current request context
return sessions.NewContext(ctx, string(encSession), err), nil
}
@ -129,8 +143,11 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
}
// SignIn handles to authenticating a user.
// SignIn handles authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End()
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
@ -158,32 +175,30 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
jwtAudience = append(jwtAudience, fwdAuth)
}
jwt, err := sessions.FromContext(r.Context())
s, err := a.getSessionFromCtx(ctx)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
return err
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
return err
}
// user impersonation
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
}
newSession := sessions.NewSession(s, a.RedirectURL.Host, jwtAudience, accessToken)
// re-persist the session, useful when session was evicted from session
if err := a.sessionStore.SaveSession(w, r, &s); err != nil {
if err := a.sessionStore.SaveSession(w, r, s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
newSession := s.NewSession(a.RedirectURL.Host, jwtAudience)
callbackParams := callbackURL.Query()
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
newSession.Programmatic = true
encSession, err := a.encryptedEncoder.Marshal(newSession)
encSession, err := a.encryptedEncoder.Marshal(accessToken)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
@ -192,7 +207,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
}
// sign the route session, as a JWT
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
signedJWT, err := a.sharedEncoder.Marshal(newSession)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
@ -217,27 +232,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
// SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// no matter what happens, we want to clear the local session store
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End()
// no matter what happens, we want to clear the session store
a.sessionStore.ClearSession(w, r)
jwt, err := sessions.FromContext(r.Context())
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
redirectString := r.FormValue(urlutil.QueryRedirectURI)
// first, try to revoke the session if implemented
err = a.provider.Revoke(r.Context(), s.AccessToken)
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
return httputil.NewError(http.StatusBadRequest, err)
}
// next, try to build a logout url if implemented
endSessionURL, err := a.provider.LogOut()
if err == nil {
params := url.Values{}
@ -245,14 +245,29 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
return httputil.NewError(http.StatusBadRequest, err)
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
}
redirectURL, err := urlutil.ParseAndValidateURL(redirectString)
httputil.Redirect(w, r, redirectString, http.StatusFound)
s, err := a.getSessionFromCtx(ctx)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
return nil
}
accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting access token")
return nil
}
// first, try to revoke the session if implemented
err = a.provider.Revoke(ctx, accessToken)
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
log.Warn().Err(err).Msg("authenticate.SignOut: failed revoking token")
return nil
}
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
return nil
}
@ -267,7 +282,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// https://tools.ietf.org/html/rfc6749#section-4.2.1
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
// If request AJAX/XHR request, return a 401 instead .
// If request AJAX/XHR request, return a 401 instead because the redirect
// will almost certainly violate their CORs policy
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
return httputil.NewError(http.StatusUnauthorized, err)
}
@ -290,7 +306,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
redirect, err := a.getOAuthCallback(w, r)
if err != nil {
return fmt.Errorf("oauth callback : %w", err)
return fmt.Errorf("authenticate.OAuthCallback: %w", err)
}
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
return nil
@ -306,6 +322,9 @@ func (a *Authenticate) statusForErrorCode(errorCode string) int {
}
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
//
// first, check if the identity provider returned an error
@ -321,14 +340,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
//
// Exchange the supplied Authorization Code for a valid user session.
session, err := a.provider.Authenticate(r.Context(), code)
var s sessions.State
accessToken, err := a.provider.Authenticate(ctx, code, &s)
if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}
newState := sessions.NewSession(
&s,
a.RedirectURL.Hostname(),
[]string{a.RedirectURL.Hostname()},
accessToken)
// state includes a csrf nonce (validated by middleware) and redirect uri
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("bad bytes: %w", err))
}
// split state into concat'd components
@ -357,8 +384,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
return nil, httputil.NewError(http.StatusBadRequest, err)
}
// OK. Looks good so let's persist our user session
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
// Ok -- We've got a valid session here. Let's now persist the access
// token to cache ...
if err := a.setAccessToken(ctx, accessToken); err != nil {
return nil, fmt.Errorf("failed saving access token: %w", err)
}
// ... and the user state to local storage.
if err := a.sessionStore.SaveSession(w, r, &newState); err != nil {
return nil, fmt.Errorf("failed saving new session: %w", err)
}
return redirectURL, nil
@ -368,26 +400,32 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// tokens and state with the identity provider. If successful, a new signed JWT
// and refresh token (`refresh_token`) are returned as JSON
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
jwt, err := sessions.FromContext(r.Context())
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
newSession, err := a.provider.Refresh(r.Context(), &s)
if err != nil {
return err
}
newSession = newSession.NewSession(s.Issuer, s.Audience)
ctx, span := trace.StartSpan(r.Context(), "authenticate.RefreshAPI")
defer span.End()
encSession, err := a.encryptedEncoder.Marshal(newSession)
s, err := a.getSessionFromCtx(ctx)
if err != nil {
return err
}
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
accessToken, err := a.getAccessToken(ctx, s)
if err != nil {
return err
}
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
if err != nil {
return err
}
routeNewSession := sessions.NewSession(s, a.RedirectURL.Hostname(), s.Audience, newAccessToken)
encSession, err := a.encryptedEncoder.Marshal(accessToken)
if err != nil {
return err
}
signedJWT, err := a.sharedEncoder.Marshal(routeNewSession)
if err != nil {
return err
}
@ -410,28 +448,79 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
// Refresh is called by the proxy service to handle backend session refresh.
//
// NOTE: The actual refresh is handled as part of the "VerifySession"
// middleware. This handler is responsible for creating a new route scoped
// session and returning it.
// middleware. This handler is simply responsible for returning that jwt.
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
jwt, err := sessions.FromContext(r.Context())
ctx, span := trace.StartSpan(r.Context(), "authenticate.Refresh")
defer span.End()
jwt, err := sessions.FromContext(ctx)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
return fmt.Errorf("authenticate.Refresh: %w", err)
}
var s sessions.State
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
fmt.Fprint(w, jwt)
return nil
}
// getAccessToken gets an associated oauth2 access token from a session state
func (a *Authenticate) getAccessToken(ctx context.Context, s *sessions.State) (*oauth2.Token, error) {
ctx, span := trace.StartSpan(ctx, "authenticate.getAccessToken")
defer span.End()
var accessToken oauth2.Token
tokenBytes, err := a.cacheClient.Get(ctx, s.AccessTokenHash)
if err != nil {
return nil, err
}
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
return nil, err
}
if accessToken.Valid() {
return &accessToken, nil // this token is still valid, use it!
}
tokenBytes, err = a.cacheClient.Get(ctx, a.timestampedHash(accessToken.RefreshToken))
if err == nil {
// we found another possibly newer access token associated with the
// existing refresh token so let's try that.
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
return nil, err
}
}
aud := strings.Split(r.FormValue(urlutil.QueryAudience), ",")
routeSession := s.NewSession(r.Host, aud)
routeSession.AccessTokenID = s.AccessTokenID
return &accessToken, nil
}
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
func (a *Authenticate) setAccessToken(ctx context.Context, accessToken *oauth2.Token) error {
encToken, err := a.encryptedEncoder.Marshal(accessToken)
if err != nil {
return err
}
// set this specific access token
key := fmt.Sprintf("%x", hashutil.Hash(accessToken))
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
}
// set this as the "latest" token for this access token
key = a.timestampedHash(accessToken.RefreshToken)
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
}
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
w.Write(signedJWT)
return nil
}
func (a *Authenticate) timestampedHash(s string) string {
return fmt.Sprintf("%x-%v", hashutil.Hash(s), time.Now().Truncate(time.Minute).Unix())
}
func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, error) {
jwt, err := sessions.FromContext(ctx)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
}
var s sessions.State
if err := a.sharedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
}
return &s, nil
}

View file

@ -11,20 +11,24 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/frontend"
mock_cache "github.com/pomerium/pomerium/internal/grpc/cache/mock"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/golang/mock/gomock"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
@ -115,23 +119,28 @@ func TestAuthenticate_SignIn(t *testing.T) {
encoder encoding.MarshalUnmarshaler
wantCode int
}{
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
a := &Authenticate{
sessionStore: tt.session,
provider: tt.provider,
@ -144,6 +153,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
Name: "cookie",
Domain: "foo",
},
cacheClient: mc,
}
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
@ -176,6 +186,7 @@ func uriParseHelper(s string) *url.URL {
func TestAuthenticate_SignOut(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
@ -190,19 +201,24 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int
wantBody string
}{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
a := &Authenticate{
sessionStore: tt.sessionStore,
provider: tt.provider,
encryptedEncoder: mock.Encoder{},
templates: template.Must(frontend.NewTemplates()),
sharedEncoder: mock.Encoder{},
cacheClient: mc,
}
u, _ := url.Parse("/sign_out")
params, _ := url.ParseQuery(u.RawQuery)
@ -256,40 +272,50 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string
wantCode int
}{
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusUnauthorized},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}, AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusUnauthorized},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
signer, err := jws.NewHS256Signer(nil, "mock")
if err != nil {
t.Fatal(err)
}
authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{
RedirectURL: authURL,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
RedirectURL: authURL,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
cacheClient: mc,
encryptedEncoder: signer,
}
u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("error", tt.paramErr)
params.Add("code", tt.code)
nonce := cryptutil.NewBase64Key() // mock csrf
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
@ -336,15 +362,21 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
wantStatus int
}{
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusFound},
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
@ -361,6 +393,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
provider: tt.provider,
cookieCipher: aead,
encryptedEncoder: signer,
cacheClient: mc,
sharedEncoder: mock.Encoder{},
}
r := httptest.NewRequest("GET", "/", nil)
state, err := tt.session.LoadSession(r)
@ -402,14 +436,20 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
wantStatus int
}{
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
@ -423,6 +463,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
cacheClient: mc,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
@ -441,53 +482,111 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
})
}
}
func TestAuthenticate_Refresh(t *testing.T) {
t.Parallel()
tests := []struct {
name string
session sessions.SessionStore
ctxError error
session *sessions.State
at *oauth2.Token
provider identity.Authenticator
secretEncoder encoding.MarshalUnmarshaler
sharedEncoder encoding.MarshalUnmarshaler
wantStatus int
}{
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
{"good",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
mock.Encoder{MarshalResponse: []byte("ok")},
200},
{"session and oauth2 expired",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(-10 * time.Minute)},
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
mock.Encoder{MarshalResponse: []byte("ok")},
200},
{"session expired",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
mock.Encoder{MarshalResponse: []byte("ok")},
200},
{"failed refresh",
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
identity.MockProvider{RefreshError: errors.New("oh no")},
mock.Encoder{MarshalResponse: []byte("ok")},
302},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mc := mock_cache.NewMockCacher(ctrl)
// just enough is stubbed out here so we can use our own mock provider
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
out := fmt.Sprintf(`{"issuer":"http://%s"}`, r.Host)
fmt.Fprintln(w, out)
}))
defer ts.Close()
rURL := ts.URL
a, err := New(config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
AuthenticateURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
Provider: "oidc",
ClientID: "mock",
ClientSecret: "mock",
ProviderURL: rURL,
AuthenticateCallbackPath: "mock",
CookieName: "pomerium",
Addr: ":0",
CacheURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
})
if err != nil {
t.Fatal(err)
}
a := Authenticate{
sharedKey: cryptutil.NewBase64Key(),
cookieSecret: cryptutil.NewKey(),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
encryptedEncoder: tt.secretEncoder,
sharedEncoder: tt.sharedEncoder,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
a.cacheClient = mc
a.provider = tt.provider
u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery)
destination := urlutil.NewSignedURL(a.sharedKey,
&url.URL{
Scheme: "https",
Host: "example.com",
Path: "/.pomerium/refresh"})
u.RawQuery = params.Encode()
r := httptest.NewRequest(http.MethodGet, destination.String(), nil)
jwt, err := a.sharedEncoder.Marshal(tt.session)
if err != nil {
t.Fatal(err)
}
rawToken, err := a.encryptedEncoder.Marshal(tt.at)
if err != nil {
t.Fatal(err)
}
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return(rawToken, nil).AnyTimes()
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
a.cacheClient = mc
r.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", jwt))
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
router := mux.NewRouter()
a.Mount(router)
router.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
t.Errorf("Refresh() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}

View file

@ -1,5 +1,3 @@
//go:generate mockgen -destination mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
// Package evaluator defines a Evaluator interfaces that can be implemented by
// a policy evaluator framework.
package evaluator

View file

@ -218,10 +218,6 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
// 1 - build a signed url to call refresh on authenticate service
refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"})
q := refreshURI.Query()
q.Set(urlutil.QueryAccessTokenID, state.AccessTokenID) // hash value points to parent token
q.Set(urlutil.QueryAudience, strings.Join(state.Audience, ",")) // request's audience, this route
refreshURI.RawQuery = q.Encode()
signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String()
// 2 - http call to authenticate service
@ -229,6 +225,7 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
if err != nil {
return nil, fmt.Errorf("authorize: refresh request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", rawSession))
req.Header.Set("X-Requested-With", "XmlHttpRequest")
req.Header.Set("Accept", "application/json")

30
cache/grpc_test.go vendored
View file

@ -11,6 +11,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/grpc/cache"
@ -48,11 +49,8 @@ func TestCache_Get_and_Set(t *testing.T) {
&cache.SetReply{},
&cache.GetRequest{Key: "key"},
&cache.GetReply{
Exists: true,
Value: []byte("hello"),
XXX_NoUnkeyedLiteral: struct{}{},
XXX_unrecognized: nil,
XXX_sizecache: 0,
Exists: true,
Value: []byte("hello"),
},
false,
false,
@ -63,11 +61,8 @@ func TestCache_Get_and_Set(t *testing.T) {
&cache.SetReply{},
&cache.GetRequest{Key: "no-such-key"},
&cache.GetReply{
Exists: false,
Value: nil,
XXX_NoUnkeyedLiteral: struct{}{},
XXX_unrecognized: nil,
XXX_sizecache: 0,
Exists: false,
Value: nil,
},
false,
false,
@ -78,11 +73,8 @@ func TestCache_Get_and_Set(t *testing.T) {
nil,
&cache.GetRequest{Key: hugeKey},
&cache.GetReply{
Exists: false,
Value: nil,
XXX_NoUnkeyedLiteral: struct{}{},
XXX_unrecognized: nil,
XXX_sizecache: 0,
Exists: false,
Value: nil,
},
true,
false,
@ -96,7 +88,11 @@ func TestCache_Get_and_Set(t *testing.T) {
t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError)
return
}
if diff := cmp.Diff(setGot, tt.SetReply); diff != "" {
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(cache.SetReply{}, cache.GetReply{}),
}
if diff := cmp.Diff(setGot, tt.SetReply, cmpOpts...); diff != "" {
t.Errorf("Cache.Set() = %v", diff)
}
getGot, err := c.Get(tt.ctx, tt.GetRequest)
@ -104,7 +100,7 @@ func TestCache_Get_and_Set(t *testing.T) {
t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError)
return
}
if diff := cmp.Diff(getGot, tt.GetReply); diff != "" {
if diff := cmp.Diff(getGot, tt.GetReply, cmpOpts...); diff != "" {
t.Errorf("Cache.Get() = %v", diff)
}
})

View file

@ -5,6 +5,7 @@ import (
"crypto/x509"
"encoding/pem"
"sort"
"time"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
@ -151,6 +152,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options,
},
Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{
GrpcService: &envoy_config_core_v3.GrpcService{
Timeout: ptypes.DurationProto(time.Second * 30),
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
ClusterName: "pomerium-authz",

14
internal/grpc/cache/cache.go vendored Normal file
View file

@ -0,0 +1,14 @@
// Package cache defines a Cacher interfaces that can be implemented by any
// key value store system.
package cache
import (
"context"
)
// Cacher specifies an interface for remote clients connecting to the cache service.
type Cacher interface {
Get(ctx context.Context, key string) (value []byte, err error)
Set(ctx context.Context, key string, value []byte) error
Close() error
}

View file

@ -1,217 +1,356 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.21.0
// protoc v3.11.4
// source: cache.proto
package cache
import (
context "context"
fmt "fmt"
proto "github.com/golang/protobuf/proto"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
math "math"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type GetRequest struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
}
func (m *GetRequest) Reset() { *m = GetRequest{} }
func (m *GetRequest) String() string { return proto.CompactTextString(m) }
func (*GetRequest) ProtoMessage() {}
func (x *GetRequest) Reset() {
*x = GetRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetRequest) ProtoMessage() {}
func (x *GetRequest) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetRequest.ProtoReflect.Descriptor instead.
func (*GetRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{0}
return file_cache_proto_rawDescGZIP(), []int{0}
}
func (m *GetRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_GetRequest.Unmarshal(m, b)
}
func (m *GetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GetRequest.Marshal(b, m, deterministic)
}
func (m *GetRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_GetRequest.Merge(m, src)
}
func (m *GetRequest) XXX_Size() int {
return xxx_messageInfo_GetRequest.Size(m)
}
func (m *GetRequest) XXX_DiscardUnknown() {
xxx_messageInfo_GetRequest.DiscardUnknown(m)
}
var xxx_messageInfo_GetRequest proto.InternalMessageInfo
func (m *GetRequest) GetKey() string {
if m != nil {
return m.Key
func (x *GetRequest) GetKey() string {
if x != nil {
return x.Key
}
return ""
}
type GetReply struct {
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
}
func (m *GetReply) Reset() { *m = GetReply{} }
func (m *GetReply) String() string { return proto.CompactTextString(m) }
func (*GetReply) ProtoMessage() {}
func (x *GetReply) Reset() {
*x = GetReply{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetReply) ProtoMessage() {}
func (x *GetReply) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetReply.ProtoReflect.Descriptor instead.
func (*GetReply) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{1}
return file_cache_proto_rawDescGZIP(), []int{1}
}
func (m *GetReply) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_GetReply.Unmarshal(m, b)
}
func (m *GetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GetReply.Marshal(b, m, deterministic)
}
func (m *GetReply) XXX_Merge(src proto.Message) {
xxx_messageInfo_GetReply.Merge(m, src)
}
func (m *GetReply) XXX_Size() int {
return xxx_messageInfo_GetReply.Size(m)
}
func (m *GetReply) XXX_DiscardUnknown() {
xxx_messageInfo_GetReply.DiscardUnknown(m)
}
var xxx_messageInfo_GetReply proto.InternalMessageInfo
func (m *GetReply) GetExists() bool {
if m != nil {
return m.Exists
func (x *GetReply) GetExists() bool {
if x != nil {
return x.Exists
}
return false
}
func (m *GetReply) GetValue() []byte {
if m != nil {
return m.Value
func (x *GetReply) GetValue() []byte {
if x != nil {
return x.Value
}
return nil
}
type SetRequest struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
}
func (m *SetRequest) Reset() { *m = SetRequest{} }
func (m *SetRequest) String() string { return proto.CompactTextString(m) }
func (*SetRequest) ProtoMessage() {}
func (x *SetRequest) Reset() {
*x = SetRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetRequest) ProtoMessage() {}
func (x *SetRequest) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetRequest.ProtoReflect.Descriptor instead.
func (*SetRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{2}
return file_cache_proto_rawDescGZIP(), []int{2}
}
func (m *SetRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SetRequest.Unmarshal(m, b)
}
func (m *SetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SetRequest.Marshal(b, m, deterministic)
}
func (m *SetRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_SetRequest.Merge(m, src)
}
func (m *SetRequest) XXX_Size() int {
return xxx_messageInfo_SetRequest.Size(m)
}
func (m *SetRequest) XXX_DiscardUnknown() {
xxx_messageInfo_SetRequest.DiscardUnknown(m)
}
var xxx_messageInfo_SetRequest proto.InternalMessageInfo
func (m *SetRequest) GetKey() string {
if m != nil {
return m.Key
func (x *SetRequest) GetKey() string {
if x != nil {
return x.Key
}
return ""
}
func (m *SetRequest) GetValue() []byte {
if m != nil {
return m.Value
func (x *SetRequest) GetValue() []byte {
if x != nil {
return x.Value
}
return nil
}
type SetReply struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (m *SetReply) Reset() { *m = SetReply{} }
func (m *SetReply) String() string { return proto.CompactTextString(m) }
func (*SetReply) ProtoMessage() {}
func (x *SetReply) Reset() {
*x = SetReply{}
if protoimpl.UnsafeEnabled {
mi := &file_cache_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetReply) ProtoMessage() {}
func (x *SetReply) ProtoReflect() protoreflect.Message {
mi := &file_cache_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetReply.ProtoReflect.Descriptor instead.
func (*SetReply) Descriptor() ([]byte, []int) {
return fileDescriptor_5fca3b110c9bbf3a, []int{3}
return file_cache_proto_rawDescGZIP(), []int{3}
}
func (m *SetReply) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SetReply.Unmarshal(m, b)
}
func (m *SetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SetReply.Marshal(b, m, deterministic)
}
func (m *SetReply) XXX_Merge(src proto.Message) {
xxx_messageInfo_SetReply.Merge(m, src)
}
func (m *SetReply) XXX_Size() int {
return xxx_messageInfo_SetReply.Size(m)
}
func (m *SetReply) XXX_DiscardUnknown() {
xxx_messageInfo_SetReply.DiscardUnknown(m)
var File_cache_proto protoreflect.FileDescriptor
var file_cache_proto_rawDesc = []byte{
0x0a, 0x0b, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63,
0x61, 0x63, 0x68, 0x65, 0x22, 0x1e, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x03, 0x6b, 0x65, 0x79, 0x22, 0x38, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x12, 0x16, 0x0a, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x34,
0x0a, 0x0a, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03,
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14,
0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76,
0x61, 0x6c, 0x75, 0x65, 0x22, 0x0a, 0x0a, 0x08, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x32, 0x61, 0x0a, 0x05, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x2b, 0x0a, 0x03, 0x47, 0x65, 0x74,
0x12, 0x11, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52,
0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x03, 0x53, 0x65, 0x74, 0x12, 0x11, 0x2e,
0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c,
0x79, 0x22, 0x00, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var xxx_messageInfo_SetReply proto.InternalMessageInfo
var (
file_cache_proto_rawDescOnce sync.Once
file_cache_proto_rawDescData = file_cache_proto_rawDesc
)
func init() {
proto.RegisterType((*GetRequest)(nil), "cache.GetRequest")
proto.RegisterType((*GetReply)(nil), "cache.GetReply")
proto.RegisterType((*SetRequest)(nil), "cache.SetRequest")
proto.RegisterType((*SetReply)(nil), "cache.SetReply")
func file_cache_proto_rawDescGZIP() []byte {
file_cache_proto_rawDescOnce.Do(func() {
file_cache_proto_rawDescData = protoimpl.X.CompressGZIP(file_cache_proto_rawDescData)
})
return file_cache_proto_rawDescData
}
func init() {
proto.RegisterFile("cache.proto", fileDescriptor_5fca3b110c9bbf3a)
var file_cache_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_cache_proto_goTypes = []interface{}{
(*GetRequest)(nil), // 0: cache.GetRequest
(*GetReply)(nil), // 1: cache.GetReply
(*SetRequest)(nil), // 2: cache.SetRequest
(*SetReply)(nil), // 3: cache.SetReply
}
var file_cache_proto_depIdxs = []int32{
0, // 0: cache.Cache.Get:input_type -> cache.GetRequest
2, // 1: cache.Cache.Set:input_type -> cache.SetRequest
1, // 2: cache.Cache.Get:output_type -> cache.GetReply
3, // 3: cache.Cache.Set:output_type -> cache.SetReply
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
var fileDescriptor_5fca3b110c9bbf3a = []byte{
// 176 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x4e, 0x4c, 0xce,
0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0xb8, 0xdc,
0x53, 0x4b, 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x04, 0xb8, 0x98, 0xb3, 0x53, 0x2b,
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x40, 0x4c, 0x25, 0x0b, 0x2e, 0x0e, 0xb0, 0x7c, 0x41,
0x4e, 0xa5, 0x90, 0x18, 0x17, 0x5b, 0x6a, 0x45, 0x66, 0x71, 0x49, 0x31, 0x58, 0x01, 0x47, 0x10,
0x94, 0x27, 0x24, 0xc2, 0xc5, 0x5a, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, 0xa4, 0xc0, 0xa8, 0xc1,
0x13, 0x04, 0xe1, 0x28, 0x99, 0x70, 0x71, 0x05, 0xe3, 0x31, 0x19, 0x87, 0x2e, 0x2e, 0x2e, 0x8e,
0x60, 0xa8, 0x7d, 0x46, 0x89, 0x5c, 0xac, 0xce, 0x20, 0x47, 0x0a, 0x69, 0x73, 0x31, 0xbb, 0xa7,
0x96, 0x08, 0x09, 0xea, 0x41, 0x3c, 0x80, 0x70, 0xb0, 0x14, 0x3f, 0xb2, 0x50, 0x41, 0x4e, 0xa5,
0x12, 0x03, 0x48, 0x71, 0x30, 0x92, 0xe2, 0x60, 0x4c, 0xc5, 0xc1, 0x70, 0xc5, 0x49, 0x6c, 0xe0,
0xc0, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x0e, 0xef, 0x5f, 0x9e, 0x1b, 0x01, 0x00, 0x00,
func init() { file_cache_proto_init() }
func file_cache_proto_init() {
if File_cache_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_cache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_cache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_cache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_cache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_cache_proto_rawDesc,
NumEnums: 0,
NumMessages: 4,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_cache_proto_goTypes,
DependencyIndexes: file_cache_proto_depIdxs,
MessageInfos: file_cache_proto_msgTypes,
}.Build()
File_cache_proto = out.File
file_cache_proto_rawDesc = nil
file_cache_proto_goTypes = nil
file_cache_proto_depIdxs = nil
}
// Reference imports to suppress errors if they are not otherwise used.
@ -266,10 +405,10 @@ type CacheServer interface {
type UnimplementedCacheServer struct {
}
func (*UnimplementedCacheServer) Get(ctx context.Context, req *GetRequest) (*GetReply, error) {
func (*UnimplementedCacheServer) Get(context.Context, *GetRequest) (*GetReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
}
func (*UnimplementedCacheServer) Set(ctx context.Context, req *SetRequest) (*SetReply, error) {
func (*UnimplementedCacheServer) Set(context.Context, *SetRequest) (*SetReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
}

View file

@ -3,6 +3,7 @@ package client
import (
"context"
"errors"
"github.com/pomerium/pomerium/internal/grpc/cache"
"github.com/pomerium/pomerium/internal/telemetry/trace"
@ -10,12 +11,7 @@ import (
"google.golang.org/grpc"
)
// Cacher specifies an interface for remote clients connecting to the cache service.
type Cacher interface {
Get(ctx context.Context, key string) (keyExists bool, value []byte, err error)
Set(ctx context.Context, key string, value []byte) error
Close() error
}
var errKeyNotFound = errors.New("cache/client: key not found")
// Client represents a gRPC cache service client.
type Client struct {
@ -29,15 +25,18 @@ func New(conn *grpc.ClientConn) (p *Client) {
}
// Get retrieves a value from the cache service.
func (a *Client) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
func (a *Client) Get(ctx context.Context, key string) (value []byte, err error) {
ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get")
defer span.End()
response, err := a.client.Get(ctx, &cache.GetRequest{Key: key})
if err != nil {
return false, nil, err
return nil, err
}
return response.GetExists(), response.GetValue(), nil
if !response.GetExists() {
return nil, errKeyNotFound
}
return response.GetValue(), nil
}
// Set stores a key value pair in the cache service.

77
internal/grpc/cache/mock/mock_cacher.go vendored Normal file
View file

@ -0,0 +1,77 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/pomerium/pomerium/internal/grpc/cache (interfaces: Cacher)
// Package mock_cache is a generated GoMock package.
package mock_cache
import (
context "context"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockCacher is a mock of Cacher interface
type MockCacher struct {
ctrl *gomock.Controller
recorder *MockCacherMockRecorder
}
// MockCacherMockRecorder is the mock recorder for MockCacher
type MockCacherMockRecorder struct {
mock *MockCacher
}
// NewMockCacher creates a new mock instance
func NewMockCacher(ctrl *gomock.Controller) *MockCacher {
mock := &MockCacher{ctrl: ctrl}
mock.recorder = &MockCacherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCacher) EXPECT() *MockCacherMockRecorder {
return m.recorder
}
// Close mocks base method
func (m *MockCacher) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockCacherMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCacher)(nil).Close))
}
// Get mocks base method
func (m *MockCacher) Get(arg0 context.Context, arg1 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockCacherMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCacher)(nil).Get), arg0, arg1)
}
// Set mocks base method
func (m *MockCacher) Set(arg0 context.Context, arg1 string, arg2 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// Set indicates an expected call of Set
func (mr *MockCacherMockRecorder) Set(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCacher)(nil).Set), arg0, arg1, arg2)
}

View file

@ -0,0 +1,22 @@
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
package hashutil
import (
"github.com/cespare/xxhash/v2"
"github.com/mitchellh/hashstructure"
)
// Hash returns the xxhash value of an arbitrary value or struct. Returns 0
// on error. NOT SUITABLE FOR CRYTOGRAPHIC HASHING.
//
// http://cyan4973.github.io/xxHash/
func Hash(v interface{}) uint64 {
opts := &hashstructure.HashOptions{
Hasher: xxhash.New(),
}
hash, err := hashstructure.Hash(v, opts)
if err != nil {
hash = 0
}
return hash
}

View file

@ -0,0 +1,37 @@
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
package hashutil
import "testing"
func TestHash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
v interface{}
want uint64
}{
{"string", "string", 6134271061086542852},
{"num", 7, 609900476111905877},
{"compound struct", struct {
NESCarts []string
numberOfCarts int
}{
[]string{"Battletoads", "Mega Man 1", "Clash at Demonhead"},
12,
},
9061978360207659575},
{"compound struct with embedded func (errors!)", struct {
AnswerToEverythingFn func() int
}{
func() int { return 42 },
},
0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Hash(tt.v); got != tt.want {
t.Errorf("Hash() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -5,15 +5,13 @@ import (
"net/url"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/sessions"
)
// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse sessions.State
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse sessions.State
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
@ -22,12 +20,12 @@ type MockProvider struct {
}
// Authenticate is a mocked providers function.
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
func (mp MockProvider) Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) {
return &mp.AuthenticateResponse, mp.AuthenticateError
}
// Refresh is a mocked providers function.
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
func (mp MockProvider) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) {
return &mp.RefreshResponse, mp.RefreshError
}

View file

@ -20,7 +20,6 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
@ -77,48 +76,36 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// Authenticate creates an identity session with github from a authorization code, and follows up
// call to the user and user group endpoint with the
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
resp, err := p.Oauth.Exchange(ctx, code)
func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
oauth2Token, err := p.Oauth.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("github: token exchange failed %v", err)
}
s := &sessions.State{
AccessToken: &oauth2.Token{
AccessToken: resp.AccessToken,
TokenType: resp.TokenType,
},
AccessTokenID: resp.AccessToken,
}
err = p.updateSessionState(ctx, s)
err = p.updateSessionState(ctx, oauth2Token, v)
if err != nil {
return nil, err
}
return s, nil
return oauth2Token, nil
}
// updateSessionState will get the user information from github and also retrieve the user's team(s)
//
// https://developer.github.com/v3/users/#get-the-authenticated-user
func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) error {
if s == nil || s.AccessToken == nil {
return errors.New("github: user session cannot be empty")
}
accessToken := s.AccessToken.AccessToken
func (p *Provider) updateSessionState(ctx context.Context, t *oauth2.Token, v interface{}) error {
err := p.userInfo(ctx, accessToken, s)
err := p.userInfo(ctx, t, v)
if err != nil {
return fmt.Errorf("github: could not retrieve user info %w", err)
}
err = p.userEmail(ctx, accessToken, s)
err = p.userEmail(ctx, t, v)
if err != nil {
return fmt.Errorf("github: could not retrieve user email %w", err)
}
err = p.userTeams(ctx, accessToken, s)
err = p.userTeams(ctx, t, v)
if err != nil {
return fmt.Errorf("github: could not retrieve groups %w", err)
}
@ -127,14 +114,12 @@ func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) er
}
// Refresh renews a user's session by making a new userInfo request.
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.AccessToken == nil {
return nil, errors.New("github: missing oauth2 access token")
}
if err := p.updateSessionState(ctx, s); err != nil {
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
err := p.updateSessionState(ctx, t, v)
if err != nil {
return nil, err
}
return s, nil
return t, nil
}
// userTeams returns a slice of teams the user belongs by making a request
@ -142,7 +127,7 @@ func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.St
//
// https://developer.github.com/v3/teams/#list-user-teams
// https://developer.github.com/v3/auth/
func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) error {
func (p *Provider) userTeams(ctx context.Context, t *oauth2.Token, v interface{}) error {
var response []struct {
ID json.Number `json:"id"`
@ -154,20 +139,24 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
Privacy string `json:"privacy,omitempty"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)}
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
teamURL := githubAPIURL + teamPath
err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return err
}
log.Debug().Interface("teams", response).Msg("github: user teams")
s.Groups = nil
for _, org := range response {
s.Groups = append(s.Groups, org.ID.String())
var out struct {
Groups []string `json:"groups"`
}
return nil
for _, org := range response {
out.Groups = append(out.Groups, org.ID.String())
}
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}
// userEmail returns the primary email of the user by making
@ -175,7 +164,7 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
//
// https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user
// https://developer.github.com/v3/auth/
func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State) error {
func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}) error {
// response represents the github user email
// https://developer.github.com/v3/users/emails/#response
var response []struct {
@ -184,48 +173,67 @@ func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State)
Primary bool `json:"primary"`
Visibility string `json:"visibility"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)}
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
emailURL := githubAPIURL + emailPath
err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return err
}
var out struct {
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
log.Debug().Interface("emails", response).Msg("github: user emails")
for _, email := range response {
if email.Primary && email.Verified {
s.Email = email.Email
s.EmailVerified = true
return nil
out.Email = email.Email
out.Verified = true
break
}
}
return nil
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}
func (p *Provider) userInfo(ctx context.Context, at string, s *sessions.State) error {
func (p *Provider) userInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
var response struct {
ID int `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url,omitempty"`
}
headers := map[string]string{
"Authorization": fmt.Sprintf("token %s", at),
"Authorization": fmt.Sprintf("token %s", t.AccessToken),
"Accept": "application/vnd.github.v3+json",
}
err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response)
if err != nil {
return err
}
var out struct {
Subject string `json:"sub"`
Name string `json:"name,omitempty"`
User string `json:"user"`
Picture string `json:"picture,omitempty"`
// needs to be set manually
Expiry *jwt.NumericDate `json:"exp,omitempty"`
}
s.User = response.Login
s.Name = response.Name
s.Picture = response.AvatarURL
out.User = response.Login
out.Subject = response.Login
out.Name = response.Name
out.Picture = response.AvatarURL
// set the session expiry
s.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
return nil
out.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}
// Revoke method will remove all the github grants the user

View file

@ -5,7 +5,7 @@ package azure
import (
"context"
"errors"
"encoding/json"
"fmt"
"net/http"
"time"
@ -16,7 +16,6 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
@ -60,10 +59,7 @@ func (p *Provider) GetSignInURL(state string) string {
// `Directory.Read.All` is required.
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
if s == nil || s.AccessToken == nil {
return nil, errors.New("identity/azure: session cannot be nil")
}
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
var response struct {
Groups []struct {
ID string `json:"id"`
@ -73,15 +69,23 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
GroupTypes []string `json:"groupTypes,omitempty"`
} `json:"value"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
return err
}
log.Debug().Interface("response", response).Msg("microsoft: groups")
var out struct {
Groups []string `json:"groups"`
}
var groups []string
for _, group := range response.Groups {
log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("microsoft: group")
groups = append(groups, group.ID)
out.Groups = append(out.Groups, group.ID)
}
return groups, nil
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}

View file

@ -1,12 +1,14 @@
package oidc
import "errors"
import (
"errors"
)
// ErrRevokeNotImplemented error type when Revoke method is not implemented
// by an identity provider
// ErrRevokeNotImplemented is returned when revoke is not implemented
// by an identity provider.
var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented")
// ErrSignoutNotImplemented error type when end session is not implemented
// ErrSignoutNotImplemented is returned when end session is not implemented
// by an identity provider
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
@ -14,3 +16,13 @@ var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implem
// ErrMissingProviderURL is returned when an identity provider requires a provider url
// does not receive one.
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
// ErrMissingIDToken is returned when (usually on refresh) and identity provider
// failed to include an id_token in a oauth2 token.
var ErrMissingIDToken = errors.New("identity/oidc: missing id_token")
// ErrMissingRefreshToken is returned if no refresh token was found.
var ErrMissingRefreshToken = errors.New("identity/oidc: missing refresh token")
// ErrMissingAccessToken is returned when no access token was found.
var ErrMissingAccessToken = errors.New("identity/oidc: missing access token")

View file

@ -6,7 +6,6 @@ package gitlab
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -15,14 +14,14 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
"golang.org/x/oauth2"
)
// Name identifies the GitLab identity provider
// Name identifies the GitLab identity provider.
const Name = "gitlab"
var defaultScopes = []string{oidc.ScopeOpenID, "read_api", "read_user", "profile", "email"}
var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email", "api"}
const (
defaultProviderURL = "https://gitlab.com"
@ -64,11 +63,7 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
//
// Returns 20 results at a time because the API results are paginated.
// https://docs.gitlab.com/ee/api/groups.html#list-groups
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
if s == nil || s.AccessToken == nil {
return nil, errors.New("gitlab: user session cannot be empty")
}
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
var response []struct {
ID json.Number `json:"id"`
Name string `json:"name,omitempty"`
@ -81,17 +76,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
FullName string `json:"full_name,omitempty"`
FullPath string `json:"full_path,omitempty"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
return err
}
var groups []string
log.Debug().Interface("response", response).Msg("gitlab: groups")
var out struct {
Groups []string `json:"groups"`
}
for _, group := range response {
groups = append(groups, group.ID.String())
out.Groups = append(out.Groups, group.ID.String())
}
b, err := json.Marshal(out)
if err != nil {
return err
}
return groups, nil
return json.Unmarshal(b, v)
}

View file

@ -8,6 +8,7 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
oidc "github.com/coreos/go-oidc"
@ -18,7 +19,6 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
)
const (
@ -54,36 +54,35 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
}
p.Provider = genericOidc
// if service account set, configure admin sdk calls
if o.ServiceAccount != "" {
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
if err != nil {
return nil, fmt.Errorf("google: could not decode service account json %w", err)
}
// Required scopes for groups api
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
if err != nil {
return nil, fmt.Errorf("google: failed making jwt config from json %w", err)
}
var credentialsFile struct {
ImpersonateUser string `json:"impersonate_user"`
}
if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil {
return nil, err
}
conf.Subject = credentialsFile.ImpersonateUser
client := conf.Client(context.TODO())
p.apiClient, err = admin.New(client)
if err != nil {
return nil, fmt.Errorf("google: failed creating admin service %w", err)
}
p.UserGroupFn = p.UserGroups
} else {
log.Warn().Msg("google: no service account, cannot retrieve groups")
if o.ServiceAccount == "" {
log.Warn().Msg("google: no service account, will not fetch groups")
return &p, nil
}
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
if err != nil {
return nil, fmt.Errorf("google: could not decode service account json %w", err)
}
// Required scopes for groups api
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
if err != nil {
return nil, fmt.Errorf("google: failed making jwt config from json %w", err)
}
var credentialsFile struct {
ImpersonateUser string `json:"impersonate_user"`
}
if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil {
return nil, err
}
conf.Subject = credentialsFile.ImpersonateUser
client := conf.Client(context.TODO())
p.apiClient, err = admin.New(client)
if err != nil {
return nil, fmt.Errorf("google: failed creating admin service %w", err)
}
p.UserGroupFn = p.UserGroups
return &p, nil
}
@ -106,17 +105,34 @@ func (p *Provider) GetSignInURL(state string) string {
// NOTE: groups via Directory API is limited to 1 QPS!
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
// https://developers.google.com/admin-sdk/directory/v1/limits
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
var groups []string
if p.apiClient != nil {
req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100)
resp, err := req.Do()
if err != nil {
return nil, fmt.Errorf("google: group api request failed %w", err)
}
for _, group := range resp.Groups {
groups = append(groups, group.Email)
}
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
if p.apiClient == nil {
return errors.New("google: trying to fetch groups, but no api client")
}
return groups, nil
s, err := p.GetSubject(v)
if err != nil {
return err
}
var out struct {
Groups []string `json:"groups"`
}
req := p.apiClient.Groups.List().Context(ctx).UserKey(s)
err = req.Pages(ctx, func(resp *admin.Groups) error {
for _, group := range resp.Groups {
out.Groups = append(out.Groups, group.Email)
}
return nil
})
if err != nil {
return err
}
_, err = req.Do()
if err != nil {
return fmt.Errorf("google: group api request failed %w", err)
}
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}

View file

@ -5,6 +5,7 @@ package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -15,12 +16,11 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
)
// Name identifies the generic OpenID Connect provider
// Name identifies the generic OpenID Connect provider.
const Name = "oidc"
var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"}
@ -37,11 +37,6 @@ type Provider struct {
// client application information and the server's endpoint URLs.
Oauth *oauth2.Config
// UserInfoURL specifies the endpoint responsible for returning claims
// about the authenticated End-User.
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
UserInfoURL string `json:"userinfo_endpoint,omitempty"`
// RevocationURL is the location of the OAuth 2.0 token revocation endpoint.
// https://tools.ietf.org/html/rfc7009
RevocationURL string `json:"revocation_endpoint,omitempty"`
@ -53,7 +48,7 @@ type Provider struct {
// UserGroupFn is, if set, used to return a slice of group IDs the
// user is a member of
UserGroupFn func(context.Context, *sessions.State) ([]string, error)
UserGroupFn func(context.Context, *oauth2.Token, interface{}) error
}
// New creates a new instance of a generic OpenID Connect provider.
@ -100,81 +95,89 @@ func (p *Provider) GetSignInURL(state string) string {
// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
// Exchange converts an authorization code into a token.
oauth2Token, err := p.Oauth.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err)
}
idToken, err := p.IdentityFromToken(ctx, oauth2Token)
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
}
aud, err := urlutil.ParseAndValidateURL(p.Oauth.RedirectURL)
// hydrate `v` using claims inside the returned `id_token`
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
if err := idToken.Claims(v); err != nil {
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
}
if err := p.updateUserInfo(ctx, oauth2Token, v); err != nil {
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
}
return oauth2Token, nil
}
// updateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
// groups endpoint (non-spec) to populate the rest of the user's information.
//
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t))
if err != nil {
return nil, fmt.Errorf("identity/oidc: bad redirect uri: %w", err)
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
}
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, aud.Hostname())
if err != nil {
return nil, err
if err := userInfo.Claims(v); err != nil {
return fmt.Errorf("identity/oidc: failed parsing user info endpoint claims: %w", err)
}
if err := p.Provider.Claims(&p); err == nil && p.UserInfoURL != "" {
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
if err != nil {
return nil, fmt.Errorf("identity/oidc: could not retrieve user info %w", err)
}
if err := userInfo.Claims(&s); err != nil {
return nil, fmt.Errorf("identity/oidc: could not parse user claims %w", err)
}
}
if p.UserGroupFn != nil {
s.Groups, err = p.UserGroupFn(ctx, s)
if err != nil {
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
if err := p.UserGroupFn(ctx, t, v); err != nil {
return fmt.Errorf("identity/oidc: could not retrieve groups: %w", err)
}
}
return s, nil
return nil
}
// Refresh renews a user's session using an oidc refresh token without reprompting the user.
// Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.AccessToken == nil || s.AccessToken.RefreshToken == "" {
return nil, errors.New("internal/oidc: missing refresh token")
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
if t == nil {
return nil, ErrMissingAccessToken
}
if t.RefreshToken == "" {
return nil, ErrMissingRefreshToken
}
var err error
newToken, err := p.Oauth.TokenSource(ctx, t).Token()
if err != nil {
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
}
t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken}
oauthToken, err := p.Oauth.TokenSource(ctx, &t).Token()
if err != nil {
return nil, fmt.Errorf("internal/oidc: refresh failed %w", err)
}
idToken, err := p.IdentityFromToken(ctx, oauthToken)
if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
}
if err := s.UpdateState(idToken, oauthToken); err != nil {
return nil, fmt.Errorf("internal/oidc: state update failed %w", err)
}
if p.UserGroupFn != nil {
s.Groups, err = p.UserGroupFn(ctx, s)
if err != nil {
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
// Many identity providers _will not_ return `id_token` on refresh
// https://github.com/FusionAuth/fusionauth-issues/issues/110#issuecomment-481526544
idToken, err := p.getIDToken(ctx, newToken)
if err == nil {
if err := idToken.Claims(v); err != nil {
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
}
}
return s, nil
if err := p.updateUserInfo(ctx, newToken, v); err != nil {
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
}
return newToken, nil
}
// IdentityFromToken takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id.
func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
// getIDToken returns the raw jwt payload for `id_token` from the oauth2 token
// returned following oidc code flow
//
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
rawIDToken, ok := t.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("internal/oidc: id_token not found")
return nil, ErrMissingIDToken
}
return p.Verifier.Verify(ctx, rawIDToken)
}
@ -183,13 +186,16 @@ func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_
// support revocation an error is thrown.
//
// https://tools.ietf.org/html/rfc7009#section-2.1
func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
if p.RevocationURL == "" {
return ErrRevokeNotImplemented
}
if t == nil {
return ErrMissingAccessToken
}
params := url.Values{}
params.Add("token", token.AccessToken)
params.Add("token", t.AccessToken)
params.Add("token_type_hint", "access_token")
// Some providers like okta / onelogin require "client authentication"
// https://developer.okta.com/docs/reference/api/oidc/#client-secret
@ -198,7 +204,7 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
params.Add("client_secret", p.Oauth.ClientSecret)
err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
if err != nil && errors.Is(err, httputil.ErrTokenRevoked) {
return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err)
}
@ -214,3 +220,20 @@ func (p *Provider) LogOut() (*url.URL, error) {
}
return urlutil.ParseAndValidateURL(p.EndSessionURL)
}
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
func (p *Provider) GetSubject(v interface{}) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
var s struct {
Subject string `json:"sub"`
}
err = json.Unmarshal(b, &s)
if err != nil {
return "", err
}
return s.Subject, nil
}

View file

@ -5,6 +5,7 @@ package okta
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
@ -13,9 +14,9 @@ import (
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
"golang.org/x/oauth2"
)
const (
@ -64,7 +65,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
s, err := p.GetSubject(v)
if err != nil {
return err
}
var response []struct {
ID string `json:"id"`
Profile struct {
@ -74,15 +79,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
}
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)}
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject)
err := httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s)
err = httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
return err
}
log.Debug().Interface("response", response).Msg("okta: groups")
var out struct {
Groups []string `json:"groups"`
}
var groups []string
for _, group := range response {
log.Debug().Interface("group", group).Msg("okta: group")
groups = append(groups, group.ID)
out.Groups = append(out.Groups, group.ID)
}
return groups, nil
b, err := json.Marshal(out)
if err != nil {
return err
}
return json.Unmarshal(b, v)
}

View file

@ -5,17 +5,15 @@ package onelogin
import (
"context"
"errors"
"fmt"
"net/http"
"time"
oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
@ -55,24 +53,10 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
// UserGroups returns a slice of group names a given user is in.
// https://developers.onelogin.com/openid-connect/api/user-info
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
if s == nil || s.AccessToken == nil {
return nil, errors.New("identity/onelogin: session cannot be nil")
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
if t == nil {
return pom_oidc.ErrMissingAccessToken
}
var response struct {
User string `json:"sub"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Name string `json:"name"`
UpdatedAt time.Time `json:"updated_at"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Groups []string `json:"groups"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
}
return response.Groups, nil
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
return httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, v)
}

View file

@ -17,25 +17,24 @@ import (
"github.com/pomerium/pomerium/internal/identity/oidc/google"
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
"github.com/pomerium/pomerium/internal/sessions"
)
var (
// compile time assertions that providers are satisfying the interface
_ Authenticator = &azure.Provider{}
_ Authenticator = &gitlab.Provider{}
_ Authenticator = &github.Provider{}
_ Authenticator = &gitlab.Provider{}
_ Authenticator = &google.Provider{}
_ Authenticator = &MockProvider{}
_ Authenticator = &oidc.Provider{}
_ Authenticator = &okta.Provider{}
_ Authenticator = &onelogin.Provider{}
_ Authenticator = &MockProvider{}
)
// Authenticator is an interface representing the ability to authenticate with an identity provider.
type Authenticator interface {
Authenticate(context.Context, string) (*sessions.State, error)
Refresh(context.Context, *sessions.State) (*sessions.State, error)
Authenticate(context.Context, string, interface{}) (*oauth2.Token, error)
Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) string
LogOut() (*url.URL, error)

View file

@ -37,6 +37,9 @@ type Store struct {
srv *http.Server
}
// ErrCacheMiss is returned when the cache misses for a given key.
var ErrCacheMiss = errors.New("cache miss")
// Options represent autocache options.
type Options struct {
Addr string
@ -60,7 +63,7 @@ var DefaultOptions = &Options{
GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error {
b := fromContext(ctx)
if len(b) == 0 {
return fmt.Errorf("autocache: empty ctx for id: %s", id)
return fmt.Errorf("autocache: id %s : %w", id, ErrCacheMiss)
}
if err := dest.SetBytes(b); err != nil {
return fmt.Errorf("autocache: sink error %w", err)

View file

@ -1,95 +0,0 @@
// Package cache provides a remote cache based implementation of session store
// and loader. See pomerium's cache service for more details.
package cache
import (
"errors"
"fmt"
"net/http"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
// Store implements the session store interface using a cache service.
type Store struct {
cache client.Cacher
encoder encoding.MarshalUnmarshaler
queryParam string
wrappedStore sessions.SessionStore
}
// Options represent cache store's available configurations.
type Options struct {
Cache client.Cacher
Encoder encoding.MarshalUnmarshaler
QueryParam string
WrappedStore sessions.SessionStore
}
var defaultOptions = &Options{
QueryParam: "cache_store_key",
}
// NewStore creates a new cache
func NewStore(o *Options) *Store {
if o.QueryParam == "" {
o.QueryParam = defaultOptions.QueryParam
}
return &Store{
cache: o.Cache,
encoder: o.Encoder,
queryParam: o.QueryParam,
wrappedStore: o.WrappedStore,
}
}
// LoadSession looks for a preset query parameter in the request body
// representing the key to lookup from the cache.
func (s *Store) LoadSession(r *http.Request) (string, error) {
// look for our cache's key in the default query param
sessionID := r.URL.Query().Get(s.queryParam)
if sessionID == "" {
return "", sessions.ErrNoSessionFound
}
exists, val, err := s.cache.Get(r.Context(), sessionID)
if err != nil {
log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader")
return "", err
}
if !exists {
return "", sessions.ErrNoSessionFound
}
return string(val), nil
}
// ClearSession clears the session from the wrapped store.
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
s.wrappedStore.ClearSession(w, r)
}
// SaveSession saves the session to the cache, and wrapped store.
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
err := s.wrappedStore.SaveSession(w, r, x)
if err != nil {
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
}
state, ok := x.(*sessions.State)
if !ok {
return errors.New("sessions/cache: cannot cache non state type")
}
data, err := s.encoder.Marshal(&state)
if err != nil {
return fmt.Errorf("sessions/cache: marshal %w", err)
}
return s.cache.Set(r.Context(), state.AccessTokenID, data)
}

View file

@ -1,190 +0,0 @@
package cache
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/grpc/cache/client"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/mock"
"gopkg.in/square/go-jose.v2/jwt"
)
type mockCache struct {
Key string
KeyExists bool
Value []byte
Err error
}
func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
return mc.KeyExists, mc.Value, mc.Err
}
func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error {
return mc.Err
}
func (mc *mockCache) Close() error {
return mc.Err
}
func TestNewStore(t *testing.T) {
tests := []struct {
name string
Options *Options
State *sessions.State
wantErr bool
wantLoadErr bool
wantStatus int
}{
{"simple good",
&Options{
Cache: &mockCache{},
WrappedStore: &mock.Store{},
Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")},
},
&sessions.State{Email: "user@domain.com", User: "user"},
false, false,
http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewStore(tt.Options)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
got.ClearSession(w, r)
status := w.Result().StatusCode
if diff := cmp.Diff(status, tt.wantStatus); diff != "" {
t.Errorf("ClearSession() = %v", diff)
}
})
}
}
func TestStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
Options *Options
x interface{}
wantErr bool
}{
{"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
{"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
{"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
{"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := tt.Options
if o.WrappedStore == nil {
o.WrappedStore = cs
}
cacheStore := NewStore(tt.Options)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestStore_LoadSession(t *testing.T) {
key := cryptutil.NewBase64Key()
tests := []struct {
name string
state *sessions.State
cache client.Cacher
encoder encoding.MarshalUnmarshaler
queryParam string
wrappedStore sessions.SessionStore
wantErr bool
}{
{"good",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
false},
{"missing param with key",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: true},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
"bad_query",
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"doesn't exist",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{KeyExists: false},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
{"retrieval error",
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
&mockCache{Err: errors.New("err")},
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
defaultOptions.QueryParam,
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Store{
cache: tt.cache,
encoder: tt.encoder,
queryParam: tt.queryParam,
wrappedStore: tt.wrappedStore,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
q := r.URL.Query()
q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID)
r.URL.RawQuery = q.Encode()
r.Header.Set("Accept", "application/json")
_, err := s.LoadSession(r)
if (err != nil) != tt.wantErr {
t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -1,15 +1,11 @@
package sessions
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/cespare/xxhash/v2"
oidc "github.com/coreos/go-oidc"
"github.com/mitchellh/hashstructure"
"github.com/pomerium/pomerium/internal/hashutil"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -27,6 +23,9 @@ type State struct {
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
// At_hash is an OPTIONAL Access Token hash value
// https://ldapwiki.com/wiki/At_hash
AccessTokenHash string `json:"at_hash,omitempty"`
// core pomerium identity claims ; not standard to RFC 7519
Email string `json:"email"`
@ -48,84 +47,24 @@ type State struct {
// Programmatic whether this state is used for machine-to-machine
// programatic access.
Programmatic bool `json:"programatic"`
AccessToken *oauth2.Token `json:"act,omitempty"`
AccessTokenID string `json:"ati,omitempty"`
idToken *oidc.IDToken
}
// NewStateFromTokens returns a session state built from oidc and oauth2
// tokens as part of OpenID Connect flow with a new audience appended to the
// audience claim.
func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) {
if idToken == nil {
return nil, errors.New("sessions: oidc id token missing")
}
if accessToken == nil {
return nil, errors.New("sessions: oauth2 token missing")
}
s := &State{}
if err := idToken.Claims(s); err != nil {
return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err)
}
s.Audience = []string{audience}
s.idToken = idToken
s.AccessToken = accessToken
s.AccessTokenID = s.accessTokenHash()
return s, nil
}
// UpdateState updates the current state given a new identity (oidc) and authorization
// (oauth2) tokens following a oidc refresh. NB, unlike during authentication,
// refresh typically provides fewer claims in the token so we want to build from
// our previous state.
func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error {
if idToken == nil {
return errors.New("sessions: oidc id token missing")
}
if accessToken == nil {
return errors.New("sessions: oauth2 token missing")
}
audience := append(s.Audience[:0:0], s.Audience...)
s.AccessToken = accessToken
if err := idToken.Claims(s); err != nil {
return fmt.Errorf("sessions: update state failed %w", err)
}
s.Audience = audience
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
s.AccessTokenID = s.accessTokenHash()
return nil
}
// NewSession updates issuer, audience, and issuance timestamps but keeps
// parent expiry.
func (s State) NewSession(issuer string, audience []string) *State {
s.IssuedAt = jwt.NewNumericDate(timeNow())
s.NotBefore = s.IssuedAt
s.Audience = audience
s.Issuer = issuer
return &s
}
// RouteSession creates a route session with access tokens stripped.
func (s State) RouteSession() *State {
s.AccessToken = nil
return &s
func NewSession(s *State, issuer string, audience []string, accessToken *oauth2.Token) State {
newState := *s
newState.IssuedAt = jwt.NewNumericDate(timeNow())
newState.NotBefore = newState.IssuedAt
newState.Audience = audience
newState.Issuer = issuer
newState.AccessTokenHash = fmt.Sprintf("%x", hashutil.Hash(accessToken))
newState.Expiry = jwt.NewNumericDate(accessToken.Expiry)
return newState
}
// IsExpired returns true if the users's session is expired.
func (s *State) IsExpired() bool {
if s.Expiry != nil && timeNow().After(s.Expiry.Time()) {
return true
}
if s.AccessToken != nil && timeNow().After(s.AccessToken.Expiry) {
return true
}
return false
return s.Expiry != nil && timeNow().After(s.Expiry.Time())
}
// Impersonating returns if the request is impersonating.
@ -133,23 +72,6 @@ func (s *State) Impersonating() bool {
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
}
// RequestEmail is the email to make the request as.
func (s *State) RequestEmail() string {
if s.ImpersonateEmail != "" {
return s.ImpersonateEmail
}
return s.Email
}
// RequestGroups returns the groups of the Groups making the request; uses
// impersonating user if set.
func (s *State) RequestGroups() string {
if len(s.ImpersonateGroups) != 0 {
return strings.Join(s.ImpersonateGroups, ",")
}
return strings.Join(s.Groups, ",")
}
// SetImpersonation sets impersonation user and groups.
func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateEmail = email
@ -159,34 +81,3 @@ func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateGroups = strings.Split(groups, ",")
}
}
func (s *State) accessTokenHash() string {
hash, err := hashstructure.Hash(
s.AccessToken,
&hashstructure.HashOptions{Hasher: xxhash.New()})
if err != nil {
return ""
}
return fmt.Sprintf("%x", hash)
}
// UnmarshalJSON parses the JSON-encoded session state.
// TODO(BDD): remove in v0.8.0
func (s *State) UnmarshalJSON(b []byte) error {
type Alias State
t := &struct {
*Alias
OldToken *oauth2.Token `json:"access_token,omitempty"` // < v0.5.0
}{
Alias: (*Alias)(s),
}
if err := json.Unmarshal(b, &t); err != nil {
return err
}
if t.AccessToken == nil {
t.AccessToken = t.OldToken
}
*s = *(*State)(t.Alias)
return nil
}

View file

@ -5,8 +5,6 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -38,12 +36,6 @@ func TestState_Impersonating(t *testing.T) {
if got := s.Impersonating(); got != tt.want {
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
}
if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail {
t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
}
if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups {
t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
}
})
}
}
@ -63,16 +55,14 @@ func TestState_IsExpired(t *testing.T) {
}{
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
{"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &State{
Audience: tt.Audience,
Expiry: tt.Expiry,
NotBefore: tt.NotBefore,
IssuedAt: tt.IssuedAt,
AccessToken: tt.AccessToken,
Audience: tt.Audience,
Expiry: tt.Expiry,
NotBefore: tt.NotBefore,
IssuedAt: tt.IssuedAt,
}
if exp := s.IsExpired(); exp != tt.wantErr {
t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr)
@ -80,67 +70,3 @@ func TestState_IsExpired(t *testing.T) {
})
}
}
func TestState_RouteSession(t *testing.T) {
now := time.Now()
timeNow = func() time.Time {
return now
}
tests := []struct {
name string
Issuer string
Audience jwt.Audience
Expiry *jwt.NumericDate
AccessToken *oauth2.Token
issuer string
audience []string
want *State
}{
{"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow())}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := State{
Issuer: tt.Issuer,
Audience: tt.Audience,
Expiry: tt.Expiry,
AccessToken: tt.AccessToken,
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(State{}),
}
got := s.NewSession(tt.issuer, tt.audience)
got = got.RouteSession()
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("State.RouteSession() = %s", diff)
}
})
}
}
func TestState_accessTokenHash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
state State
want string
}{
{"empty access token", State{}, "34c96acdcadb1bbb"},
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &tt.state
if got := s.accessTokenHash(); got != tt.want {
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -17,6 +17,7 @@ const (
QueryRefreshToken = "pomerium_refresh_token"
QueryAccessTokenID = "pomerium_session_access_token_id"
QueryAudience = "pomerium_session_audience"
QueryProgrammaticToken = "pomerium_programmatic_token"
)
// URL signature based query params used for verifying the authenticity of a URL.