mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-04 21:06:03 +02:00
authenticate: save oauth2 tokens to cache (#698)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
ef399380b7
commit
666fd6aa35
31 changed files with 1127 additions and 1061 deletions
6
Makefile
6
Makefile
|
@ -38,6 +38,12 @@ GETENVOY_VERSION = v0.1.8
|
||||||
all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet.
|
all: clean build-deps test lint spellcheck build ## Runs a clean, build, fmt, lint, test, and vet.
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: generate-mocks
|
||||||
|
generate-mocks: ## Generate mocks
|
||||||
|
@echo "==> $@"
|
||||||
|
@go run github.com/golang/mock/mockgen -destination authorize/evaluator/mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
|
||||||
|
@go run github.com/golang/mock/mockgen -destination internal/grpc/cache/mock/mock_cacher.go github.com/pomerium/pomerium/internal/grpc/cache Cacher
|
||||||
|
|
||||||
.PHONY: build-deps
|
.PHONY: build-deps
|
||||||
build-deps: ## Install build dependencies
|
build-deps: ## Install build dependencies
|
||||||
@echo "==> $@"
|
@echo "==> $@"
|
||||||
|
|
|
@ -17,12 +17,12 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
"github.com/pomerium/pomerium/internal/grpc"
|
"github.com/pomerium/pomerium/internal/grpc"
|
||||||
|
"github.com/pomerium/pomerium/internal/grpc/cache"
|
||||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cache"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
|
@ -93,7 +93,7 @@ type Authenticate struct {
|
||||||
provider identity.Authenticator
|
provider identity.Authenticator
|
||||||
|
|
||||||
// cacheClient is the interface for setting and getting sessions from a cache
|
// cacheClient is the interface for setting and getting sessions from a cache
|
||||||
cacheClient client.Cacher
|
cacheClient cache.Cacher
|
||||||
|
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
}
|
}
|
||||||
|
@ -106,12 +106,12 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
|
|
||||||
// shared state encoder setup
|
// shared state encoder setup
|
||||||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
|
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
|
||||||
signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
|
sharedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// private state encoder setup
|
// private state encoder setup, used to encrypt oauth2 tokens
|
||||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||||
encryptedEncoder := ecjson.New(cookieCipher)
|
encryptedEncoder := ecjson.New(cookieCipher)
|
||||||
|
@ -124,7 +124,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
Expire: opts.CookieExpire,
|
Expire: opts.CookieExpire,
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
|
cookieStore, err := cookie.NewStore(cookieOptions, sharedEncoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -145,13 +145,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
|
|
||||||
cacheClient := client.New(cacheConn)
|
cacheClient := client.New(cacheConn)
|
||||||
|
|
||||||
cacheStore := cache.NewStore(&cache.Options{
|
qpStore := queryparam.NewStore(encryptedEncoder, urlutil.QueryProgrammaticToken)
|
||||||
Cache: cacheClient,
|
|
||||||
Encoder: encryptedEncoder,
|
|
||||||
QueryParam: urlutil.QueryAccessTokenID,
|
|
||||||
WrappedStore: cookieStore})
|
|
||||||
|
|
||||||
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
|
|
||||||
headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium)
|
headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium)
|
||||||
|
|
||||||
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
||||||
|
@ -177,14 +171,14 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
// shared state
|
// shared state
|
||||||
sharedKey: opts.SharedKey,
|
sharedKey: opts.SharedKey,
|
||||||
sharedCipher: sharedCipher,
|
sharedCipher: sharedCipher,
|
||||||
sharedEncoder: signedEncoder,
|
sharedEncoder: sharedEncoder,
|
||||||
// private state
|
// private state
|
||||||
cookieSecret: decodedCookieSecret,
|
cookieSecret: decodedCookieSecret,
|
||||||
cookieCipher: cookieCipher,
|
cookieCipher: cookieCipher,
|
||||||
cookieOptions: cookieOptions,
|
cookieOptions: cookieOptions,
|
||||||
sessionStore: cacheStore,
|
sessionStore: cookieStore,
|
||||||
encryptedEncoder: encryptedEncoder,
|
encryptedEncoder: encryptedEncoder,
|
||||||
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
|
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
|
||||||
// IdP
|
// IdP
|
||||||
provider: provider,
|
provider: provider,
|
||||||
// grpc client for cache
|
// grpc client for cache
|
||||||
|
|
|
@ -91,6 +91,10 @@ func TestNew(t *testing.T) {
|
||||||
badGRPCConn.CacheURL = nil
|
badGRPCConn.CacheURL = nil
|
||||||
badGRPCConn.CookieName = "D"
|
badGRPCConn.CookieName = "D"
|
||||||
|
|
||||||
|
emptyProviderURL := newTestOptions(t)
|
||||||
|
emptyProviderURL.Provider = "oidc"
|
||||||
|
emptyProviderURL.ProviderURL = ""
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
opts *config.Options
|
opts *config.Options
|
||||||
|
@ -103,6 +107,7 @@ func TestNew(t *testing.T) {
|
||||||
{"bad cookie name", badCookieName, true},
|
{"bad cookie name", badCookieName, true},
|
||||||
{"bad provider", badProvider, true},
|
{"bad provider", badProvider, true},
|
||||||
{"bad cache url", badGRPCConn, true},
|
{"bad cache url", badGRPCConn, true},
|
||||||
|
{"empty provider url", emptyProviderURL, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/csrf"
|
"github.com/pomerium/csrf"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
@ -23,6 +24,7 @@ import (
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler returns the authenticate service's handler chain.
|
// Handler returns the authenticate service's handler chain.
|
||||||
|
@ -80,19 +82,15 @@ func (a *Authenticate) Mount(r *mux.Router) {
|
||||||
// session state is attached to the users's request context.
|
// session state is attached to the users's request context.
|
||||||
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx := r.Context()
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession")
|
||||||
jwt, err := sessions.FromContext(ctx)
|
defer span.End()
|
||||||
|
s, err := a.getSessionFromCtx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: session load error")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
}
|
}
|
||||||
var s sessions.State
|
|
||||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.IsExpired() {
|
if s.IsExpired() {
|
||||||
ctx, err = a.refresh(w, r, &s)
|
ctx, err = a.refresh(w, r, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
|
@ -106,18 +104,34 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
|
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
newSession, err := a.provider.Refresh(ctx, s)
|
accessToken, err := a.getAccessToken(ctx, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// we are going to keep the same audiences for the refreshed token
|
||||||
|
// otherwise this will be rewritten to be the ClientID of the provider
|
||||||
|
oldAudience := s.Audience
|
||||||
|
|
||||||
|
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||||
}
|
}
|
||||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
|
||||||
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
|
newSession := sessions.NewSession(s, a.RedirectURL.Hostname(), oldAudience, newAccessToken)
|
||||||
}
|
|
||||||
newSession = newSession.NewSession(s.Issuer, s.Audience)
|
encSession, err := a.sharedEncoder.Marshal(newSession)
|
||||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||||
|
return nil, fmt.Errorf("authenticate: error saving new session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.setAccessToken(ctx, newAccessToken); err != nil {
|
||||||
|
return nil, fmt.Errorf("authenticate: error saving refreshed access token: %w", err)
|
||||||
|
}
|
||||||
// return the new session and add it to the current request context
|
// return the new session and add it to the current request context
|
||||||
return sessions.NewContext(ctx, string(encSession), err), nil
|
return sessions.NewContext(ctx, string(encSession), err), nil
|
||||||
}
|
}
|
||||||
|
@ -129,8 +143,11 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignIn handles to authenticating a user.
|
// SignIn handles authenticating a user.
|
||||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
@ -158,32 +175,30 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
jwtAudience = append(jwtAudience, fwdAuth)
|
jwtAudience = append(jwtAudience, fwdAuth)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt, err := sessions.FromContext(r.Context())
|
s, err := a.getSessionFromCtx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return err
|
||||||
}
|
}
|
||||||
var s sessions.State
|
accessToken, err := a.getAccessToken(ctx, s)
|
||||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// user impersonation
|
// user impersonation
|
||||||
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
|
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
|
||||||
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
|
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
|
||||||
}
|
}
|
||||||
|
newSession := sessions.NewSession(s, a.RedirectURL.Host, jwtAudience, accessToken)
|
||||||
|
|
||||||
// re-persist the session, useful when session was evicted from session
|
// re-persist the session, useful when session was evicted from session
|
||||||
if err := a.sessionStore.SaveSession(w, r, &s); err != nil {
|
if err := a.sessionStore.SaveSession(w, r, s); err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession := s.NewSession(a.RedirectURL.Host, jwtAudience)
|
|
||||||
|
|
||||||
callbackParams := callbackURL.Query()
|
callbackParams := callbackURL.Query()
|
||||||
|
|
||||||
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
||||||
newSession.Programmatic = true
|
newSession.Programmatic = true
|
||||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
encSession, err := a.encryptedEncoder.Marshal(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
@ -192,7 +207,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// sign the route session, as a JWT
|
// sign the route session, as a JWT
|
||||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
|
signedJWT, err := a.sharedEncoder.Marshal(newSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
@ -217,27 +232,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||||
// Handles both GET and POST.
|
// Handles both GET and POST.
|
||||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
// no matter what happens, we want to clear the local session store
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
// no matter what happens, we want to clear the session store
|
||||||
a.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
|
|
||||||
jwt, err := sessions.FromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
var s sessions.State
|
|
||||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectString := r.FormValue(urlutil.QueryRedirectURI)
|
redirectString := r.FormValue(urlutil.QueryRedirectURI)
|
||||||
|
|
||||||
// first, try to revoke the session if implemented
|
|
||||||
err = a.provider.Revoke(r.Context(), s.AccessToken)
|
|
||||||
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// next, try to build a logout url if implemented
|
|
||||||
endSessionURL, err := a.provider.LogOut()
|
endSessionURL, err := a.provider.LogOut()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
|
@ -245,14 +245,29 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
endSessionURL.RawQuery = params.Encode()
|
endSessionURL.RawQuery = params.Encode()
|
||||||
redirectString = endSessionURL.String()
|
redirectString = endSessionURL.String()
|
||||||
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(redirectString)
|
httputil.Redirect(w, r, redirectString, http.StatusFound)
|
||||||
|
|
||||||
|
s, err := a.getSessionFromCtx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := a.getAccessToken(ctx, s)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting access token")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// first, try to revoke the session if implemented
|
||||||
|
err = a.provider.Revoke(ctx, accessToken)
|
||||||
|
if err != nil && !errors.Is(err, oidc.ErrRevokeNotImplemented) {
|
||||||
|
log.Warn().Err(err).Msg("authenticate.SignOut: failed revoking token")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,7 +282,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||||
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
|
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
|
||||||
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
|
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
|
||||||
// If request AJAX/XHR request, return a 401 instead .
|
// If request AJAX/XHR request, return a 401 instead because the redirect
|
||||||
|
// will almost certainly violate their CORs policy
|
||||||
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
|
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
|
||||||
return httputil.NewError(http.StatusUnauthorized, err)
|
return httputil.NewError(http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
|
@ -290,7 +306,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
||||||
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
|
||||||
redirect, err := a.getOAuthCallback(w, r)
|
redirect, err := a.getOAuthCallback(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("oauth callback : %w", err)
|
return fmt.Errorf("authenticate.OAuthCallback: %w", err)
|
||||||
}
|
}
|
||||||
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
|
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||||
return nil
|
return nil
|
||||||
|
@ -306,6 +322,9 @@ func (a *Authenticate) statusForErrorCode(errorCode string) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
||||||
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||||
//
|
//
|
||||||
// first, check if the identity provider returned an error
|
// first, check if the identity provider returned an error
|
||||||
|
@ -321,14 +340,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
|
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
|
||||||
//
|
//
|
||||||
// Exchange the supplied Authorization Code for a valid user session.
|
// Exchange the supplied Authorization Code for a valid user session.
|
||||||
session, err := a.provider.Authenticate(r.Context(), code)
|
var s sessions.State
|
||||||
|
accessToken, err := a.provider.Authenticate(ctx, code, &s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newState := sessions.NewSession(
|
||||||
|
&s,
|
||||||
|
a.RedirectURL.Hostname(),
|
||||||
|
[]string{a.RedirectURL.Hostname()},
|
||||||
|
accessToken)
|
||||||
|
|
||||||
// state includes a csrf nonce (validated by middleware) and redirect uri
|
// state includes a csrf nonce (validated by middleware) and redirect uri
|
||||||
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
|
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("bad bytes: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// split state into concat'd components
|
// split state into concat'd components
|
||||||
|
@ -357,8 +384,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OK. Looks good so let's persist our user session
|
// Ok -- We've got a valid session here. Let's now persist the access
|
||||||
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
// token to cache ...
|
||||||
|
if err := a.setAccessToken(ctx, accessToken); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed saving access token: %w", err)
|
||||||
|
}
|
||||||
|
// ... and the user state to local storage.
|
||||||
|
if err := a.sessionStore.SaveSession(w, r, &newState); err != nil {
|
||||||
return nil, fmt.Errorf("failed saving new session: %w", err)
|
return nil, fmt.Errorf("failed saving new session: %w", err)
|
||||||
}
|
}
|
||||||
return redirectURL, nil
|
return redirectURL, nil
|
||||||
|
@ -368,26 +400,32 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
// tokens and state with the identity provider. If successful, a new signed JWT
|
// tokens and state with the identity provider. If successful, a new signed JWT
|
||||||
// and refresh token (`refresh_token`) are returned as JSON
|
// and refresh token (`refresh_token`) are returned as JSON
|
||||||
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
|
||||||
jwt, err := sessions.FromContext(r.Context())
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.RefreshAPI")
|
||||||
if err != nil {
|
defer span.End()
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
var s sessions.State
|
|
||||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
newSession, err := a.provider.Refresh(r.Context(), &s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
newSession = newSession.NewSession(s.Issuer, s.Audience)
|
|
||||||
|
|
||||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
s, err := a.getSessionFromCtx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
|
accessToken, err := a.getAccessToken(ctx, s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newAccessToken, err := a.provider.Refresh(ctx, accessToken, s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
routeNewSession := sessions.NewSession(s, a.RedirectURL.Hostname(), s.Audience, newAccessToken)
|
||||||
|
|
||||||
|
encSession, err := a.encryptedEncoder.Marshal(accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
signedJWT, err := a.sharedEncoder.Marshal(routeNewSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -410,28 +448,79 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
|
||||||
// Refresh is called by the proxy service to handle backend session refresh.
|
// Refresh is called by the proxy service to handle backend session refresh.
|
||||||
//
|
//
|
||||||
// NOTE: The actual refresh is handled as part of the "VerifySession"
|
// NOTE: The actual refresh is handled as part of the "VerifySession"
|
||||||
// middleware. This handler is responsible for creating a new route scoped
|
// middleware. This handler is simply responsible for returning that jwt.
|
||||||
// session and returning it.
|
|
||||||
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
|
||||||
jwt, err := sessions.FromContext(r.Context())
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.Refresh")
|
||||||
|
defer span.End()
|
||||||
|
jwt, err := sessions.FromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return fmt.Errorf("authenticate.Refresh: %w", err)
|
||||||
}
|
}
|
||||||
var s sessions.State
|
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
|
||||||
if err := a.encryptedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
fmt.Fprint(w, jwt)
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aud := strings.Split(r.FormValue(urlutil.QueryAudience), ",")
|
// getAccessToken gets an associated oauth2 access token from a session state
|
||||||
routeSession := s.NewSession(r.Host, aud)
|
func (a *Authenticate) getAccessToken(ctx context.Context, s *sessions.State) (*oauth2.Token, error) {
|
||||||
routeSession.AccessTokenID = s.AccessTokenID
|
ctx, span := trace.StartSpan(ctx, "authenticate.getAccessToken")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
|
var accessToken oauth2.Token
|
||||||
|
tokenBytes, err := a.cacheClient.Get(ctx, s.AccessTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accessToken.Valid() {
|
||||||
|
return &accessToken, nil // this token is still valid, use it!
|
||||||
|
}
|
||||||
|
tokenBytes, err = a.cacheClient.Get(ctx, a.timestampedHash(accessToken.RefreshToken))
|
||||||
|
if err == nil {
|
||||||
|
// we found another possibly newer access token associated with the
|
||||||
|
// existing refresh token so let's try that.
|
||||||
|
if err := a.encryptedEncoder.Unmarshal(tokenBytes, &accessToken); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authenticate) setAccessToken(ctx context.Context, accessToken *oauth2.Token) error {
|
||||||
|
encToken, err := a.encryptedEncoder.Marshal(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// set this specific access token
|
||||||
|
key := fmt.Sprintf("%x", hashutil.Hash(accessToken))
|
||||||
|
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
|
||||||
|
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set this as the "latest" token for this access token
|
||||||
|
key = a.timestampedHash(accessToken.RefreshToken)
|
||||||
|
if err := a.cacheClient.Set(ctx, key, encToken); err != nil {
|
||||||
|
return fmt.Errorf("authenticate: setAccessToken failed key: %s :%w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
|
|
||||||
w.Write(signedJWT)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authenticate) timestampedHash(s string) string {
|
||||||
|
return fmt.Sprintf("%x-%v", hashutil.Hash(s), time.Now().Truncate(time.Minute).Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, error) {
|
||||||
|
jwt, err := sessions.FromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
var s sessions.State
|
||||||
|
if err := a.sharedEncoder.Unmarshal([]byte(jwt), &s); err != nil {
|
||||||
|
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
|
@ -11,20 +11,24 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
|
mock_cache "github.com/pomerium/pomerium/internal/grpc/cache/mock"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
@ -115,23 +119,28 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
encoder encoding.MarshalUnmarshaler
|
encoder encoding.MarshalUnmarshaler
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||||
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||||
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io"}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
mc := mock_cache.NewMockCacher(ctrl)
|
||||||
|
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||||
|
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
|
@ -144,6 +153,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
Name: "cookie",
|
Name: "cookie",
|
||||||
Domain: "foo",
|
Domain: "foo",
|
||||||
},
|
},
|
||||||
|
cacheClient: mc,
|
||||||
}
|
}
|
||||||
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
||||||
|
|
||||||
|
@ -176,6 +186,7 @@ func uriParseHelper(s string) *url.URL {
|
||||||
|
|
||||||
func TestAuthenticate_SignOut(t *testing.T) {
|
func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
|
@ -190,19 +201,24 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
wantCode int
|
wantCode int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
|
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
|
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
|
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{Email: "user@pomerium.io"}}, http.StatusFound, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
mc := mock_cache.NewMockCacher(ctrl)
|
||||||
|
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
sessionStore: tt.sessionStore,
|
sessionStore: tt.sessionStore,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
encryptedEncoder: mock.Encoder{},
|
encryptedEncoder: mock.Encoder{},
|
||||||
templates: template.Must(frontend.NewTemplates()),
|
templates: template.Must(frontend.NewTemplates()),
|
||||||
|
sharedEncoder: mock.Encoder{},
|
||||||
|
cacheClient: mc,
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/sign_out")
|
u, _ := url.Parse("/sign_out")
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
@ -256,40 +272,50 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusFound},
|
||||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}, AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{}, "", http.StatusInternalServerError},
|
||||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
|
||||||
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusUnauthorized},
|
{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusUnauthorized},
|
||||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
|
||||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
|
||||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
mc := mock_cache.NewMockCacher(ctrl)
|
||||||
|
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||||
|
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
signer, err := jws.NewHS256Signer(nil, "mock")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
authURL, _ := url.Parse(tt.authenticateURL)
|
authURL, _ := url.Parse(tt.authenticateURL)
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
RedirectURL: authURL,
|
RedirectURL: authURL,
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
|
cacheClient: mc,
|
||||||
|
encryptedEncoder: signer,
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/oauthGet")
|
u, _ := url.Parse("/oauthGet")
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
params.Add("error", tt.paramErr)
|
params.Add("error", tt.paramErr)
|
||||||
params.Add("code", tt.code)
|
params.Add("code", tt.code)
|
||||||
nonce := cryptutil.NewBase64Key() // mock csrf
|
nonce := cryptutil.NewBase64Key() // mock csrf
|
||||||
|
|
||||||
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
|
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
|
||||||
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
|
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
|
||||||
|
|
||||||
|
@ -336,15 +362,21 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
|
||||||
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||||
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK},
|
||||||
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||||
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusFound},
|
||||||
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
mc := mock_cache.NewMockCacher(ctrl)
|
||||||
|
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||||
|
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
|
|
||||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -361,6 +393,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
encryptedEncoder: signer,
|
encryptedEncoder: signer,
|
||||||
|
cacheClient: mc,
|
||||||
|
sharedEncoder: mock.Encoder{},
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
state, err := tt.session.LoadSession(r)
|
state, err := tt.session.LoadSession(r)
|
||||||
|
@ -402,14 +436,20 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||||
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||||
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||||
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||||
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
mc := mock_cache.NewMockCacher(ctrl)
|
||||||
|
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return([]byte("hi"), nil).AnyTimes()
|
||||||
|
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
|
|
||||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -423,6 +463,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
|
cacheClient: mc,
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
state, _ := tt.session.LoadSession(r)
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
@ -441,53 +482,111 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticate_Refresh(t *testing.T) {
|
func TestAuthenticate_Refresh(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
session sessions.SessionStore
|
session *sessions.State
|
||||||
ctxError error
|
at *oauth2.Token
|
||||||
|
|
||||||
provider identity.Authenticator
|
provider identity.Authenticator
|
||||||
secretEncoder encoding.MarshalUnmarshaler
|
secretEncoder encoding.MarshalUnmarshaler
|
||||||
sharedEncoder encoding.MarshalUnmarshaler
|
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
{"good",
|
||||||
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))},
|
||||||
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
|
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
|
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||||
|
200},
|
||||||
|
{"session and oauth2 expired",
|
||||||
|
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
|
||||||
|
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(-10 * time.Minute)},
|
||||||
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
|
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||||
|
200},
|
||||||
|
{"session expired",
|
||||||
|
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
|
||||||
|
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
|
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||||
|
200},
|
||||||
|
{"failed refresh",
|
||||||
|
&sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))},
|
||||||
|
&oauth2.Token{AccessToken: "mock", Expiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
identity.MockProvider{RefreshError: errors.New("oh no")},
|
||||||
|
mock.Encoder{MarshalResponse: []byte("ok")},
|
||||||
|
302},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
mc := mock_cache.NewMockCacher(ctrl)
|
||||||
|
// just enough is stubbed out here so we can use our own mock provider
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
out := fmt.Sprintf(`{"issuer":"http://%s"}`, r.Host)
|
||||||
|
fmt.Fprintln(w, out)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
rURL := ts.URL
|
||||||
|
a, err := New(config.Options{
|
||||||
|
SharedKey: cryptutil.NewBase64Key(),
|
||||||
|
CookieSecret: cryptutil.NewBase64Key(),
|
||||||
|
AuthenticateURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||||
|
Provider: "oidc",
|
||||||
|
ClientID: "mock",
|
||||||
|
ClientSecret: "mock",
|
||||||
|
ProviderURL: rURL,
|
||||||
|
AuthenticateCallbackPath: "mock",
|
||||||
|
CookieName: "pomerium",
|
||||||
|
Addr: ":0",
|
||||||
|
CacheURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
a := Authenticate{
|
a.cacheClient = mc
|
||||||
sharedKey: cryptutil.NewBase64Key(),
|
a.provider = tt.provider
|
||||||
cookieSecret: cryptutil.NewKey(),
|
|
||||||
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
|
||||||
encryptedEncoder: tt.secretEncoder,
|
|
||||||
sharedEncoder: tt.sharedEncoder,
|
|
||||||
sessionStore: tt.session,
|
|
||||||
provider: tt.provider,
|
|
||||||
cookieCipher: aead,
|
|
||||||
}
|
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
|
||||||
state, _ := tt.session.LoadSession(r)
|
|
||||||
ctx := r.Context()
|
|
||||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
|
u, _ := url.Parse("/oauthGet")
|
||||||
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
destination := urlutil.NewSignedURL(a.sharedKey,
|
||||||
|
&url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "example.com",
|
||||||
|
Path: "/.pomerium/refresh"})
|
||||||
|
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, destination.String(), nil)
|
||||||
|
|
||||||
|
jwt, err := a.sharedEncoder.Marshal(tt.session)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rawToken, err := a.encryptedEncoder.Marshal(tt.at)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
mc.EXPECT().Get(gomock.Any(), gomock.Any()).Return(rawToken, nil).AnyTimes()
|
||||||
|
mc.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
|
a.cacheClient = mc
|
||||||
|
|
||||||
|
r.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", jwt))
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
|
router := mux.NewRouter()
|
||||||
|
a.Mount(router)
|
||||||
|
router.ServeHTTP(w, r)
|
||||||
if status := w.Code; status != tt.wantStatus {
|
if status := w.Code; status != tt.wantStatus {
|
||||||
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
t.Errorf("Refresh() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||||
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
//go:generate mockgen -destination mock_evaluator/mock.go github.com/pomerium/pomerium/authorize/evaluator Evaluator
|
|
||||||
|
|
||||||
// Package evaluator defines a Evaluator interfaces that can be implemented by
|
// Package evaluator defines a Evaluator interfaces that can be implemented by
|
||||||
// a policy evaluator framework.
|
// a policy evaluator framework.
|
||||||
package evaluator
|
package evaluator
|
||||||
|
|
|
@ -218,10 +218,6 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
|
||||||
|
|
||||||
// 1 - build a signed url to call refresh on authenticate service
|
// 1 - build a signed url to call refresh on authenticate service
|
||||||
refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"})
|
refreshURI := options.AuthenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/refresh"})
|
||||||
q := refreshURI.Query()
|
|
||||||
q.Set(urlutil.QueryAccessTokenID, state.AccessTokenID) // hash value points to parent token
|
|
||||||
q.Set(urlutil.QueryAudience, strings.Join(state.Audience, ",")) // request's audience, this route
|
|
||||||
refreshURI.RawQuery = q.Encode()
|
|
||||||
signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String()
|
signedRefreshURL := urlutil.NewSignedURL(options.SharedKey, refreshURI).String()
|
||||||
|
|
||||||
// 2 - http call to authenticate service
|
// 2 - http call to authenticate service
|
||||||
|
@ -229,6 +225,7 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: refresh request: %w", err)
|
return nil, fmt.Errorf("authorize: refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Pomerium %s", rawSession))
|
||||||
req.Header.Set("X-Requested-With", "XmlHttpRequest")
|
req.Header.Set("X-Requested-With", "XmlHttpRequest")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
|
18
cache/grpc_test.go
vendored
18
cache/grpc_test.go
vendored
|
@ -11,6 +11,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/grpc/cache"
|
"github.com/pomerium/pomerium/internal/grpc/cache"
|
||||||
|
@ -50,9 +51,6 @@ func TestCache_Get_and_Set(t *testing.T) {
|
||||||
&cache.GetReply{
|
&cache.GetReply{
|
||||||
Exists: true,
|
Exists: true,
|
||||||
Value: []byte("hello"),
|
Value: []byte("hello"),
|
||||||
XXX_NoUnkeyedLiteral: struct{}{},
|
|
||||||
XXX_unrecognized: nil,
|
|
||||||
XXX_sizecache: 0,
|
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -65,9 +63,6 @@ func TestCache_Get_and_Set(t *testing.T) {
|
||||||
&cache.GetReply{
|
&cache.GetReply{
|
||||||
Exists: false,
|
Exists: false,
|
||||||
Value: nil,
|
Value: nil,
|
||||||
XXX_NoUnkeyedLiteral: struct{}{},
|
|
||||||
XXX_unrecognized: nil,
|
|
||||||
XXX_sizecache: 0,
|
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -80,9 +75,6 @@ func TestCache_Get_and_Set(t *testing.T) {
|
||||||
&cache.GetReply{
|
&cache.GetReply{
|
||||||
Exists: false,
|
Exists: false,
|
||||||
Value: nil,
|
Value: nil,
|
||||||
XXX_NoUnkeyedLiteral: struct{}{},
|
|
||||||
XXX_unrecognized: nil,
|
|
||||||
XXX_sizecache: 0,
|
|
||||||
},
|
},
|
||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
|
@ -96,7 +88,11 @@ func TestCache_Get_and_Set(t *testing.T) {
|
||||||
t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError)
|
t.Errorf("Cache.Set() error = %v, wantSetError %v", err, tt.wantSetError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(setGot, tt.SetReply); diff != "" {
|
cmpOpts := []cmp.Option{
|
||||||
|
cmpopts.IgnoreUnexported(cache.SetReply{}, cache.GetReply{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(setGot, tt.SetReply, cmpOpts...); diff != "" {
|
||||||
t.Errorf("Cache.Set() = %v", diff)
|
t.Errorf("Cache.Set() = %v", diff)
|
||||||
}
|
}
|
||||||
getGot, err := c.Get(tt.ctx, tt.GetRequest)
|
getGot, err := c.Get(tt.ctx, tt.GetRequest)
|
||||||
|
@ -104,7 +100,7 @@ func TestCache_Get_and_Set(t *testing.T) {
|
||||||
t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError)
|
t.Errorf("Cache.Get() error = %v, wantGetError %v", err, tt.wantGetError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(getGot, tt.GetReply); diff != "" {
|
if diff := cmp.Diff(getGot, tt.GetReply, cmpOpts...); diff != "" {
|
||||||
t.Errorf("Cache.Get() = %v", diff)
|
t.Errorf("Cache.Get() = %v", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"sort"
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||||
envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
|
envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
|
||||||
|
@ -151,6 +152,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options,
|
||||||
},
|
},
|
||||||
Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{
|
Services: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthz_GrpcService{
|
||||||
GrpcService: &envoy_config_core_v3.GrpcService{
|
GrpcService: &envoy_config_core_v3.GrpcService{
|
||||||
|
Timeout: ptypes.DurationProto(time.Second * 30),
|
||||||
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
|
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
|
||||||
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
|
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
|
||||||
ClusterName: "pomerium-authz",
|
ClusterName: "pomerium-authz",
|
||||||
|
|
14
internal/grpc/cache/cache.go
vendored
Normal file
14
internal/grpc/cache/cache.go
vendored
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
// Package cache defines a Cacher interfaces that can be implemented by any
|
||||||
|
// key value store system.
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cacher specifies an interface for remote clients connecting to the cache service.
|
||||||
|
type Cacher interface {
|
||||||
|
Get(ctx context.Context, key string) (value []byte, err error)
|
||||||
|
Set(ctx context.Context, key string, value []byte) error
|
||||||
|
Close() error
|
||||||
|
}
|
421
internal/grpc/cache/cache.pb.go
vendored
421
internal/grpc/cache/cache.pb.go
vendored
|
@ -1,217 +1,356 @@
|
||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.21.0
|
||||||
|
// protoc v3.11.4
|
||||||
// source: cache.proto
|
// source: cache.proto
|
||||||
|
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
fmt "fmt"
|
|
||||||
proto "github.com/golang/protobuf/proto"
|
proto "github.com/golang/protobuf/proto"
|
||||||
grpc "google.golang.org/grpc"
|
grpc "google.golang.org/grpc"
|
||||||
codes "google.golang.org/grpc/codes"
|
codes "google.golang.org/grpc/codes"
|
||||||
status "google.golang.org/grpc/status"
|
status "google.golang.org/grpc/status"
|
||||||
math "math"
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reference imports to suppress errors if they are not otherwise used.
|
const (
|
||||||
var _ = proto.Marshal
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
var _ = fmt.Errorf
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
var _ = math.Inf
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
// This is a compile-time assertion to ensure that this generated file
|
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||||
// is compatible with the proto package it is being compiled against.
|
// of the legacy proto package is being used.
|
||||||
// A compilation error at this line likely means your copy of the
|
const _ = proto.ProtoPackageIsVersion4
|
||||||
// proto package needs to be updated.
|
|
||||||
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
|
|
||||||
|
|
||||||
type GetRequest struct {
|
type GetRequest struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
|
||||||
XXX_unrecognized []byte `json:"-"`
|
|
||||||
XXX_sizecache int32 `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GetRequest) Reset() { *m = GetRequest{} }
|
func (x *GetRequest) Reset() {
|
||||||
func (m *GetRequest) String() string { return proto.CompactTextString(m) }
|
*x = GetRequest{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_cache_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *GetRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
func (*GetRequest) ProtoMessage() {}
|
func (*GetRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *GetRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_cache_proto_msgTypes[0]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use GetRequest.ProtoReflect.Descriptor instead.
|
||||||
func (*GetRequest) Descriptor() ([]byte, []int) {
|
func (*GetRequest) Descriptor() ([]byte, []int) {
|
||||||
return fileDescriptor_5fca3b110c9bbf3a, []int{0}
|
return file_cache_proto_rawDescGZIP(), []int{0}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GetRequest) XXX_Unmarshal(b []byte) error {
|
func (x *GetRequest) GetKey() string {
|
||||||
return xxx_messageInfo_GetRequest.Unmarshal(m, b)
|
if x != nil {
|
||||||
}
|
return x.Key
|
||||||
func (m *GetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
|
||||||
return xxx_messageInfo_GetRequest.Marshal(b, m, deterministic)
|
|
||||||
}
|
|
||||||
func (m *GetRequest) XXX_Merge(src proto.Message) {
|
|
||||||
xxx_messageInfo_GetRequest.Merge(m, src)
|
|
||||||
}
|
|
||||||
func (m *GetRequest) XXX_Size() int {
|
|
||||||
return xxx_messageInfo_GetRequest.Size(m)
|
|
||||||
}
|
|
||||||
func (m *GetRequest) XXX_DiscardUnknown() {
|
|
||||||
xxx_messageInfo_GetRequest.DiscardUnknown(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
var xxx_messageInfo_GetRequest proto.InternalMessageInfo
|
|
||||||
|
|
||||||
func (m *GetRequest) GetKey() string {
|
|
||||||
if m != nil {
|
|
||||||
return m.Key
|
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetReply struct {
|
type GetReply struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
|
Exists bool `protobuf:"varint,1,opt,name=exists,proto3" json:"exists,omitempty"`
|
||||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
||||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
|
||||||
XXX_unrecognized []byte `json:"-"`
|
|
||||||
XXX_sizecache int32 `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GetReply) Reset() { *m = GetReply{} }
|
func (x *GetReply) Reset() {
|
||||||
func (m *GetReply) String() string { return proto.CompactTextString(m) }
|
*x = GetReply{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_cache_proto_msgTypes[1]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *GetReply) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
func (*GetReply) ProtoMessage() {}
|
func (*GetReply) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *GetReply) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_cache_proto_msgTypes[1]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use GetReply.ProtoReflect.Descriptor instead.
|
||||||
func (*GetReply) Descriptor() ([]byte, []int) {
|
func (*GetReply) Descriptor() ([]byte, []int) {
|
||||||
return fileDescriptor_5fca3b110c9bbf3a, []int{1}
|
return file_cache_proto_rawDescGZIP(), []int{1}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GetReply) XXX_Unmarshal(b []byte) error {
|
func (x *GetReply) GetExists() bool {
|
||||||
return xxx_messageInfo_GetReply.Unmarshal(m, b)
|
if x != nil {
|
||||||
}
|
return x.Exists
|
||||||
func (m *GetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
|
||||||
return xxx_messageInfo_GetReply.Marshal(b, m, deterministic)
|
|
||||||
}
|
|
||||||
func (m *GetReply) XXX_Merge(src proto.Message) {
|
|
||||||
xxx_messageInfo_GetReply.Merge(m, src)
|
|
||||||
}
|
|
||||||
func (m *GetReply) XXX_Size() int {
|
|
||||||
return xxx_messageInfo_GetReply.Size(m)
|
|
||||||
}
|
|
||||||
func (m *GetReply) XXX_DiscardUnknown() {
|
|
||||||
xxx_messageInfo_GetReply.DiscardUnknown(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
var xxx_messageInfo_GetReply proto.InternalMessageInfo
|
|
||||||
|
|
||||||
func (m *GetReply) GetExists() bool {
|
|
||||||
if m != nil {
|
|
||||||
return m.Exists
|
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GetReply) GetValue() []byte {
|
func (x *GetReply) GetValue() []byte {
|
||||||
if m != nil {
|
if x != nil {
|
||||||
return m.Value
|
return x.Value
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type SetRequest struct {
|
type SetRequest struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
|
||||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
|
||||||
XXX_unrecognized []byte `json:"-"`
|
|
||||||
XXX_sizecache int32 `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SetRequest) Reset() { *m = SetRequest{} }
|
func (x *SetRequest) Reset() {
|
||||||
func (m *SetRequest) String() string { return proto.CompactTextString(m) }
|
*x = SetRequest{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_cache_proto_msgTypes[2]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *SetRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
func (*SetRequest) ProtoMessage() {}
|
func (*SetRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *SetRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_cache_proto_msgTypes[2]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use SetRequest.ProtoReflect.Descriptor instead.
|
||||||
func (*SetRequest) Descriptor() ([]byte, []int) {
|
func (*SetRequest) Descriptor() ([]byte, []int) {
|
||||||
return fileDescriptor_5fca3b110c9bbf3a, []int{2}
|
return file_cache_proto_rawDescGZIP(), []int{2}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SetRequest) XXX_Unmarshal(b []byte) error {
|
func (x *SetRequest) GetKey() string {
|
||||||
return xxx_messageInfo_SetRequest.Unmarshal(m, b)
|
if x != nil {
|
||||||
}
|
return x.Key
|
||||||
func (m *SetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
|
||||||
return xxx_messageInfo_SetRequest.Marshal(b, m, deterministic)
|
|
||||||
}
|
|
||||||
func (m *SetRequest) XXX_Merge(src proto.Message) {
|
|
||||||
xxx_messageInfo_SetRequest.Merge(m, src)
|
|
||||||
}
|
|
||||||
func (m *SetRequest) XXX_Size() int {
|
|
||||||
return xxx_messageInfo_SetRequest.Size(m)
|
|
||||||
}
|
|
||||||
func (m *SetRequest) XXX_DiscardUnknown() {
|
|
||||||
xxx_messageInfo_SetRequest.DiscardUnknown(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
var xxx_messageInfo_SetRequest proto.InternalMessageInfo
|
|
||||||
|
|
||||||
func (m *SetRequest) GetKey() string {
|
|
||||||
if m != nil {
|
|
||||||
return m.Key
|
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SetRequest) GetValue() []byte {
|
func (x *SetRequest) GetValue() []byte {
|
||||||
if m != nil {
|
if x != nil {
|
||||||
return m.Value
|
return x.Value
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type SetReply struct {
|
type SetReply struct {
|
||||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
state protoimpl.MessageState
|
||||||
XXX_unrecognized []byte `json:"-"`
|
sizeCache protoimpl.SizeCache
|
||||||
XXX_sizecache int32 `json:"-"`
|
unknownFields protoimpl.UnknownFields
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *SetReply) Reset() {
|
||||||
|
*x = SetReply{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_cache_proto_msgTypes[3]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *SetReply) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SetReply) Reset() { *m = SetReply{} }
|
|
||||||
func (m *SetReply) String() string { return proto.CompactTextString(m) }
|
|
||||||
func (*SetReply) ProtoMessage() {}
|
func (*SetReply) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *SetReply) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_cache_proto_msgTypes[3]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use SetReply.ProtoReflect.Descriptor instead.
|
||||||
func (*SetReply) Descriptor() ([]byte, []int) {
|
func (*SetReply) Descriptor() ([]byte, []int) {
|
||||||
return fileDescriptor_5fca3b110c9bbf3a, []int{3}
|
return file_cache_proto_rawDescGZIP(), []int{3}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SetReply) XXX_Unmarshal(b []byte) error {
|
var File_cache_proto protoreflect.FileDescriptor
|
||||||
return xxx_messageInfo_SetReply.Unmarshal(m, b)
|
|
||||||
}
|
var file_cache_proto_rawDesc = []byte{
|
||||||
func (m *SetReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
0x0a, 0x0b, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63,
|
||||||
return xxx_messageInfo_SetReply.Marshal(b, m, deterministic)
|
0x61, 0x63, 0x68, 0x65, 0x22, 0x1e, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||||
}
|
0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||||
func (m *SetReply) XXX_Merge(src proto.Message) {
|
0x03, 0x6b, 0x65, 0x79, 0x22, 0x38, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
|
||||||
xxx_messageInfo_SetReply.Merge(m, src)
|
0x12, 0x16, 0x0a, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
|
||||||
}
|
0x52, 0x06, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
|
||||||
func (m *SetReply) XXX_Size() int {
|
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x34,
|
||||||
return xxx_messageInfo_SetReply.Size(m)
|
0x0a, 0x0a, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03,
|
||||||
}
|
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14,
|
||||||
func (m *SetReply) XXX_DiscardUnknown() {
|
0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76,
|
||||||
xxx_messageInfo_SetReply.DiscardUnknown(m)
|
0x61, 0x6c, 0x75, 0x65, 0x22, 0x0a, 0x0a, 0x08, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
|
||||||
|
0x32, 0x61, 0x0a, 0x05, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x2b, 0x0a, 0x03, 0x47, 0x65, 0x74,
|
||||||
|
0x12, 0x11, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75,
|
||||||
|
0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x47, 0x65, 0x74, 0x52,
|
||||||
|
0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x03, 0x53, 0x65, 0x74, 0x12, 0x11, 0x2e,
|
||||||
|
0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||||
|
0x1a, 0x0f, 0x2e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2e, 0x53, 0x65, 0x74, 0x52, 0x65, 0x70, 0x6c,
|
||||||
|
0x79, 0x22, 0x00, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
}
|
}
|
||||||
|
|
||||||
var xxx_messageInfo_SetReply proto.InternalMessageInfo
|
var (
|
||||||
|
file_cache_proto_rawDescOnce sync.Once
|
||||||
|
file_cache_proto_rawDescData = file_cache_proto_rawDesc
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func file_cache_proto_rawDescGZIP() []byte {
|
||||||
proto.RegisterType((*GetRequest)(nil), "cache.GetRequest")
|
file_cache_proto_rawDescOnce.Do(func() {
|
||||||
proto.RegisterType((*GetReply)(nil), "cache.GetReply")
|
file_cache_proto_rawDescData = protoimpl.X.CompressGZIP(file_cache_proto_rawDescData)
|
||||||
proto.RegisterType((*SetRequest)(nil), "cache.SetRequest")
|
})
|
||||||
proto.RegisterType((*SetReply)(nil), "cache.SetReply")
|
return file_cache_proto_rawDescData
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
var file_cache_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||||
proto.RegisterFile("cache.proto", fileDescriptor_5fca3b110c9bbf3a)
|
var file_cache_proto_goTypes = []interface{}{
|
||||||
|
(*GetRequest)(nil), // 0: cache.GetRequest
|
||||||
|
(*GetReply)(nil), // 1: cache.GetReply
|
||||||
|
(*SetRequest)(nil), // 2: cache.SetRequest
|
||||||
|
(*SetReply)(nil), // 3: cache.SetReply
|
||||||
|
}
|
||||||
|
var file_cache_proto_depIdxs = []int32{
|
||||||
|
0, // 0: cache.Cache.Get:input_type -> cache.GetRequest
|
||||||
|
2, // 1: cache.Cache.Set:input_type -> cache.SetRequest
|
||||||
|
1, // 2: cache.Cache.Get:output_type -> cache.GetReply
|
||||||
|
3, // 3: cache.Cache.Set:output_type -> cache.SetReply
|
||||||
|
2, // [2:4] is the sub-list for method output_type
|
||||||
|
0, // [0:2] is the sub-list for method input_type
|
||||||
|
0, // [0:0] is the sub-list for extension type_name
|
||||||
|
0, // [0:0] is the sub-list for extension extendee
|
||||||
|
0, // [0:0] is the sub-list for field type_name
|
||||||
}
|
}
|
||||||
|
|
||||||
var fileDescriptor_5fca3b110c9bbf3a = []byte{
|
func init() { file_cache_proto_init() }
|
||||||
// 176 bytes of a gzipped FileDescriptorProto
|
func file_cache_proto_init() {
|
||||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x4e, 0x4c, 0xce,
|
if File_cache_proto != nil {
|
||||||
0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0xb8, 0xdc,
|
return
|
||||||
0x53, 0x4b, 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x04, 0xb8, 0x98, 0xb3, 0x53, 0x2b,
|
}
|
||||||
0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x40, 0x4c, 0x25, 0x0b, 0x2e, 0x0e, 0xb0, 0x7c, 0x41,
|
if !protoimpl.UnsafeEnabled {
|
||||||
0x4e, 0xa5, 0x90, 0x18, 0x17, 0x5b, 0x6a, 0x45, 0x66, 0x71, 0x49, 0x31, 0x58, 0x01, 0x47, 0x10,
|
file_cache_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||||
0x94, 0x27, 0x24, 0xc2, 0xc5, 0x5a, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, 0xa4, 0xc0, 0xa8, 0xc1,
|
switch v := v.(*GetRequest); i {
|
||||||
0x13, 0x04, 0xe1, 0x28, 0x99, 0x70, 0x71, 0x05, 0xe3, 0x31, 0x19, 0x87, 0x2e, 0x2e, 0x2e, 0x8e,
|
case 0:
|
||||||
0x60, 0xa8, 0x7d, 0x46, 0x89, 0x5c, 0xac, 0xce, 0x20, 0x47, 0x0a, 0x69, 0x73, 0x31, 0xbb, 0xa7,
|
return &v.state
|
||||||
0x96, 0x08, 0x09, 0xea, 0x41, 0x3c, 0x80, 0x70, 0xb0, 0x14, 0x3f, 0xb2, 0x50, 0x41, 0x4e, 0xa5,
|
case 1:
|
||||||
0x12, 0x03, 0x48, 0x71, 0x30, 0x92, 0xe2, 0x60, 0x4c, 0xc5, 0xc1, 0x70, 0xc5, 0x49, 0x6c, 0xe0,
|
return &v.sizeCache
|
||||||
0xc0, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x0e, 0xef, 0x5f, 0x9e, 0x1b, 0x01, 0x00, 0x00,
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_cache_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*GetReply); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_cache_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*SetRequest); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_cache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*SetReply); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: file_cache_proto_rawDesc,
|
||||||
|
NumEnums: 0,
|
||||||
|
NumMessages: 4,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 1,
|
||||||
|
},
|
||||||
|
GoTypes: file_cache_proto_goTypes,
|
||||||
|
DependencyIndexes: file_cache_proto_depIdxs,
|
||||||
|
MessageInfos: file_cache_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_cache_proto = out.File
|
||||||
|
file_cache_proto_rawDesc = nil
|
||||||
|
file_cache_proto_goTypes = nil
|
||||||
|
file_cache_proto_depIdxs = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reference imports to suppress errors if they are not otherwise used.
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
@ -266,10 +405,10 @@ type CacheServer interface {
|
||||||
type UnimplementedCacheServer struct {
|
type UnimplementedCacheServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*UnimplementedCacheServer) Get(ctx context.Context, req *GetRequest) (*GetReply, error) {
|
func (*UnimplementedCacheServer) Get(context.Context, *GetRequest) (*GetReply, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method Get not implemented")
|
||||||
}
|
}
|
||||||
func (*UnimplementedCacheServer) Set(ctx context.Context, req *SetRequest) (*SetReply, error) {
|
func (*UnimplementedCacheServer) Set(context.Context, *SetRequest) (*SetReply, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method Set not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
17
internal/grpc/cache/client/cache_client.go
vendored
17
internal/grpc/cache/client/cache_client.go
vendored
|
@ -3,6 +3,7 @@ package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/grpc/cache"
|
"github.com/pomerium/pomerium/internal/grpc/cache"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
@ -10,12 +11,7 @@ import (
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Cacher specifies an interface for remote clients connecting to the cache service.
|
var errKeyNotFound = errors.New("cache/client: key not found")
|
||||||
type Cacher interface {
|
|
||||||
Get(ctx context.Context, key string) (keyExists bool, value []byte, err error)
|
|
||||||
Set(ctx context.Context, key string, value []byte) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client represents a gRPC cache service client.
|
// Client represents a gRPC cache service client.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
@ -29,15 +25,18 @@ func New(conn *grpc.ClientConn) (p *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a value from the cache service.
|
// Get retrieves a value from the cache service.
|
||||||
func (a *Client) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
|
func (a *Client) Get(ctx context.Context, key string) (value []byte, err error) {
|
||||||
ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get")
|
ctx, span := trace.StartSpan(ctx, "grpc.cache.client.Get")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
response, err := a.client.Get(ctx, &cache.GetRequest{Key: key})
|
response, err := a.client.Get(ctx, &cache.GetRequest{Key: key})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return response.GetExists(), response.GetValue(), nil
|
if !response.GetExists() {
|
||||||
|
return nil, errKeyNotFound
|
||||||
|
}
|
||||||
|
return response.GetValue(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set stores a key value pair in the cache service.
|
// Set stores a key value pair in the cache service.
|
||||||
|
|
77
internal/grpc/cache/mock/mock_cacher.go
vendored
Normal file
77
internal/grpc/cache/mock/mock_cacher.go
vendored
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/pomerium/pomerium/internal/grpc/cache (interfaces: Cacher)
|
||||||
|
|
||||||
|
// Package mock_cache is a generated GoMock package.
|
||||||
|
package mock_cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
reflect "reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCacher is a mock of Cacher interface
|
||||||
|
type MockCacher struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockCacherMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockCacherMockRecorder is the mock recorder for MockCacher
|
||||||
|
type MockCacherMockRecorder struct {
|
||||||
|
mock *MockCacher
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockCacher creates a new mock instance
|
||||||
|
func NewMockCacher(ctrl *gomock.Controller) *MockCacher {
|
||||||
|
mock := &MockCacher{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockCacherMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockCacher) EXPECT() *MockCacherMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close mocks base method
|
||||||
|
func (m *MockCacher) Close() error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Close")
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close indicates an expected call of Close
|
||||||
|
func (mr *MockCacherMockRecorder) Close() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCacher)(nil).Close))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get mocks base method
|
||||||
|
func (m *MockCacher) Get(arg0 context.Context, arg1 string) ([]byte, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Get", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get indicates an expected call of Get
|
||||||
|
func (mr *MockCacherMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCacher)(nil).Get), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set mocks base method
|
||||||
|
func (m *MockCacher) Set(arg0 context.Context, arg1 string, arg2 []byte) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set indicates an expected call of Set
|
||||||
|
func (mr *MockCacherMockRecorder) Set(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCacher)(nil).Set), arg0, arg1, arg2)
|
||||||
|
}
|
22
internal/hashutil/hashutil.go
Normal file
22
internal/hashutil/hashutil.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
|
||||||
|
package hashutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cespare/xxhash/v2"
|
||||||
|
"github.com/mitchellh/hashstructure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hash returns the xxhash value of an arbitrary value or struct. Returns 0
|
||||||
|
// on error. NOT SUITABLE FOR CRYTOGRAPHIC HASHING.
|
||||||
|
//
|
||||||
|
// http://cyan4973.github.io/xxHash/
|
||||||
|
func Hash(v interface{}) uint64 {
|
||||||
|
opts := &hashstructure.HashOptions{
|
||||||
|
Hasher: xxhash.New(),
|
||||||
|
}
|
||||||
|
hash, err := hashstructure.Hash(v, opts)
|
||||||
|
if err != nil {
|
||||||
|
hash = 0
|
||||||
|
}
|
||||||
|
return hash
|
||||||
|
}
|
37
internal/hashutil/hashutil_test.go
Normal file
37
internal/hashutil/hashutil_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
// Package hashutil provides NON-CRYPTOGRAPHIC utility functions for hashing
|
||||||
|
package hashutil
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v interface{}
|
||||||
|
want uint64
|
||||||
|
}{
|
||||||
|
{"string", "string", 6134271061086542852},
|
||||||
|
{"num", 7, 609900476111905877},
|
||||||
|
{"compound struct", struct {
|
||||||
|
NESCarts []string
|
||||||
|
numberOfCarts int
|
||||||
|
}{
|
||||||
|
[]string{"Battletoads", "Mega Man 1", "Clash at Demonhead"},
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
9061978360207659575},
|
||||||
|
{"compound struct with embedded func (errors!)", struct {
|
||||||
|
AnswerToEverythingFn func() int
|
||||||
|
}{
|
||||||
|
func() int { return 42 },
|
||||||
|
},
|
||||||
|
0},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := Hash(tt.v); got != tt.want {
|
||||||
|
t.Errorf("Hash() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,15 +5,13 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockProvider provides a mocked implementation of the providers interface.
|
// MockProvider provides a mocked implementation of the providers interface.
|
||||||
type MockProvider struct {
|
type MockProvider struct {
|
||||||
AuthenticateResponse sessions.State
|
AuthenticateResponse oauth2.Token
|
||||||
AuthenticateError error
|
AuthenticateError error
|
||||||
RefreshResponse sessions.State
|
RefreshResponse oauth2.Token
|
||||||
RefreshError error
|
RefreshError error
|
||||||
RevokeError error
|
RevokeError error
|
||||||
GetSignInURLResponse string
|
GetSignInURLResponse string
|
||||||
|
@ -22,12 +20,12 @@ type MockProvider struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate is a mocked providers function.
|
// Authenticate is a mocked providers function.
|
||||||
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
func (mp MockProvider) Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) {
|
||||||
return &mp.AuthenticateResponse, mp.AuthenticateError
|
return &mp.AuthenticateResponse, mp.AuthenticateError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh is a mocked providers function.
|
// Refresh is a mocked providers function.
|
||||||
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
func (mp MockProvider) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) {
|
||||||
return &mp.RefreshResponse, mp.RefreshError
|
return &mp.RefreshResponse, mp.RefreshError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -77,48 +76,36 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
|
|
||||||
// Authenticate creates an identity session with github from a authorization code, and follows up
|
// Authenticate creates an identity session with github from a authorization code, and follows up
|
||||||
// call to the user and user group endpoint with the
|
// call to the user and user group endpoint with the
|
||||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
|
||||||
resp, err := p.Oauth.Exchange(ctx, code)
|
oauth2Token, err := p.Oauth.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("github: token exchange failed %v", err)
|
return nil, fmt.Errorf("github: token exchange failed %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &sessions.State{
|
err = p.updateSessionState(ctx, oauth2Token, v)
|
||||||
AccessToken: &oauth2.Token{
|
|
||||||
AccessToken: resp.AccessToken,
|
|
||||||
TokenType: resp.TokenType,
|
|
||||||
},
|
|
||||||
AccessTokenID: resp.AccessToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = p.updateSessionState(ctx, s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return oauth2Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateSessionState will get the user information from github and also retrieve the user's team(s)
|
// updateSessionState will get the user information from github and also retrieve the user's team(s)
|
||||||
//
|
//
|
||||||
// https://developer.github.com/v3/users/#get-the-authenticated-user
|
// https://developer.github.com/v3/users/#get-the-authenticated-user
|
||||||
func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) error {
|
func (p *Provider) updateSessionState(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
if s == nil || s.AccessToken == nil {
|
|
||||||
return errors.New("github: user session cannot be empty")
|
|
||||||
}
|
|
||||||
accessToken := s.AccessToken.AccessToken
|
|
||||||
|
|
||||||
err := p.userInfo(ctx, accessToken, s)
|
err := p.userInfo(ctx, t, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("github: could not retrieve user info %w", err)
|
return fmt.Errorf("github: could not retrieve user info %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.userEmail(ctx, accessToken, s)
|
err = p.userEmail(ctx, t, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("github: could not retrieve user email %w", err)
|
return fmt.Errorf("github: could not retrieve user email %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.userTeams(ctx, accessToken, s)
|
err = p.userTeams(ctx, t, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("github: could not retrieve groups %w", err)
|
return fmt.Errorf("github: could not retrieve groups %w", err)
|
||||||
}
|
}
|
||||||
|
@ -127,14 +114,12 @@ func (p *Provider) updateSessionState(ctx context.Context, s *sessions.State) er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh renews a user's session by making a new userInfo request.
|
// Refresh renews a user's session by making a new userInfo request.
|
||||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
|
||||||
if s.AccessToken == nil {
|
err := p.updateSessionState(ctx, t, v)
|
||||||
return nil, errors.New("github: missing oauth2 access token")
|
if err != nil {
|
||||||
}
|
|
||||||
if err := p.updateSessionState(ctx, s); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return s, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// userTeams returns a slice of teams the user belongs by making a request
|
// userTeams returns a slice of teams the user belongs by making a request
|
||||||
|
@ -142,7 +127,7 @@ func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.St
|
||||||
//
|
//
|
||||||
// https://developer.github.com/v3/teams/#list-user-teams
|
// https://developer.github.com/v3/teams/#list-user-teams
|
||||||
// https://developer.github.com/v3/auth/
|
// https://developer.github.com/v3/auth/
|
||||||
func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State) error {
|
func (p *Provider) userTeams(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
|
|
||||||
var response []struct {
|
var response []struct {
|
||||||
ID json.Number `json:"id"`
|
ID json.Number `json:"id"`
|
||||||
|
@ -154,20 +139,24 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
|
||||||
Privacy string `json:"privacy,omitempty"`
|
Privacy string `json:"privacy,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)}
|
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
|
||||||
teamURL := githubAPIURL + teamPath
|
teamURL := githubAPIURL + teamPath
|
||||||
err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response)
|
err := httputil.Client(ctx, http.MethodGet, teamURL, version.UserAgent(), headers, nil, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Interface("teams", response).Msg("github: user teams")
|
log.Debug().Interface("teams", response).Msg("github: user teams")
|
||||||
s.Groups = nil
|
var out struct {
|
||||||
for _, org := range response {
|
Groups []string `json:"groups"`
|
||||||
s.Groups = append(s.Groups, org.ID.String())
|
|
||||||
}
|
}
|
||||||
|
for _, org := range response {
|
||||||
return nil
|
out.Groups = append(out.Groups, org.ID.String())
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// userEmail returns the primary email of the user by making
|
// userEmail returns the primary email of the user by making
|
||||||
|
@ -175,7 +164,7 @@ func (p *Provider) userTeams(ctx context.Context, at string, s *sessions.State)
|
||||||
//
|
//
|
||||||
// https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user
|
// https://developer.github.com/v3/users/emails/#list-email-addresses-for-a-user
|
||||||
// https://developer.github.com/v3/auth/
|
// https://developer.github.com/v3/auth/
|
||||||
func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State) error {
|
func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
// response represents the github user email
|
// response represents the github user email
|
||||||
// https://developer.github.com/v3/users/emails/#response
|
// https://developer.github.com/v3/users/emails/#response
|
||||||
var response []struct {
|
var response []struct {
|
||||||
|
@ -184,48 +173,67 @@ func (p *Provider) userEmail(ctx context.Context, at string, s *sessions.State)
|
||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
Visibility string `json:"visibility"`
|
Visibility string `json:"visibility"`
|
||||||
}
|
}
|
||||||
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", at)}
|
headers := map[string]string{"Authorization": fmt.Sprintf("token %s", t.AccessToken)}
|
||||||
emailURL := githubAPIURL + emailPath
|
emailURL := githubAPIURL + emailPath
|
||||||
err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response)
|
err := httputil.Client(ctx, http.MethodGet, emailURL, version.UserAgent(), headers, nil, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var out struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Verified bool `json:"email_verified"`
|
||||||
|
}
|
||||||
log.Debug().Interface("emails", response).Msg("github: user emails")
|
log.Debug().Interface("emails", response).Msg("github: user emails")
|
||||||
for _, email := range response {
|
for _, email := range response {
|
||||||
if email.Primary && email.Verified {
|
if email.Primary && email.Verified {
|
||||||
s.Email = email.Email
|
out.Email = email.Email
|
||||||
s.EmailVerified = true
|
out.Verified = true
|
||||||
return nil
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) userInfo(ctx context.Context, at string, s *sessions.State) error {
|
func (p *Provider) userInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
var response struct {
|
var response struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Email string `json:"email"`
|
|
||||||
AvatarURL string `json:"avatar_url,omitempty"`
|
AvatarURL string `json:"avatar_url,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := map[string]string{
|
headers := map[string]string{
|
||||||
"Authorization": fmt.Sprintf("token %s", at),
|
"Authorization": fmt.Sprintf("token %s", t.AccessToken),
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
}
|
}
|
||||||
err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response)
|
err := httputil.Client(ctx, http.MethodGet, p.userEndpoint, version.UserAgent(), headers, nil, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var out struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
User string `json:"user"`
|
||||||
|
Picture string `json:"picture,omitempty"`
|
||||||
|
// needs to be set manually
|
||||||
|
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
s.User = response.Login
|
out.User = response.Login
|
||||||
s.Name = response.Name
|
out.Subject = response.Login
|
||||||
s.Picture = response.AvatarURL
|
out.Name = response.Name
|
||||||
|
out.Picture = response.AvatarURL
|
||||||
// set the session expiry
|
// set the session expiry
|
||||||
s.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
|
out.Expiry = jwt.NewNumericDate(time.Now().Add(refreshDeadline))
|
||||||
return nil
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke method will remove all the github grants the user
|
// Revoke method will remove all the github grants the user
|
||||||
|
|
|
@ -5,7 +5,7 @@ package azure
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -16,7 +16,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,10 +59,7 @@ func (p *Provider) GetSignInURL(state string) string {
|
||||||
// `Directory.Read.All` is required.
|
// `Directory.Read.All` is required.
|
||||||
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
|
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
|
||||||
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
|
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
|
||||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
if s == nil || s.AccessToken == nil {
|
|
||||||
return nil, errors.New("identity/azure: session cannot be nil")
|
|
||||||
}
|
|
||||||
var response struct {
|
var response struct {
|
||||||
Groups []struct {
|
Groups []struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
@ -73,15 +69,23 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
|
||||||
GroupTypes []string `json:"groupTypes,omitempty"`
|
GroupTypes []string `json:"groupTypes,omitempty"`
|
||||||
} `json:"value"`
|
} `json:"value"`
|
||||||
}
|
}
|
||||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
|
||||||
err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response)
|
err := httputil.Client(ctx, http.MethodGet, defaultGroupURL, version.UserAgent(), headers, nil, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Interface("response", response).Msg("microsoft: groups")
|
||||||
|
var out struct {
|
||||||
|
Groups []string `json:"groups"`
|
||||||
}
|
}
|
||||||
var groups []string
|
|
||||||
for _, group := range response.Groups {
|
for _, group := range response.Groups {
|
||||||
log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("microsoft: group")
|
out.Groups = append(out.Groups, group.ID)
|
||||||
groups = append(groups, group.ID)
|
|
||||||
}
|
}
|
||||||
return groups, nil
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
// ErrRevokeNotImplemented error type when Revoke method is not implemented
|
// ErrRevokeNotImplemented is returned when revoke is not implemented
|
||||||
// by an identity provider
|
// by an identity provider.
|
||||||
var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented")
|
var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented")
|
||||||
|
|
||||||
// ErrSignoutNotImplemented error type when end session is not implemented
|
// ErrSignoutNotImplemented is returned when end session is not implemented
|
||||||
// by an identity provider
|
// by an identity provider
|
||||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||||
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
|
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
|
||||||
|
@ -14,3 +16,13 @@ var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implem
|
||||||
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
||||||
// does not receive one.
|
// does not receive one.
|
||||||
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
|
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
|
||||||
|
|
||||||
|
// ErrMissingIDToken is returned when (usually on refresh) and identity provider
|
||||||
|
// failed to include an id_token in a oauth2 token.
|
||||||
|
var ErrMissingIDToken = errors.New("identity/oidc: missing id_token")
|
||||||
|
|
||||||
|
// ErrMissingRefreshToken is returned if no refresh token was found.
|
||||||
|
var ErrMissingRefreshToken = errors.New("identity/oidc: missing refresh token")
|
||||||
|
|
||||||
|
// ErrMissingAccessToken is returned when no access token was found.
|
||||||
|
var ErrMissingAccessToken = errors.New("identity/oidc: missing access token")
|
||||||
|
|
|
@ -6,7 +6,6 @@ package gitlab
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
@ -15,14 +14,14 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Name identifies the GitLab identity provider
|
// Name identifies the GitLab identity provider.
|
||||||
const Name = "gitlab"
|
const Name = "gitlab"
|
||||||
|
|
||||||
var defaultScopes = []string{oidc.ScopeOpenID, "read_api", "read_user", "profile", "email"}
|
var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email", "api"}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultProviderURL = "https://gitlab.com"
|
defaultProviderURL = "https://gitlab.com"
|
||||||
|
@ -64,11 +63,7 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
//
|
//
|
||||||
// Returns 20 results at a time because the API results are paginated.
|
// Returns 20 results at a time because the API results are paginated.
|
||||||
// https://docs.gitlab.com/ee/api/groups.html#list-groups
|
// https://docs.gitlab.com/ee/api/groups.html#list-groups
|
||||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
if s == nil || s.AccessToken == nil {
|
|
||||||
return nil, errors.New("gitlab: user session cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
var response []struct {
|
var response []struct {
|
||||||
ID json.Number `json:"id"`
|
ID json.Number `json:"id"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
@ -81,17 +76,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
|
||||||
FullName string `json:"full_name,omitempty"`
|
FullName string `json:"full_name,omitempty"`
|
||||||
FullPath string `json:"full_path,omitempty"`
|
FullPath string `json:"full_path,omitempty"`
|
||||||
}
|
}
|
||||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
|
||||||
err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response)
|
err := httputil.Client(ctx, http.MethodGet, p.userGroupURL, version.UserAgent(), headers, nil, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var groups []string
|
|
||||||
log.Debug().Interface("response", response).Msg("gitlab: groups")
|
log.Debug().Interface("response", response).Msg("gitlab: groups")
|
||||||
|
var out struct {
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
}
|
||||||
for _, group := range response {
|
for _, group := range response {
|
||||||
groups = append(groups, group.ID.String())
|
out.Groups = append(out.Groups, group.ID.String())
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return groups, nil
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
oidc "github.com/coreos/go-oidc"
|
||||||
|
@ -18,7 +19,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -54,9 +54,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
||||||
}
|
}
|
||||||
p.Provider = genericOidc
|
p.Provider = genericOidc
|
||||||
|
if o.ServiceAccount == "" {
|
||||||
|
log.Warn().Msg("google: no service account, will not fetch groups")
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
// if service account set, configure admin sdk calls
|
|
||||||
if o.ServiceAccount != "" {
|
|
||||||
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
|
apiCreds, err := base64.StdEncoding.DecodeString(o.ServiceAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("google: could not decode service account json %w", err)
|
return nil, fmt.Errorf("google: could not decode service account json %w", err)
|
||||||
|
@ -80,9 +82,6 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
||||||
}
|
}
|
||||||
p.UserGroupFn = p.UserGroups
|
p.UserGroupFn = p.UserGroups
|
||||||
} else {
|
|
||||||
log.Warn().Msg("google: no service account, cannot retrieve groups")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
@ -106,17 +105,34 @@ func (p *Provider) GetSignInURL(state string) string {
|
||||||
// NOTE: groups via Directory API is limited to 1 QPS!
|
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||||
// https://developers.google.com/admin-sdk/directory/v1/limits
|
// https://developers.google.com/admin-sdk/directory/v1/limits
|
||||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
var groups []string
|
if p.apiClient == nil {
|
||||||
if p.apiClient != nil {
|
return errors.New("google: trying to fetch groups, but no api client")
|
||||||
req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100)
|
}
|
||||||
resp, err := req.Do()
|
s, err := p.GetSubject(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("google: group api request failed %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
var out struct {
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
}
|
||||||
|
req := p.apiClient.Groups.List().Context(ctx).UserKey(s)
|
||||||
|
err = req.Pages(ctx, func(resp *admin.Groups) error {
|
||||||
for _, group := range resp.Groups {
|
for _, group := range resp.Groups {
|
||||||
groups = append(groups, group.Email)
|
out.Groups = append(out.Groups, group.Email)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return groups, nil
|
_, err = req.Do()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("google: group api request failed %w", err)
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ package oidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -15,12 +16,11 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Name identifies the generic OpenID Connect provider
|
// Name identifies the generic OpenID Connect provider.
|
||||||
const Name = "oidc"
|
const Name = "oidc"
|
||||||
|
|
||||||
var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
var defaultScopes = []string{go_oidc.ScopeOpenID, "profile", "email", "offline_access"}
|
||||||
|
@ -37,11 +37,6 @@ type Provider struct {
|
||||||
// client application information and the server's endpoint URLs.
|
// client application information and the server's endpoint URLs.
|
||||||
Oauth *oauth2.Config
|
Oauth *oauth2.Config
|
||||||
|
|
||||||
// UserInfoURL specifies the endpoint responsible for returning claims
|
|
||||||
// about the authenticated End-User.
|
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
|
||||||
UserInfoURL string `json:"userinfo_endpoint,omitempty"`
|
|
||||||
|
|
||||||
// RevocationURL is the location of the OAuth 2.0 token revocation endpoint.
|
// RevocationURL is the location of the OAuth 2.0 token revocation endpoint.
|
||||||
// https://tools.ietf.org/html/rfc7009
|
// https://tools.ietf.org/html/rfc7009
|
||||||
RevocationURL string `json:"revocation_endpoint,omitempty"`
|
RevocationURL string `json:"revocation_endpoint,omitempty"`
|
||||||
|
@ -53,7 +48,7 @@ type Provider struct {
|
||||||
|
|
||||||
// UserGroupFn is, if set, used to return a slice of group IDs the
|
// UserGroupFn is, if set, used to return a slice of group IDs the
|
||||||
// user is a member of
|
// user is a member of
|
||||||
UserGroupFn func(context.Context, *sessions.State) ([]string, error)
|
UserGroupFn func(context.Context, *oauth2.Token, interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new instance of a generic OpenID Connect provider.
|
// New creates a new instance of a generic OpenID Connect provider.
|
||||||
|
@ -100,81 +95,89 @@ func (p *Provider) GetSignInURL(state string) string {
|
||||||
|
|
||||||
// Authenticate converts an authorization code returned from the identity
|
// Authenticate converts an authorization code returned from the identity
|
||||||
// provider into a token which is then converted into a user session.
|
// provider into a token which is then converted into a user session.
|
||||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
func (p *Provider) Authenticate(ctx context.Context, code string, v interface{}) (*oauth2.Token, error) {
|
||||||
|
// Exchange converts an authorization code into a token.
|
||||||
oauth2Token, err := p.Oauth.Exchange(ctx, code)
|
oauth2Token, err := p.Oauth.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
||||||
}
|
}
|
||||||
idToken, err := p.IdentityFromToken(ctx, oauth2Token)
|
|
||||||
|
idToken, err := p.getIDToken(ctx, oauth2Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
aud, err := urlutil.ParseAndValidateURL(p.Oauth.RedirectURL)
|
// hydrate `v` using claims inside the returned `id_token`
|
||||||
if err != nil {
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||||
return nil, fmt.Errorf("identity/oidc: bad redirect uri: %w", err)
|
if err := idToken.Claims(v); err != nil {
|
||||||
|
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, aud.Hostname())
|
if err := p.updateUserInfo(ctx, oauth2Token, v); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.Provider.Claims(&p); err == nil && p.UserInfoURL != "" {
|
return oauth2Token, nil
|
||||||
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("identity/oidc: could not retrieve user info %w", err)
|
|
||||||
}
|
|
||||||
if err := userInfo.Claims(&s); err != nil {
|
|
||||||
return nil, fmt.Errorf("identity/oidc: could not parse user claims %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
|
||||||
|
// groups endpoint (non-spec) to populate the rest of the user's information.
|
||||||
|
//
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||||
|
func (p *Provider) updateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
|
userInfo, err := p.Provider.UserInfo(ctx, oauth2.StaticTokenSource(t))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
|
||||||
|
}
|
||||||
|
if err := userInfo.Claims(v); err != nil {
|
||||||
|
return fmt.Errorf("identity/oidc: failed parsing user info endpoint claims: %w", err)
|
||||||
|
}
|
||||||
if p.UserGroupFn != nil {
|
if p.UserGroupFn != nil {
|
||||||
s.Groups, err = p.UserGroupFn(ctx, s)
|
if err := p.UserGroupFn(ctx, t, v); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("identity/oidc: could not retrieve groups: %w", err)
|
||||||
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh renews a user's session using an oidc refresh token without reprompting the user.
|
// Refresh renews a user's session using an oidc refresh token without reprompting the user.
|
||||||
// Group membership is also refreshed.
|
// Group membership is also refreshed.
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v interface{}) (*oauth2.Token, error) {
|
||||||
if s.AccessToken == nil || s.AccessToken.RefreshToken == "" {
|
if t == nil {
|
||||||
return nil, errors.New("internal/oidc: missing refresh token")
|
return nil, ErrMissingAccessToken
|
||||||
|
}
|
||||||
|
if t.RefreshToken == "" {
|
||||||
|
return nil, ErrMissingRefreshToken
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
newToken, err := p.Oauth.TokenSource(ctx, t).Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken}
|
// Many identity providers _will not_ return `id_token` on refresh
|
||||||
oauthToken, err := p.Oauth.TokenSource(ctx, &t).Token()
|
// https://github.com/FusionAuth/fusionauth-issues/issues/110#issuecomment-481526544
|
||||||
if err != nil {
|
idToken, err := p.getIDToken(ctx, newToken)
|
||||||
return nil, fmt.Errorf("internal/oidc: refresh failed %w", err)
|
if err == nil {
|
||||||
}
|
if err := idToken.Claims(v); err != nil {
|
||||||
idToken, err := p.IdentityFromToken(ctx, oauthToken)
|
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
|
||||||
}
|
|
||||||
if err := s.UpdateState(idToken, oauthToken); err != nil {
|
|
||||||
return nil, fmt.Errorf("internal/oidc: state update failed %w", err)
|
|
||||||
}
|
|
||||||
if p.UserGroupFn != nil {
|
|
||||||
s.Groups, err = p.UserGroupFn(ctx, s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("internal/oidc: could not retrieve groups %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s, nil
|
if err := p.updateUserInfo(ctx, newToken, v); err != nil {
|
||||||
|
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||||
|
}
|
||||||
|
return newToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IdentityFromToken takes an identity provider issued JWT as input ('id_token')
|
// getIDToken returns the raw jwt payload for `id_token` from the oauth2 token
|
||||||
// and returns a session state. The provided token's audience ('aud') must
|
// returned following oidc code flow
|
||||||
// match Pomerium's client_id.
|
//
|
||||||
func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||||
|
func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
|
||||||
rawIDToken, ok := t.Extra("id_token").(string)
|
rawIDToken, ok := t.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("internal/oidc: id_token not found")
|
return nil, ErrMissingIDToken
|
||||||
}
|
}
|
||||||
return p.Verifier.Verify(ctx, rawIDToken)
|
return p.Verifier.Verify(ctx, rawIDToken)
|
||||||
}
|
}
|
||||||
|
@ -183,13 +186,16 @@ func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*go_
|
||||||
// support revocation an error is thrown.
|
// support revocation an error is thrown.
|
||||||
//
|
//
|
||||||
// https://tools.ietf.org/html/rfc7009#section-2.1
|
// https://tools.ietf.org/html/rfc7009#section-2.1
|
||||||
func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
|
||||||
if p.RevocationURL == "" {
|
if p.RevocationURL == "" {
|
||||||
return ErrRevokeNotImplemented
|
return ErrRevokeNotImplemented
|
||||||
}
|
}
|
||||||
|
if t == nil {
|
||||||
|
return ErrMissingAccessToken
|
||||||
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("token", token.AccessToken)
|
params.Add("token", t.AccessToken)
|
||||||
params.Add("token_type_hint", "access_token")
|
params.Add("token_type_hint", "access_token")
|
||||||
// Some providers like okta / onelogin require "client authentication"
|
// Some providers like okta / onelogin require "client authentication"
|
||||||
// https://developer.okta.com/docs/reference/api/oidc/#client-secret
|
// https://developer.okta.com/docs/reference/api/oidc/#client-secret
|
||||||
|
@ -198,7 +204,7 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||||
params.Add("client_secret", p.Oauth.ClientSecret)
|
params.Add("client_secret", p.Oauth.ClientSecret)
|
||||||
|
|
||||||
err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
|
err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
|
||||||
if err != nil && err != httputil.ErrTokenRevoked {
|
if err != nil && errors.Is(err, httputil.ErrTokenRevoked) {
|
||||||
return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err)
|
return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,3 +220,20 @@ func (p *Provider) LogOut() (*url.URL, error) {
|
||||||
}
|
}
|
||||||
return urlutil.ParseAndValidateURL(p.EndSessionURL)
|
return urlutil.ParseAndValidateURL(p.EndSessionURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
|
||||||
|
func (p *Provider) GetSubject(v interface{}) (string, error) {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var s struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(b, &s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return s.Subject, nil
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ package okta
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -13,9 +14,9 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -64,7 +65,11 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
|
|
||||||
// UserGroups fetches the groups of which the user is a member
|
// UserGroups fetches the groups of which the user is a member
|
||||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
|
s, err := p.GetSubject(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
var response []struct {
|
var response []struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Profile struct {
|
Profile struct {
|
||||||
|
@ -74,15 +79,22 @@ func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)}
|
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.serviceAccount)}
|
||||||
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject)
|
uri := fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s)
|
||||||
err := httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
|
err = httputil.Client(ctx, http.MethodGet, uri, version.UserAgent(), headers, nil, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Interface("response", response).Msg("okta: groups")
|
||||||
|
var out struct {
|
||||||
|
Groups []string `json:"groups"`
|
||||||
}
|
}
|
||||||
var groups []string
|
|
||||||
for _, group := range response {
|
for _, group := range response {
|
||||||
log.Debug().Interface("group", group).Msg("okta: group")
|
out.Groups = append(out.Groups, group.ID)
|
||||||
groups = append(groups, group.ID)
|
|
||||||
}
|
}
|
||||||
return groups, nil
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(b, v)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,17 +5,15 @@ package onelogin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
oidc "github.com/coreos/go-oidc"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,24 +53,10 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
|
|
||||||
// UserGroups returns a slice of group names a given user is in.
|
// UserGroups returns a slice of group names a given user is in.
|
||||||
// https://developers.onelogin.com/openid-connect/api/user-info
|
// https://developers.onelogin.com/openid-connect/api/user-info
|
||||||
func (p *Provider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
func (p *Provider) UserGroups(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||||
if s == nil || s.AccessToken == nil {
|
if t == nil {
|
||||||
return nil, errors.New("identity/onelogin: session cannot be nil")
|
return pom_oidc.ErrMissingAccessToken
|
||||||
}
|
}
|
||||||
var response struct {
|
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", t.AccessToken)}
|
||||||
User string `json:"sub"`
|
return httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, v)
|
||||||
Email string `json:"email"`
|
|
||||||
PreferredUsername string `json:"preferred_username"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
GivenName string `json:"given_name"`
|
|
||||||
FamilyName string `json:"family_name"`
|
|
||||||
Groups []string `json:"groups"`
|
|
||||||
}
|
|
||||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
|
||||||
err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return response.Groups, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,25 +17,24 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc/google"
|
"github.com/pomerium/pomerium/internal/identity/oidc/google"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
|
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
|
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// compile time assertions that providers are satisfying the interface
|
// compile time assertions that providers are satisfying the interface
|
||||||
_ Authenticator = &azure.Provider{}
|
_ Authenticator = &azure.Provider{}
|
||||||
_ Authenticator = &gitlab.Provider{}
|
|
||||||
_ Authenticator = &github.Provider{}
|
_ Authenticator = &github.Provider{}
|
||||||
|
_ Authenticator = &gitlab.Provider{}
|
||||||
_ Authenticator = &google.Provider{}
|
_ Authenticator = &google.Provider{}
|
||||||
|
_ Authenticator = &MockProvider{}
|
||||||
_ Authenticator = &oidc.Provider{}
|
_ Authenticator = &oidc.Provider{}
|
||||||
_ Authenticator = &okta.Provider{}
|
_ Authenticator = &okta.Provider{}
|
||||||
_ Authenticator = &onelogin.Provider{}
|
_ Authenticator = &onelogin.Provider{}
|
||||||
_ Authenticator = &MockProvider{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Authenticate(context.Context, string) (*sessions.State, error)
|
Authenticate(context.Context, string, interface{}) (*oauth2.Token, error)
|
||||||
Refresh(context.Context, *sessions.State) (*sessions.State, error)
|
Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error)
|
||||||
Revoke(context.Context, *oauth2.Token) error
|
Revoke(context.Context, *oauth2.Token) error
|
||||||
GetSignInURL(state string) string
|
GetSignInURL(state string) string
|
||||||
LogOut() (*url.URL, error)
|
LogOut() (*url.URL, error)
|
||||||
|
|
|
@ -37,6 +37,9 @@ type Store struct {
|
||||||
srv *http.Server
|
srv *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrCacheMiss is returned when the cache misses for a given key.
|
||||||
|
var ErrCacheMiss = errors.New("cache miss")
|
||||||
|
|
||||||
// Options represent autocache options.
|
// Options represent autocache options.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Addr string
|
Addr string
|
||||||
|
@ -60,7 +63,7 @@ var DefaultOptions = &Options{
|
||||||
GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error {
|
GetterFn: func(ctx context.Context, id string, dest groupcache.Sink) error {
|
||||||
b := fromContext(ctx)
|
b := fromContext(ctx)
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return fmt.Errorf("autocache: empty ctx for id: %s", id)
|
return fmt.Errorf("autocache: id %s : %w", id, ErrCacheMiss)
|
||||||
}
|
}
|
||||||
if err := dest.SetBytes(b); err != nil {
|
if err := dest.SetBytes(b); err != nil {
|
||||||
return fmt.Errorf("autocache: sink error %w", err)
|
return fmt.Errorf("autocache: sink error %w", err)
|
||||||
|
|
95
internal/sessions/cache/cache_store.go
vendored
95
internal/sessions/cache/cache_store.go
vendored
|
@ -1,95 +0,0 @@
|
||||||
// Package cache provides a remote cache based implementation of session store
|
|
||||||
// and loader. See pomerium's cache service for more details.
|
|
||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
|
||||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ sessions.SessionStore = &Store{}
|
|
||||||
var _ sessions.SessionLoader = &Store{}
|
|
||||||
|
|
||||||
// Store implements the session store interface using a cache service.
|
|
||||||
type Store struct {
|
|
||||||
cache client.Cacher
|
|
||||||
encoder encoding.MarshalUnmarshaler
|
|
||||||
queryParam string
|
|
||||||
wrappedStore sessions.SessionStore
|
|
||||||
}
|
|
||||||
|
|
||||||
// Options represent cache store's available configurations.
|
|
||||||
type Options struct {
|
|
||||||
Cache client.Cacher
|
|
||||||
Encoder encoding.MarshalUnmarshaler
|
|
||||||
QueryParam string
|
|
||||||
WrappedStore sessions.SessionStore
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultOptions = &Options{
|
|
||||||
QueryParam: "cache_store_key",
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStore creates a new cache
|
|
||||||
func NewStore(o *Options) *Store {
|
|
||||||
if o.QueryParam == "" {
|
|
||||||
o.QueryParam = defaultOptions.QueryParam
|
|
||||||
}
|
|
||||||
return &Store{
|
|
||||||
cache: o.Cache,
|
|
||||||
encoder: o.Encoder,
|
|
||||||
queryParam: o.QueryParam,
|
|
||||||
wrappedStore: o.WrappedStore,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadSession looks for a preset query parameter in the request body
|
|
||||||
// representing the key to lookup from the cache.
|
|
||||||
func (s *Store) LoadSession(r *http.Request) (string, error) {
|
|
||||||
// look for our cache's key in the default query param
|
|
||||||
sessionID := r.URL.Query().Get(s.queryParam)
|
|
||||||
if sessionID == "" {
|
|
||||||
return "", sessions.ErrNoSessionFound
|
|
||||||
}
|
|
||||||
exists, val, err := s.cache.Get(r.Context(), sessionID)
|
|
||||||
if err != nil {
|
|
||||||
log.FromRequest(r).Debug().Msg("sessions/cache: miss, trying wrapped loader")
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return "", sessions.ErrNoSessionFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(val), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSession clears the session from the wrapped store.
|
|
||||||
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
|
||||||
s.wrappedStore.ClearSession(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveSession saves the session to the cache, and wrapped store.
|
|
||||||
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
|
||||||
err := s.wrappedStore.SaveSession(w, r, x)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
state, ok := x.(*sessions.State)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("sessions/cache: cannot cache non state type")
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := s.encoder.Marshal(&state)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sessions/cache: marshal %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.cache.Set(r.Context(), state.AccessTokenID, data)
|
|
||||||
}
|
|
190
internal/sessions/cache/cache_store_test.go
vendored
190
internal/sessions/cache/cache_store_test.go
vendored
|
@ -1,190 +0,0 @@
|
||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
|
||||||
mock_encoder "github.com/pomerium/pomerium/internal/encoding/mock"
|
|
||||||
"github.com/pomerium/pomerium/internal/grpc/cache/client"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions/mock"
|
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockCache struct {
|
|
||||||
Key string
|
|
||||||
KeyExists bool
|
|
||||||
Value []byte
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mockCache) Get(ctx context.Context, key string) (keyExists bool, value []byte, err error) {
|
|
||||||
return mc.KeyExists, mc.Value, mc.Err
|
|
||||||
}
|
|
||||||
func (mc *mockCache) Set(ctx context.Context, key string, value []byte) error {
|
|
||||||
return mc.Err
|
|
||||||
}
|
|
||||||
func (mc *mockCache) Close() error {
|
|
||||||
return mc.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewStore(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
Options *Options
|
|
||||||
State *sessions.State
|
|
||||||
|
|
||||||
wantErr bool
|
|
||||||
wantLoadErr bool
|
|
||||||
wantStatus int
|
|
||||||
}{
|
|
||||||
{"simple good",
|
|
||||||
&Options{
|
|
||||||
Cache: &mockCache{},
|
|
||||||
WrappedStore: &mock.Store{},
|
|
||||||
Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
|
||||||
},
|
|
||||||
&sessions.State{Email: "user@domain.com", User: "user"},
|
|
||||||
false, false,
|
|
||||||
http.StatusOK},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := NewStore(tt.Options)
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
r = httptest.NewRequest("GET", "/", nil)
|
|
||||||
w = httptest.NewRecorder()
|
|
||||||
|
|
||||||
got.ClearSession(w, r)
|
|
||||||
status := w.Result().StatusCode
|
|
||||||
if diff := cmp.Diff(status, tt.wantStatus); diff != "" {
|
|
||||||
t.Errorf("ClearSession() = %v", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStore_SaveSession(t *testing.T) {
|
|
||||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
|
||||||
encoder := ecjson.New(cipher)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
cs, err := cookie.NewStore(&cookie.Options{
|
|
||||||
Name: "_pomerium",
|
|
||||||
}, encoder)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
Options *Options
|
|
||||||
|
|
||||||
x interface{}
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"good", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalResponse: []byte("ok")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
|
|
||||||
{"encoder error", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
|
|
||||||
{"good", &Options{Cache: &mockCache{}, WrappedStore: &mock.Store{SaveError: errors.New("err")}}, &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true},
|
|
||||||
{"bad type", &Options{Cache: &mockCache{}, WrappedStore: cs, Encoder: mock_encoder.Encoder{MarshalError: errors.New("err")}}, "bad type!", true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
o := tt.Options
|
|
||||||
if o.WrappedStore == nil {
|
|
||||||
o.WrappedStore = cs
|
|
||||||
|
|
||||||
}
|
|
||||||
cacheStore := NewStore(tt.Options)
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
r.Header.Set("Accept", "application/json")
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStore_LoadSession(t *testing.T) {
|
|
||||||
key := cryptutil.NewBase64Key()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
state *sessions.State
|
|
||||||
cache client.Cacher
|
|
||||||
encoder encoding.MarshalUnmarshaler
|
|
||||||
queryParam string
|
|
||||||
wrappedStore sessions.SessionStore
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"good",
|
|
||||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
|
||||||
&mockCache{KeyExists: true},
|
|
||||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
|
||||||
defaultOptions.QueryParam,
|
|
||||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
|
||||||
false},
|
|
||||||
{"missing param with key",
|
|
||||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
|
||||||
&mockCache{KeyExists: true},
|
|
||||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
|
||||||
"bad_query",
|
|
||||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
|
||||||
true},
|
|
||||||
{"doesn't exist",
|
|
||||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
|
||||||
&mockCache{KeyExists: false},
|
|
||||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
|
||||||
defaultOptions.QueryParam,
|
|
||||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
|
||||||
true},
|
|
||||||
{"retrieval error",
|
|
||||||
&sessions.State{AccessTokenID: key, Email: "user@pomerium.io"},
|
|
||||||
&mockCache{Err: errors.New("err")},
|
|
||||||
mock_encoder.Encoder{MarshalResponse: []byte("ok")},
|
|
||||||
defaultOptions.QueryParam,
|
|
||||||
&mock.Store{Session: &sessions.State{AccessTokenID: key, Email: "user@pomerium.io"}},
|
|
||||||
true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Store{
|
|
||||||
cache: tt.cache,
|
|
||||||
encoder: tt.encoder,
|
|
||||||
queryParam: tt.queryParam,
|
|
||||||
wrappedStore: tt.wrappedStore,
|
|
||||||
}
|
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
q := r.URL.Query()
|
|
||||||
|
|
||||||
q.Set(defaultOptions.QueryParam, tt.state.AccessTokenID)
|
|
||||||
r.URL.RawQuery = q.Encode()
|
|
||||||
r.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
_, err := s.LoadSession(r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Store.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,15 +1,11 @@
|
||||||
package sessions
|
package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
oidc "github.com/coreos/go-oidc"
|
|
||||||
"github.com/mitchellh/hashstructure"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
@ -27,6 +23,9 @@ type State struct {
|
||||||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
||||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||||
ID string `json:"jti,omitempty"`
|
ID string `json:"jti,omitempty"`
|
||||||
|
// At_hash is an OPTIONAL Access Token hash value
|
||||||
|
// https://ldapwiki.com/wiki/At_hash
|
||||||
|
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||||
|
|
||||||
// core pomerium identity claims ; not standard to RFC 7519
|
// core pomerium identity claims ; not standard to RFC 7519
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
@ -48,84 +47,24 @@ type State struct {
|
||||||
// Programmatic whether this state is used for machine-to-machine
|
// Programmatic whether this state is used for machine-to-machine
|
||||||
// programatic access.
|
// programatic access.
|
||||||
Programmatic bool `json:"programatic"`
|
Programmatic bool `json:"programatic"`
|
||||||
|
|
||||||
AccessToken *oauth2.Token `json:"act,omitempty"`
|
|
||||||
AccessTokenID string `json:"ati,omitempty"`
|
|
||||||
|
|
||||||
idToken *oidc.IDToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStateFromTokens returns a session state built from oidc and oauth2
|
|
||||||
// tokens as part of OpenID Connect flow with a new audience appended to the
|
|
||||||
// audience claim.
|
|
||||||
func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) {
|
|
||||||
if idToken == nil {
|
|
||||||
return nil, errors.New("sessions: oidc id token missing")
|
|
||||||
}
|
|
||||||
if accessToken == nil {
|
|
||||||
return nil, errors.New("sessions: oauth2 token missing")
|
|
||||||
}
|
|
||||||
s := &State{}
|
|
||||||
if err := idToken.Claims(s); err != nil {
|
|
||||||
return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err)
|
|
||||||
}
|
|
||||||
s.Audience = []string{audience}
|
|
||||||
s.idToken = idToken
|
|
||||||
s.AccessToken = accessToken
|
|
||||||
s.AccessTokenID = s.accessTokenHash()
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateState updates the current state given a new identity (oidc) and authorization
|
|
||||||
// (oauth2) tokens following a oidc refresh. NB, unlike during authentication,
|
|
||||||
// refresh typically provides fewer claims in the token so we want to build from
|
|
||||||
// our previous state.
|
|
||||||
func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error {
|
|
||||||
if idToken == nil {
|
|
||||||
return errors.New("sessions: oidc id token missing")
|
|
||||||
}
|
|
||||||
if accessToken == nil {
|
|
||||||
return errors.New("sessions: oauth2 token missing")
|
|
||||||
}
|
|
||||||
audience := append(s.Audience[:0:0], s.Audience...)
|
|
||||||
s.AccessToken = accessToken
|
|
||||||
if err := idToken.Claims(s); err != nil {
|
|
||||||
return fmt.Errorf("sessions: update state failed %w", err)
|
|
||||||
}
|
|
||||||
s.Audience = audience
|
|
||||||
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
|
||||||
s.AccessTokenID = s.accessTokenHash()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSession updates issuer, audience, and issuance timestamps but keeps
|
// NewSession updates issuer, audience, and issuance timestamps but keeps
|
||||||
// parent expiry.
|
// parent expiry.
|
||||||
func (s State) NewSession(issuer string, audience []string) *State {
|
func NewSession(s *State, issuer string, audience []string, accessToken *oauth2.Token) State {
|
||||||
s.IssuedAt = jwt.NewNumericDate(timeNow())
|
newState := *s
|
||||||
s.NotBefore = s.IssuedAt
|
newState.IssuedAt = jwt.NewNumericDate(timeNow())
|
||||||
s.Audience = audience
|
newState.NotBefore = newState.IssuedAt
|
||||||
s.Issuer = issuer
|
newState.Audience = audience
|
||||||
return &s
|
newState.Issuer = issuer
|
||||||
}
|
newState.AccessTokenHash = fmt.Sprintf("%x", hashutil.Hash(accessToken))
|
||||||
|
newState.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
||||||
// RouteSession creates a route session with access tokens stripped.
|
return newState
|
||||||
func (s State) RouteSession() *State {
|
|
||||||
s.AccessToken = nil
|
|
||||||
return &s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsExpired returns true if the users's session is expired.
|
// IsExpired returns true if the users's session is expired.
|
||||||
func (s *State) IsExpired() bool {
|
func (s *State) IsExpired() bool {
|
||||||
|
return s.Expiry != nil && timeNow().After(s.Expiry.Time())
|
||||||
if s.Expiry != nil && timeNow().After(s.Expiry.Time()) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.AccessToken != nil && timeNow().After(s.AccessToken.Expiry) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Impersonating returns if the request is impersonating.
|
// Impersonating returns if the request is impersonating.
|
||||||
|
@ -133,23 +72,6 @@ func (s *State) Impersonating() bool {
|
||||||
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
|
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestEmail is the email to make the request as.
|
|
||||||
func (s *State) RequestEmail() string {
|
|
||||||
if s.ImpersonateEmail != "" {
|
|
||||||
return s.ImpersonateEmail
|
|
||||||
}
|
|
||||||
return s.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestGroups returns the groups of the Groups making the request; uses
|
|
||||||
// impersonating user if set.
|
|
||||||
func (s *State) RequestGroups() string {
|
|
||||||
if len(s.ImpersonateGroups) != 0 {
|
|
||||||
return strings.Join(s.ImpersonateGroups, ",")
|
|
||||||
}
|
|
||||||
return strings.Join(s.Groups, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetImpersonation sets impersonation user and groups.
|
// SetImpersonation sets impersonation user and groups.
|
||||||
func (s *State) SetImpersonation(email, groups string) {
|
func (s *State) SetImpersonation(email, groups string) {
|
||||||
s.ImpersonateEmail = email
|
s.ImpersonateEmail = email
|
||||||
|
@ -159,34 +81,3 @@ func (s *State) SetImpersonation(email, groups string) {
|
||||||
s.ImpersonateGroups = strings.Split(groups, ",")
|
s.ImpersonateGroups = strings.Split(groups, ",")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) accessTokenHash() string {
|
|
||||||
hash, err := hashstructure.Hash(
|
|
||||||
s.AccessToken,
|
|
||||||
&hashstructure.HashOptions{Hasher: xxhash.New()})
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%x", hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON parses the JSON-encoded session state.
|
|
||||||
// TODO(BDD): remove in v0.8.0
|
|
||||||
func (s *State) UnmarshalJSON(b []byte) error {
|
|
||||||
type Alias State
|
|
||||||
t := &struct {
|
|
||||||
*Alias
|
|
||||||
OldToken *oauth2.Token `json:"access_token,omitempty"` // < v0.5.0
|
|
||||||
}{
|
|
||||||
Alias: (*Alias)(s),
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(b, &t); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if t.AccessToken == nil {
|
|
||||||
t.AccessToken = t.OldToken
|
|
||||||
}
|
|
||||||
*s = *(*State)(t.Alias)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -5,8 +5,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
@ -38,12 +36,6 @@ func TestState_Impersonating(t *testing.T) {
|
||||||
if got := s.Impersonating(); got != tt.want {
|
if got := s.Impersonating(); got != tt.want {
|
||||||
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
|
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail {
|
|
||||||
t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
|
|
||||||
}
|
|
||||||
if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups {
|
|
||||||
t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,7 +55,6 @@ func TestState_IsExpired(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
|
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
|
||||||
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
||||||
{"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -72,7 +63,6 @@ func TestState_IsExpired(t *testing.T) {
|
||||||
Expiry: tt.Expiry,
|
Expiry: tt.Expiry,
|
||||||
NotBefore: tt.NotBefore,
|
NotBefore: tt.NotBefore,
|
||||||
IssuedAt: tt.IssuedAt,
|
IssuedAt: tt.IssuedAt,
|
||||||
AccessToken: tt.AccessToken,
|
|
||||||
}
|
}
|
||||||
if exp := s.IsExpired(); exp != tt.wantErr {
|
if exp := s.IsExpired(); exp != tt.wantErr {
|
||||||
t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr)
|
t.Errorf("State.IsExpired() error = %v, wantErr %v", exp, tt.wantErr)
|
||||||
|
@ -80,67 +70,3 @@ func TestState_IsExpired(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestState_RouteSession(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
timeNow = func() time.Time {
|
|
||||||
return now
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
Issuer string
|
|
||||||
Audience jwt.Audience
|
|
||||||
Expiry *jwt.NumericDate
|
|
||||||
AccessToken *oauth2.Token
|
|
||||||
|
|
||||||
issuer string
|
|
||||||
|
|
||||||
audience []string
|
|
||||||
|
|
||||||
want *State
|
|
||||||
}{
|
|
||||||
{"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow())}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := State{
|
|
||||||
Issuer: tt.Issuer,
|
|
||||||
Audience: tt.Audience,
|
|
||||||
Expiry: tt.Expiry,
|
|
||||||
AccessToken: tt.AccessToken,
|
|
||||||
}
|
|
||||||
cmpOpts := []cmp.Option{
|
|
||||||
cmpopts.IgnoreUnexported(State{}),
|
|
||||||
}
|
|
||||||
got := s.NewSession(tt.issuer, tt.audience)
|
|
||||||
got = got.RouteSession()
|
|
||||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
|
||||||
t.Errorf("State.RouteSession() = %s", diff)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestState_accessTokenHash(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
state State
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"empty access token", State{}, "34c96acdcadb1bbb"},
|
|
||||||
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
|
|
||||||
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
|
|
||||||
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
|
|
||||||
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &tt.state
|
|
||||||
if got := s.accessTokenHash(); got != tt.want {
|
|
||||||
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ const (
|
||||||
QueryRefreshToken = "pomerium_refresh_token"
|
QueryRefreshToken = "pomerium_refresh_token"
|
||||||
QueryAccessTokenID = "pomerium_session_access_token_id"
|
QueryAccessTokenID = "pomerium_session_access_token_id"
|
||||||
QueryAudience = "pomerium_session_audience"
|
QueryAudience = "pomerium_session_audience"
|
||||||
|
QueryProgrammaticToken = "pomerium_programmatic_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
// URL signature based query params used for verifying the authenticity of a URL.
|
// URL signature based query params used for verifying the authenticity of a URL.
|
||||||
|
|
Loading…
Add table
Reference in a new issue