From 380d314404428bde09e2c31d9029868ef30a58e3 Mon Sep 17 00:00:00 2001 From: Bobby DeSimone Date: Thu, 29 Aug 2019 22:12:29 -0700 Subject: [PATCH] authenticate: make service http only - Rename SessionState to State to avoid stutter. - Simplified option validation to use a wrapper function for base64 secrets. - Removed authenticates grpc code. - Abstracted logic to load and validate a user's authenticate session. - Removed instances of url.Parse in favor of urlutil's version. - proxy: replaces grpc refresh logic with forced deadline advancement. - internal/sessions: remove rest store; parse authorize header as part of session store. - proxy: refactor request signer - sessions: remove extend deadline (fixes #294) - remove AuthenticateInternalAddr - remove AuthenticateInternalAddrString - omit type tag.Key from declaration of vars TagKey* it will be inferred from the right-hand side - remove compatibility package xerrors - use cloned http.DefaultTransport as base transport --- Makefile | 1 - authenticate/authenticate.go | 40 +- authenticate/grpc.go | 62 --- authenticate/grpc_test.go | 153 ------- authenticate/handlers.go | 153 ++++--- authenticate/handlers_test.go | 75 ++-- cmd/pomerium/main.go | 9 +- cmd/pomerium/main_test.go | 5 +- docs/docs/quick-start/synology.md | 1 - .../examples/docker/nginx.docker-compose.yml | 1 - .../kubernetes/kubernetes-config.yaml | 1 - docs/docs/reference/reference.md | 20 +- go.mod | 1 - go.sum | 2 - internal/config/options.go | 14 - internal/config/options_test.go | 18 +- internal/cryptutil/encrypt.go | 12 + internal/cryptutil/encrypt_test.go | 23 + internal/cryptutil/sign.go | 12 +- internal/cryptutil/sign_test.go | 10 +- internal/httputil/errors.go | 17 +- internal/identity/google.go | 9 +- internal/identity/microsoft.go | 8 +- internal/identity/mock_provider.go | 12 +- internal/identity/okta.go | 2 +- internal/identity/onelogin.go | 2 +- internal/identity/providers.go | 14 +- internal/middleware/middleware.go | 7 +- internal/middleware/reverse_proxy_test.go | 3 +- internal/sessions/cookie_store.go | 165 ++++---- internal/sessions/cookie_store_test.go | 80 ++-- internal/sessions/mock_store.go | 6 +- internal/sessions/rest_store.go | 106 ----- internal/sessions/rest_store_test.go | 135 ------ .../sessions/{session_state.go => state.go} | 53 +-- .../{session_state_test.go => state_test.go} | 96 +++-- internal/sessions/store.go | 26 ++ internal/telemetry/metrics/const.go | 12 +- proto/authenticate/authenticate.pb.go | 399 ------------------ proto/authenticate/authenticate.proto | 26 -- proto/authenticate/convert.go | 49 --- .../mock_authenticate/mock_authenticate.go | 165 -------- proxy/clients/authenticate_client.go | 117 ----- proxy/clients/authenticate_client_test.go | 242 ----------- proxy/clients/authorize_client.go | 8 +- proxy/clients/authorize_client_test.go | 50 ++- proxy/clients/clients.go | 5 +- proxy/clients/mock_clients.go | 33 +- proxy/clients/mock_clients_test.go | 57 --- proxy/handlers.go | 172 +++----- proxy/handlers_test.go | 193 ++++----- proxy/proxy.go | 111 +---- proxy/proxy_test.go | 5 - 53 files changed, 718 insertions(+), 2280 deletions(-) delete mode 100644 authenticate/grpc.go delete mode 100644 authenticate/grpc_test.go delete mode 100644 internal/sessions/rest_store.go delete mode 100644 internal/sessions/rest_store_test.go rename internal/sessions/{session_state.go => state.go} (74%) rename internal/sessions/{session_state_test.go => state_test.go} (75%) create mode 100644 internal/sessions/store.go delete mode 100644 proto/authenticate/authenticate.pb.go delete mode 100644 proto/authenticate/authenticate.proto delete mode 100644 proto/authenticate/convert.go delete mode 100644 proto/authenticate/mock_authenticate/mock_authenticate.go delete mode 100644 proxy/clients/authenticate_client.go delete mode 100644 proxy/clients/authenticate_client_test.go delete mode 100644 proxy/clients/mock_clients_test.go diff --git a/Makefile b/Makefile index 018d1322d..a7885974f 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,6 @@ tag: ## Create a new git tag to prepare to build a release .PHONY: build build: ## Builds dynamic executables and/or packages. @echo "==> $@" - @echo Untracked changes? dirty? $(BUILDMETA) files? $(GITUNTRACKEDCHANGES) @CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)" .PHONY: lint diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 76fc777b7..0971b0b38 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -15,36 +15,31 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -// ValidateOptions checks to see if configuration values are valid for the authenticate service. -// The checks do not modify the internal state of the Option structure. Returns -// on first error found. +// ValidateOptions checks that configuration are complete and valid. +// Returns on first error found. func ValidateOptions(o config.Options) error { + if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil { + return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err) + } + if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { + return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err) + } if o.AuthenticateURL == nil { - return errors.New("authenticate: missing setting: authenticate-service-url") + return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required") } if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil { - return fmt.Errorf("authenticate: error parsing authenticate url: %v", err) + return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err) } if o.ClientID == "" { - return errors.New("authenticate: 'IDP_CLIENT_ID' missing") + return errors.New("authenticate: 'IDP_CLIENT_ID' is required") } if o.ClientSecret == "" { - return errors.New("authenticate: 'IDP_CLIENT_SECRET' missing") - } - if o.SharedKey == "" { - return errors.New("authenticate: 'SHARED_SECRET' missing") - } - decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret) - if err != nil { - return fmt.Errorf("authenticate: 'COOKIE_SECRET' must be base64 encoded: %v", err) - } - if len(decodedCookieSecret) != 32 { - return fmt.Errorf("authenticate: 'COOKIE_SECRET' %s be 32; got %d", o.CookieSecret, len(decodedCookieSecret)) + return errors.New("authenticate: 'IDP_CLIENT_SECRET' is required") } return nil } -// Authenticate validates a user's identity +// Authenticate contains data required to run the authenticate service. type Authenticate struct { SharedKey string RedirectURL *url.URL @@ -52,12 +47,11 @@ type Authenticate struct { templates *template.Template csrfStore sessions.CSRFStore sessionStore sessions.SessionStore - restStore sessions.SessionStore cipher cryptutil.Cipher provider identity.Authenticator } -// New validates and creates a new authenticate service from a set of Options +// New validates and creates a new authenticate service from a set of Options. func New(opts config.Options) (*Authenticate, error) { if err := ValidateOptions(opts); err != nil { return nil, err @@ -95,17 +89,13 @@ func New(opts config.Options) (*Authenticate, error) { if err != nil { return nil, err } - restStore, err := sessions.NewRestStore(&sessions.RestStoreOptions{Cipher: cipher}) - if err != nil { - return nil, err - } + return &Authenticate{ SharedKey: opts.SharedKey, RedirectURL: redirectURL, templates: templates.New(), csrfStore: cookieStore, sessionStore: cookieStore, - restStore: restStore, cipher: cipher, provider: provider, }, nil diff --git a/authenticate/grpc.go b/authenticate/grpc.go deleted file mode 100644 index c0e2e2801..000000000 --- a/authenticate/grpc.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:generate protoc -I ../proto/authenticate --go_out=plugins=grpc:../proto/authenticate ../proto/authenticate/authenticate.proto - -package authenticate // import "github.com/pomerium/pomerium/authenticate" -import ( - "context" - "fmt" - - "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" - pb "github.com/pomerium/pomerium/proto/authenticate" -) - -// Authenticate takes an encrypted code, and returns the authentication result. -func (p *Authenticate) Authenticate(ctx context.Context, in *pb.AuthenticateRequest) (*pb.Session, error) { - _, span := trace.StartSpan(ctx, "authenticate.grpc.Validate") - defer span.End() - session, err := sessions.UnmarshalSession(in.Code, p.cipher) - if err != nil { - return nil, fmt.Errorf("authenticate/grpc: authenticate %v", err) - } - newSessionProto, err := pb.ProtoFromSession(session) - if err != nil { - return nil, err - } - return newSessionProto, nil -} - -// Validate locally validates a JWT id_token; does NOT do nonce or revokation validation. -// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func (p *Authenticate) Validate(ctx context.Context, in *pb.ValidateRequest) (*pb.ValidateReply, error) { - ctx, span := trace.StartSpan(ctx, "authenticate.grpc.Validate") - defer span.End() - - isValid, err := p.provider.Validate(ctx, in.IdToken) - if err != nil { - return &pb.ValidateReply{IsValid: false}, fmt.Errorf("authenticate/grpc: validate %v", err) - } - return &pb.ValidateReply{IsValid: isValid}, nil -} - -// Refresh renews a user's session checks if the session has been revoked using an access token -// without reprompting the user. -func (p *Authenticate) Refresh(ctx context.Context, in *pb.Session) (*pb.Session, error) { - ctx, span := trace.StartSpan(ctx, "authenticate.grpc.Refresh") - defer span.End() - if in == nil { - return nil, fmt.Errorf("authenticate/grpc: session cannot be nil") - } - oldSession, err := pb.SessionFromProto(in) - if err != nil { - return nil, err - } - newSession, err := p.provider.Refresh(ctx, oldSession) - if err != nil { - return nil, fmt.Errorf("authenticate/grpc: refresh failed %v", err) - } - newSessionProto, err := pb.ProtoFromSession(newSession) - if err != nil { - return nil, err - } - return newSessionProto, nil -} diff --git a/authenticate/grpc_test.go b/authenticate/grpc_test.go deleted file mode 100644 index bf90730fb..000000000 --- a/authenticate/grpc_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package authenticate - -import ( - "context" - "errors" - "reflect" - "testing" - "time" - - "github.com/pomerium/pomerium/internal/identity" - - "github.com/golang/protobuf/ptypes" - "github.com/pomerium/pomerium/internal/cryptutil" - "github.com/pomerium/pomerium/internal/sessions" - pb "github.com/pomerium/pomerium/proto/authenticate" -) - -var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) - -func TestAuthenticate_Validate(t *testing.T) { - tests := []struct { - name string - idToken string - mp *identity.MockProvider - want bool - wantErr bool - }{ - {"good", "example", &identity.MockProvider{}, false, false}, - {"error", "error", &identity.MockProvider{ValidateError: errors.New("err")}, false, true}, - {"not error", "not error", &identity.MockProvider{ValidateError: nil}, false, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Authenticate{provider: tt.mp} - got, err := p.Validate(context.Background(), &pb.ValidateRequest{IdToken: tt.idToken}) - if (err != nil) != tt.wantErr { - t.Errorf("Authenticate.Validate() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got.IsValid, tt.want) { - t.Errorf("Authenticate.Validate() = %v, want %v", got.IsValid, tt.want) - } - }) - } -} - -func TestAuthenticate_Refresh(t *testing.T) { - fixedProtoTime, err := ptypes.TimestampProto(fixedDate) - if err != nil { - t.Fatal("failed to parse timestamp") - } - - tests := []struct { - name string - mock *identity.MockProvider - originalSession *pb.Session - want *pb.Session - wantErr bool - }{ - {"good", - &identity.MockProvider{ - RefreshResponse: &sessions.SessionState{ - AccessToken: "updated", - RefreshDeadline: fixedDate, - }}, - &pb.Session{ - AccessToken: "original", - RefreshDeadline: fixedProtoTime, - }, - &pb.Session{ - AccessToken: "updated", - RefreshDeadline: fixedProtoTime, - }, - false}, - {"test error", &identity.MockProvider{RefreshError: errors.New("hi")}, &pb.Session{RefreshToken: "refresh token", RefreshDeadline: fixedProtoTime}, nil, true}, - {"test catch nil", nil, nil, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Authenticate{provider: tt.mock} - - got, err := p.Refresh(context.Background(), tt.originalSession) - if (err != nil) != tt.wantErr { - t.Errorf("Authenticate.Refresh() error = %v, wantErr %v", err, tt.wantErr) - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Authenticate.Refresh() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestAuthenticate_Authenticate(t *testing.T) { - secret := cryptutil.GenerateKey() - c, err := cryptutil.NewCipher(secret) - if err != nil { - t.Fatalf("expected to be able to create cipher: %v", err) - } - newSecret := cryptutil.GenerateKey() - c2, err := cryptutil.NewCipher(newSecret) - if err != nil { - t.Fatalf("expected to be able to create cipher: %v", err) - } - rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC() - vtProto, err := ptypes.TimestampProto(rt) - if err != nil { - t.Fatal("failed to parse timestamp") - } - - want := &sessions.SessionState{ - AccessToken: "token1234", - RefreshToken: "refresh4321", - RefreshDeadline: rt, - Email: "user@domain.com", - User: "user", - } - - goodReply := &pb.Session{ - AccessToken: "token1234", - RefreshToken: "refresh4321", - RefreshDeadline: vtProto, - Email: "user@domain.com", - User: "user"} - ciphertext, err := sessions.MarshalSession(want, c) - if err != nil { - t.Fatalf("expected to be encode session: %v", err) - } - - tests := []struct { - name string - cipher cryptutil.Cipher - code string - want *pb.Session - wantErr bool - }{ - {"good", c, ciphertext, goodReply, false}, - {"bad cipher", c2, ciphertext, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Authenticate{cipher: tt.cipher} - got, err := p.Authenticate(context.Background(), &pb.AuthenticateRequest{Code: tt.code}) - if (err != nil) != tt.wantErr { - t.Errorf("Authenticate.Authenticate() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Authenticate.Authenticate() = got: \n%vwant:\n%v", got, tt.want) - } - }) - } -} diff --git a/authenticate/handlers.go b/authenticate/handlers.go index dda857922..f5a9b94e0 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -2,12 +2,13 @@ package authenticate // import "github.com/pomerium/pomerium/authenticate" import ( "encoding/base64" + "encoding/json" + "errors" "fmt" "net/http" "net/url" "strings" - - "golang.org/x/xerrors" + "time" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" @@ -18,6 +19,7 @@ import ( ) // CSPHeaders are the content security headers added to the service's handlers +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src var CSPHeaders = map[string]string{ "Content-Security-Policy": "default-src 'none'; style-src 'self'" + " 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " + @@ -27,22 +29,24 @@ var CSPHeaders = map[string]string{ "Referrer-Policy": "Same-origin", } -// Handler returns the authenticate service's HTTP request multiplexer, and routes. +// Handler returns the authenticate service's HTTP multiplexer, and routes. func (a *Authenticate) Handler() http.Handler { // validation middleware chain c := middleware.NewChain() c = c.Append(middleware.SetHeaders(CSPHeaders)) - validate := c.Append(middleware.ValidateSignature(a.SharedKey)) - validate = validate.Append(middleware.ValidateRedirectURI(a.RedirectURL)) mux := http.NewServeMux() mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt)) - // Identity Provider (IdP) callback endpoints and callbacks - mux.Handle("/start", c.ThenFunc(a.OAuthStart)) + // Identity Provider (IdP) endpoints + mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart)) mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback)) - // authenticate-server endpoints - mux.Handle("/sign_in", validate.ThenFunc(a.SignIn)) - mux.Handle("/sign_out", validate.ThenFunc(a.SignOut)) // POST - // programmatic authentication endpoints + // Proxy service endpoints + validationMiddlewares := c.Append( + middleware.ValidateSignature(a.SharedKey), + middleware.ValidateRedirectURI(a.RedirectURL), + ) + mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn)) + mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST + // Direct user access endpoints mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken)) return mux } @@ -55,43 +59,46 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "User-agent: *\nDisallow: /") } -func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) error { - if session.RefreshPeriodExpired() { - session, err := a.provider.Refresh(r.Context(), session) - if err != nil { - return xerrors.Errorf("session refresh failed : %w", err) - } - if err = a.sessionStore.SaveSession(w, r, session); err != nil { - return xerrors.Errorf("failed saving refreshed session : %w", err) - } - } else { - valid, err := a.provider.Validate(r.Context(), session.IDToken) - if err != nil || !valid { - return xerrors.Errorf("session valid: %v : %w", valid, err) - } +func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) { + session, err := a.sessionStore.LoadSession(r) + if err != nil { + return nil, err } - return nil + err = session.Valid() + if err == nil { + return session, nil + } else if !errors.Is(err, sessions.ErrExpired) { + return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err) + } else { + return a.refresh(w, r, session) + } +} + +func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) { + newSession, err := a.provider.Refresh(r.Context(), s) + if err != nil { + return nil, fmt.Errorf("authenticate: refresh failed: %w", err) + } + if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { + return nil, fmt.Errorf("authenticate: refresh save failed: %w", err) + } + return newSession, nil + } // SignIn handles to authenticating a user. func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { - session, err := a.sessionStore.LoadSession(r) + session, err := a.loadExisting(w, r) if err != nil { - log.FromRequest(r).Debug().Err(err).Msg("no session loaded, restart auth") + log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session") a.sessionStore.ClearSession(w, r) a.OAuthStart(w, r) return } - // if a session already exists, authenticate it - if err := a.authenticate(w, r, session); err != nil { - httputil.ErrorResponse(w, r, err) - return - } if err := r.ParseForm(); err != nil { httputil.ErrorResponse(w, r, err) return } - state := r.Form.Get("state") if state == "" { httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil)) @@ -100,21 +107,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri")) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri parameter passed", http.StatusBadRequest, err)) + httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return } // encrypt session state as json blob encrypted, err := sessions.MarshalSession(session, a.cipher) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("couldn't marshall session", http.StatusInternalServerError, err)) + httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err)) return } - http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound) } func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string { - // error handled by go's mux stack + // ParseQuery err handled by go's mux stack params, _ := url.ParseQuery(redirectURL.RawQuery) params.Set("code", authCode) params.Set("state", state) @@ -122,8 +128,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string return redirectURL.String() } -// SignOut signs the user out by trying to revoke the user's remote identity session along with -// the associated local session state. Handles both GET and POST. +// SignOut signs the user out and attempts to revoke the user's identity session +// Handles both GET and POST. func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { httputil.ErrorResponse(w, r, err) @@ -156,7 +162,6 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { // OIDC : 3.1.2.1. Authentication Request nonce := fmt.Sprintf("%x", cryptutil.GenerateKey()) a.csrfStore.SetCSRF(w, r, nonce) - // Redirection URI to which the response will be sent. This URI MUST exactly // match one of the Redirection URI values for the Client pre-registered at // at your identity provider @@ -173,7 +178,6 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil)) return } - // State is the opaque value used to maintain state between the request and // the callback; contains both the nonce and redirect URI state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String()))) @@ -183,74 +187,69 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, signInURL, http.StatusFound) } -// OAuthCallback handles the callback from the identity provider. Displays an error page if there -// was an error. If successful, the user is redirected back to the proxy-service. +// OAuthCallback handles the callback from the identity provider. // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { redirect, err := a.getOAuthCallback(w, r) if err != nil { - httputil.ErrorResponse(w, r, xerrors.Errorf("oauth callback : %w", err)) + httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err)) return } // redirect back to the proxy-service via sign_in - http.Redirect(w, r, redirect, http.StatusFound) + http.Redirect(w, r, redirect.String(), http.StatusFound) } -func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { +func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { if err := r.ParseForm(); err != nil { - return "", httputil.Error("invalid signature", http.StatusBadRequest, err) + return nil, httputil.Error("invalid signature", http.StatusBadRequest, err) } // OIDC : 3.1.2.6. Authentication Error Response // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError - if errorString := r.Form.Get("error"); errorString != "" { - return "", httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider returned error: %v", errorString)) + if idpError := r.Form.Get("error"); idpError != "" { + return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError)) } - // OIDC : 3.1.2.5. Successful Authentication Response - // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse code := r.Form.Get("code") if code == "" { - return "", httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil) + return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil) } // validate the returned code with the identity provider session, err := a.provider.Authenticate(r.Context(), code) if err != nil { - return "", xerrors.Errorf("error redeeming authenticate code: %w", err) + return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } - // Opaque value used to maintain state between the request and the callback. // OIDC : 3.1.2.5. Successful Authentication Response - // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse + // Opaque value used to maintain state between the request and the callback. bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state")) if err != nil { - return "", xerrors.Errorf("failed decoding state: %w", err) + return nil, fmt.Errorf("failed decoding state: %w", err) } s := strings.SplitN(string(bytes), ":", 2) if len(s) != 2 { - return "", xerrors.Errorf("invalid state size: %v", len(s)) + return nil, fmt.Errorf("invalid state size: %d", len(s)) } - // state contains both our csrf nonce and the redirect uri + // state contains the csrf nonce and redirect uri nonce := s[0] redirect := s[1] c, err := a.csrfStore.GetCSRF(r) defer a.csrfStore.ClearCSRF(w, r) if err != nil || c.Value != nonce { - return "", xerrors.Errorf("csrf failure: %w", err) - + return nil, fmt.Errorf("csrf failure: %w", err) } redirectURL, err := urlutil.ParseAndValidateURL(redirect) if err != nil { - return "", httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err) + return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err) } // sanity check, we are redirecting back to the same subdomain right? if !middleware.SameDomain(redirectURL, a.RedirectURL) { - return "", httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil) + return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil) } if err := a.sessionStore.SaveSession(w, r, session); err != nil { - return "", xerrors.Errorf("failed saving new session: %w", err) + return nil, fmt.Errorf("failed saving new session: %w", err) } - return redirect, nil + return redirectURL, nil } // ExchangeToken takes an identity provider issued JWT as input ('id_token) @@ -263,16 +262,32 @@ func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) { } code := r.Form.Get("id_token") if code == "" { - httputil.ErrorResponse(w, r, httputil.Error("provider missing id token", http.StatusBadRequest, nil)) + httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil)) return } session, err := a.provider.IDTokenToSession(r.Context(), code) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("could not exchange identity for session", http.StatusInternalServerError, err)) + httputil.ErrorResponse(w, r, err) return } - if err := a.restStore.SaveSession(w, r, session); err != nil { - httputil.ErrorResponse(w, r, httputil.Error("failed returning new session", http.StatusInternalServerError, err)) + encToken, err := sessions.MarshalSession(session, a.cipher) + if err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return } + restSession := struct { + Token string + Expiry time.Time `json:",omitempty"` + }{ + Token: encToken, + Expiry: session.RefreshDeadline, + } + + jsonBytes, err := json.Marshal(restSession) + if err != nil { + httputil.ErrorResponse(w, r, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(jsonBytes) } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index ae549b381..a273e1aa0 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -68,22 +68,25 @@ func TestAuthenticate_SignIn(t *testing.T) { state string redirectURI string session sessions.SessionStore + restStore sessions.SessionStore provider identity.MockProvider cipher cryptutil.Cipher wantCode int }{ - {"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, - {"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"session refresh error", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"session save after refresh error", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, - {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, - {"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, - {"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + {"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, + {"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound}, + {"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound}, + {"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh + {"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + + // {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, + {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + {"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, + {"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + {"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // actually caught by go's handler, but we should keep the test. - {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError}, + {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, + {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -178,10 +181,10 @@ func TestAuthenticate_SignOut(t *testing.T) { wantCode int wantBody string }{ - {"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, - {"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"}, - {"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""}, - {"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, + {"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, + {"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"}, + {"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""}, + {"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -288,19 +291,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { want string wantCode int }{ - {"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound}, - {"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, - {"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, + {"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound}, + {"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, + {"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, {"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, + {"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, + {"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, + {"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, + {"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, + {"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, + {"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, + {"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -336,7 +339,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { } func TestAuthenticate_ExchangeToken(t *testing.T) { - cipher := &cryptutil.MockCipher{} tests := []struct { name string method string @@ -346,18 +348,18 @@ func TestAuthenticate_ExchangeToken(t *testing.T) { provider identity.MockProvider want string }{ - {"good", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, ""}, - {"could not exchange identity for session", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, "could not exchange identity for session"}, - {"missing token", http.MethodPost, "", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, "missing id token"}, - {"save error", http.MethodPost, "token", &sessions.MockSessionStore{SaveError: errors.New("error")}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, "failed returning new session"}, - {"malformed form", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, ""}, + {"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""}, + {"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""}, + {"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"}, + {"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""}, + {"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authenticate{ - restStore: tt.restStore, - cipher: tt.cipher, - provider: tt.provider, + cipher: tt.cipher, + provider: tt.provider, + sessionStore: tt.restStore, } form := url.Values{} if tt.idToken != "" { @@ -370,6 +372,7 @@ func TestAuthenticate_ExchangeToken(t *testing.T) { } r := httptest.NewRequest(tt.method, "/", strings.NewReader(rawForm)) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.Header.Set("Accept", "application/json") w := httptest.NewRecorder() diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 67f222c42..540e12b2e 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -21,7 +21,6 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" - pbAuthenticate "github.com/pomerium/pomerium/proto/authenticate" pbAuthorize "github.com/pomerium/pomerium/proto/authorize" "github.com/pomerium/pomerium/proxy" ) @@ -47,7 +46,7 @@ func main() { mux := http.NewServeMux() grpcServer := setupGRPCServer(opt) - _, err = newAuthenticateService(*opt, mux, grpcServer) + _, err = newAuthenticateService(*opt, mux) if err != nil { log.Fatal().Err(err).Msg("cmd/pomerium: authenticate") } @@ -62,7 +61,6 @@ func main() { log.Fatal().Err(err).Msg("cmd/pomerium: proxy") } if proxy != nil { - defer proxy.AuthenticateClient.Close() defer proxy.AuthorizeClient.Close() } @@ -82,7 +80,7 @@ func main() { os.Exit(0) } -func newAuthenticateService(opt config.Options, mux *http.ServeMux, rpc *grpc.Server) (*authenticate.Authenticate, error) { +func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) { if !config.IsAuthenticate(opt.Services) { return nil, nil } @@ -90,7 +88,6 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux, rpc *grpc.Se if err != nil { return nil, err } - pbAuthenticate.RegisterAuthenticatorServer(rpc, service) mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler()) return service, nil } @@ -164,7 +161,7 @@ func configToServerOptions(opt *config.Options) *httputil.ServerOptions { func setupMetrics(opt *config.Options) { if opt.MetricsAddr != "" { if handler, err := metrics.PrometheusHandler(); err != nil { - log.Error().Err(err).Msg("cmd/pomerium: couldn't start metrics server") + log.Error().Err(err).Msg("cmd/pomerium: metrics failed to start") } else { metrics.SetBuildInfo(opt.Services) metrics.RegisterInfoMetrics() diff --git a/cmd/pomerium/main_test.go b/cmd/pomerium/main_test.go index 4ffebf4c6..1d24b1cf5 100644 --- a/cmd/pomerium/main_test.go +++ b/cmd/pomerium/main_test.go @@ -21,9 +21,6 @@ import ( ) func Test_newAuthenticateService(t *testing.T) { - grpcAuth := middleware.NewSharedSecretCred("test") - grpcOpts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcAuth.ValidateRequest)} - grpcServer := grpc.NewServer(grpcOpts...) mux := http.NewServeMux() tests := []struct { @@ -56,7 +53,7 @@ func Test_newAuthenticateService(t *testing.T) { testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value")) } - _, err = newAuthenticateService(*testOpts, mux, grpcServer) + _, err = newAuthenticateService(*testOpts, mux) if (err != nil) != tt.wantErr { t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/docs/docs/quick-start/synology.md b/docs/docs/quick-start/synology.md index a2a68af76..271f3d7b1 100644 --- a/docs/docs/quick-start/synology.md +++ b/docs/docs/quick-start/synology.md @@ -176,7 +176,6 @@ Go to **Environment** tab. | SHARED_SECRET | output of `head -c32 /dev/urandom | base64` | | AUTHORIZE_SERVICE_URL | `https://localhost` | | AUTHENTICATE_SERVICE_URL | `https://authenticate.int.nas.example` | -| AUTHENTICATE_INTERNAL_URL | `https://localhost` | For a detailed explanation, and additional options, please refer to the [configuration variable docs]. Also note, though not covered in this guide, settings can be made via a mounted configuration file. diff --git a/docs/docs/reference/examples/docker/nginx.docker-compose.yml b/docs/docs/reference/examples/docker/nginx.docker-compose.yml index bbeadf072..c45c24647 100644 --- a/docs/docs/reference/examples/docker/nginx.docker-compose.yml +++ b/docs/docs/reference/examples/docker/nginx.docker-compose.yml @@ -48,7 +48,6 @@ services: - SERVICES=proxy # IMPORTANT! If you are running pomerium behind another ingress (loadbalancer/firewall/etc) # you must tell pomerium proxy how to communicate using an internal hostname for RPC - - AUTHENTICATE_INTERNAL_URL=https://pomerium-authenticate - AUTHORIZE_SERVICE_URL=https://pomerium-authorize # When communicating internally, rPC is going to get a name conflict expecting an external # facing certificate name (i.e. authenticate-service.local vs *.corp.example.com). diff --git a/docs/docs/reference/examples/kubernetes/kubernetes-config.yaml b/docs/docs/reference/examples/kubernetes/kubernetes-config.yaml index 91f88ee8f..73ae5428a 100644 --- a/docs/docs/reference/examples/kubernetes/kubernetes-config.yaml +++ b/docs/docs/reference/examples/kubernetes/kubernetes-config.yaml @@ -1,6 +1,5 @@ # Main configuration flags : https://www.pomerium.io/reference/ authenticate_service_url: https://authenticate.corp.beyondperimeter.com -authenticate_internal_url: https://pomerium-authenticate-service.default.svc.cluster.local authorize_service_url: https://pomerium-authorize-service.default.svc.cluster.local override_certificate_name: "*.corp.beyondperimeter.com" diff --git a/docs/docs/reference/reference.md b/docs/docs/reference/reference.md index 597cb678f..a53a803b1 100644 --- a/docs/docs/reference/reference.md +++ b/docs/docs/reference/reference.md @@ -146,7 +146,7 @@ Timeouts set the global server timeouts. For route-specific timeouts, see [polic ## GRPC Options -These settings control upstream connections to the Authorize and Authenticate services. +These settings control upstream connections to the Authorize service. ### GRPC Client Timeout @@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor | Config Key | Description | Required | | :--------------- | :---------------------------------------------------------------- | -------- | -| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ | -| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ | +| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ | +| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ | ### Jaeger @@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor | Config Key | Description | Required | | :-------------------------------- | :------------------------------------------ | -------- | -| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ | -| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ | +| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ | +| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ | #### Example @@ -464,16 +464,6 @@ Signing key is the base64 encoded key used to sign outbound requests. For more i Authenticate Service URL is the externally accessible URL for the authenticate service. -## Authenticate Internal Service URL - -- Environmental Variable: `AUTHENTICATE_INTERNAL_URL` -- Config File Key: `authenticate_internal_url` -- Type: `URL` -- Optional -- Example: `https://pomerium-authenticate-service.default.svc.cluster.local` - -Authenticate Internal Service URL is the internally routed dns name of the authenticate service. This setting is typically used with load balancers that do not gRPC, thus allowing you to specify an internally accessible name. - ## Authorize Service URL - Environmental Variable: `AUTHORIZE_SERVICE_URL` diff --git a/go.mod b/go.mod index 9e46f6584..0dcdc8412 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,6 @@ require ( golang.org/x/net v0.0.0-20190611141213-3f473d35a33a golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae // indirect - golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 google.golang.org/api v0.6.0 google.golang.org/appengine v1.6.1 // indirect google.golang.org/genproto v0.0.0-20190611190212-a7e196e89fd3 // indirect diff --git a/go.sum b/go.sum index 117ebe2c0..e06dfd01c 100644 --- a/go.sum +++ b/go.sum @@ -257,8 +257,6 @@ golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/api v0.3.2/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/internal/config/options.go b/internal/config/options.go index 3d8f64dae..cfb4c5e96 100644 --- a/internal/config/options.go +++ b/internal/config/options.go @@ -97,13 +97,6 @@ type Options struct { // (sudo) access including the ability to impersonate other users' access Administrators []string `mapstructure:"administrators"` - // AuthenticateInternalAddr is used override the routable destination of - // authenticate service's GRPC endpoint. - // NOTE: As many load balancers do not support externally routed gRPC so - // this may be an internal location. - AuthenticateInternalAddrString string `mapstructure:"authenticate_internal_url"` - AuthenticateInternalAddr *url.URL - // AuthorizeURL is the routable destination of the authorize service's // gRPC endpoint. NOTE: As many load balancers do not support // externally routed gRPC so this may be an internal location. @@ -246,13 +239,6 @@ func (o *Options) Validate() error { o.AuthorizeURL = u } - if o.AuthenticateInternalAddrString != "" { - u, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddrString) - if err != nil { - return fmt.Errorf("bad authenticate-internal-addr %s : %v", o.AuthenticateInternalAddrString, err) - } - o.AuthenticateInternalAddr = u - } if o.PolicyFile != "" { return errors.New("policy file setting is deprecated") } diff --git a/internal/config/options_test.go b/internal/config/options_test.go index 1890ee688..d8eddc98e 100644 --- a/internal/config/options_test.go +++ b/internal/config/options_test.go @@ -337,7 +337,7 @@ func TestNewOptions(t *testing.T) { func TestOptionsFromViper(t *testing.T) { opts := []cmp.Option{ - cmpopts.IgnoreFields(Options{}, "AuthenticateInternalAddr", "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"), + cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"), cmpopts.IgnoreFields(Policy{}, "Source", "Destination"), } @@ -361,21 +361,6 @@ func TestOptionsFromViper(t *testing.T) { "X-XSS-Protection": "1; mode=block", }}, false}, - {"good with authenticate internal url", - []byte(`{"authenticate_internal_url": "https://internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`), - &Options{ - AuthenticateInternalAddrString: "https://internal.example", - Policies: []Policy{{From: "https://from.example", To: "https://to.example"}}, - CookieName: "_pomerium", - CookieSecure: true, - CookieHTTPOnly: true, - Headers: map[string]string{ - "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "SAMEORIGIN", - "X-XSS-Protection": "1; mode=block", - }}, - false}, {"good disable header", []byte(`{"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`), &Options{ @@ -385,7 +370,6 @@ func TestOptionsFromViper(t *testing.T) { CookieHTTPOnly: true, Headers: map[string]string{}}, false}, - {"bad authenticate internal url", []byte(`{"authenticate_internal_url": "internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`), nil, true}, {"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true}, {"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true}, diff --git a/internal/cryptutil/encrypt.go b/internal/cryptutil/encrypt.go index b974be7d7..82d79647f 100644 --- a/internal/cryptutil/encrypt.go +++ b/internal/cryptutil/encrypt.go @@ -67,6 +67,18 @@ func NewCipher(secret []byte) (*XChaCha20Cipher, error) { }, nil } +// NewCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher. +func NewCipherFromBase64(s string) (*XChaCha20Cipher, error) { + decoded, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("cryptutil: invalid base64: %v", err) + } + if len(decoded) != 32 { + return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(decoded)) + } + return NewCipher(decoded) +} + // GenerateNonce generates a random nonce. // Panics if source of randomness fails. func (c *XChaCha20Cipher) GenerateNonce() []byte { diff --git a/internal/cryptutil/encrypt_test.go b/internal/cryptutil/encrypt_test.go index d46316600..eddd45a37 100644 --- a/internal/cryptutil/encrypt_test.go +++ b/internal/cryptutil/encrypt_test.go @@ -259,3 +259,26 @@ func TestNewCipher(t *testing.T) { }) } } + +func TestNewCipherFromBase64(t *testing.T) { + + tests := []struct { + name string + s string + wantErr bool + }{ + {"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false}, + {"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true}, + {"key too long", GenerateRandomString(33), true}, + {"bad base 64", string(GenerateKey()), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCipherFromBase64(tt.s) + if (err != nil) != tt.wantErr { + t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/internal/cryptutil/sign.go b/internal/cryptutil/sign.go index aa7ce6ac3..657ad3aff 100644 --- a/internal/cryptutil/sign.go +++ b/internal/cryptutil/sign.go @@ -1,5 +1,6 @@ package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( + "encoding/base64" "fmt" "sync" "time" @@ -48,15 +49,20 @@ type ES256Signer struct { NotBefore jwt.NumericDate `json:"nbf,omitempty"` } -// NewES256Signer creates an Elliptic Curve, NIST P-256 (aka secp256r1 aka prime256v1) JWT signer. +// NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer +// from a base64 encoded private key. // // RSA is not supported due to performance considerations of needing to sign each request. // Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune // to length extension attacks. // See also: // - https://cloud.google.com/iot/docs/how-tos/credentials/keys -func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) { - key, err := DecodePrivateKey(privKey) +func NewES256Signer(privKey, audience string) (*ES256Signer, error) { + decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey) + if err != nil { + return nil, err + } + key, err := DecodePrivateKey(decodedSigningKey) if err != nil { return nil, fmt.Errorf("cryptutil: parsing key failed %v", err) } diff --git a/internal/cryptutil/sign_test.go b/internal/cryptutil/sign_test.go index ff31afe9e..0b2f3f2bd 100644 --- a/internal/cryptutil/sign_test.go +++ b/internal/cryptutil/sign_test.go @@ -1,11 +1,12 @@ package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" import ( + "encoding/base64" "testing" ) func TestES256Signer(t *testing.T) { - signer, err := NewES256Signer([]byte(pemECPrivateKeyP256), "destination-url") + signer, err := NewES256Signer(base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "destination-url") if err != nil { t.Fatal(err) } @@ -25,12 +26,13 @@ func TestNewES256Signer(t *testing.T) { t.Parallel() tests := []struct { name string - privKey []byte + privKey string audience string wantErr bool }{ - {"working example", []byte(pemECPrivateKeyP256), "some-domain.com", false}, - {"bad private key", []byte(garbagePEM), "some-domain.com", true}, + {"working example", base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "some-domain.com", false}, + {"bad private key", base64.StdEncoding.EncodeToString([]byte(garbagePEM)), "some-domain.com", true}, + {"bad base64 key", garbagePEM, "some-domain.com", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/httputil/errors.go b/internal/httputil/errors.go index e05e6932c..27c17f414 100644 --- a/internal/httputil/errors.go +++ b/internal/httputil/errors.go @@ -2,20 +2,18 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil" import ( "encoding/json" + "errors" "fmt" "io" "net/http" - "golang.org/x/xerrors" - "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/templates" ) // Error formats creates a HTTP error with code, user friendly (and safe) error -// message. If nil or empty: -// HTTP status code defaults to 500. -// Message defaults to the text of the status code. +// message. If nil or empty, HTTP status code defaults to 500 and message +// defaults to the text of the status code. func Error(message string, code int, err error) error { if code == 0 { code = http.StatusInternalServerError @@ -45,7 +43,9 @@ func (e *httpError) Error() string { func (e *httpError) Unwrap() error { return e.Err } // Timeout reports whether this error represents a user debuggable error. -func (e *httpError) Debugable() bool { return e.Code == http.StatusUnauthorized } +func (e *httpError) Debugable() bool { + return e.Code == http.StatusUnauthorized || e.Code == http.StatusForbidden +} // ErrorResponse renders an error page given an error. If the error is a // http error from this package, a user friendly message is set, http status code, @@ -57,11 +57,12 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) { var requestID string var httpError *httpError // if this is an HTTPError, we can add some additional useful information - if xerrors.As(e, &httpError) { + if errors.As(e, &httpError) { canDebug = httpError.Debugable() statusCode = httpError.Code errorString = httpError.Message } + log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error") if id, ok := log.IDFromRequest(r); ok { @@ -71,7 +72,7 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) { var response struct { Error string `json:"error"` } - response.Error = e.Error() + response.Error = errorString writeJSONResponse(rw, statusCode, response) } else { rw.WriteHeader(statusCode) diff --git a/internal/identity/google.go b/internal/identity/google.go index c133e8dd9..faabb8b01 100644 --- a/internal/identity/google.go +++ b/internal/identity/google.go @@ -129,8 +129,7 @@ func (p *GoogleProvider) GetSignInURL(state string) string { // Authenticate creates an identity session with google from a authorization code, and follows up // call to the admin/group api to check what groups the user is in. -func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { - // convert authorization code into a token +func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { oauth2Token, err := p.oauth.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("identity/google: token exchange failed %v", err) @@ -153,7 +152,7 @@ func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessio // Refresh renews a user's session using an oidc refresh token withoutreprompting the user. // Group membership is also refreshed. // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens -func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { +func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { if s.RefreshToken == "" { return nil, errors.New("identity: missing refresh token") } @@ -180,7 +179,7 @@ func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.SessionState) // IDTokenToSession takes an identity provider issued JWT as input ('id_token') // and returns a session state. The provided token's audience ('aud') must // match Pomerium's client_id. -func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) { +func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) { idToken, err := p.verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("identity/google: could not verify id_token %v", err) @@ -200,7 +199,7 @@ func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string return nil, fmt.Errorf("identity/google: could not retrieve groups %v", err) } - return &sessions.SessionState{ + return &sessions.State{ IDToken: rawIDToken, RefreshDeadline: idToken.Expiry.Truncate(time.Second), Email: claims.Email, diff --git a/internal/identity/microsoft.go b/internal/identity/microsoft.go index e67b4d4a9..aa5cf22e0 100644 --- a/internal/identity/microsoft.go +++ b/internal/identity/microsoft.go @@ -74,7 +74,7 @@ func NewAzureProvider(p *Provider) (*AzureProvider, error) { // Authenticate creates an identity session with azure from a authorization code, and follows up // call to the groups api to check what groups the user is in. -func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { +func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { // convert authorization code into a token oauth2Token, err := p.oauth.Exchange(ctx, code) if err != nil { @@ -104,7 +104,7 @@ func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*session // IDTokenToSession takes an identity provider issued JWT as input ('id_token') // and returns a session state. The provided token's audience ('aud') must // match Pomerium's client_id. -func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) { +func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) { idToken, err := p.verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err) @@ -118,7 +118,7 @@ func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) return nil, fmt.Errorf("identity/microsoft: failed to parse id_token claims %v", err) } - return &sessions.SessionState{ + return &sessions.State{ IDToken: rawIDToken, RefreshDeadline: idToken.Expiry.Truncate(time.Second), Email: claims.Email, @@ -146,7 +146,7 @@ func (p *AzureProvider) GetSignInURL(state string) string { // Refresh renews a user's session using an oid refresh token without reprompting the user. // Group membership is also refreshed. // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens -func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { +func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { if s.RefreshToken == "" { return nil, errors.New("identity/microsoft: missing refresh token") } diff --git a/internal/identity/mock_provider.go b/internal/identity/mock_provider.go index cb173f7bd..eb6c2c636 100644 --- a/internal/identity/mock_provider.go +++ b/internal/identity/mock_provider.go @@ -8,25 +8,25 @@ import ( // MockProvider provides a mocked implementation of the providers interface. type MockProvider struct { - AuthenticateResponse sessions.SessionState + AuthenticateResponse sessions.State AuthenticateError error - IDTokenToSessionResponse sessions.SessionState + IDTokenToSessionResponse sessions.State IDTokenToSessionError error ValidateResponse bool ValidateError error - RefreshResponse *sessions.SessionState + RefreshResponse *sessions.State RefreshError error RevokeError error GetSignInURLResponse string } // Authenticate is a mocked providers function. -func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { +func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { return &mp.AuthenticateResponse, mp.AuthenticateError } // IDTokenToSession is a mocked providers function. -func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.SessionState, error) { +func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.State, error) { return &mp.IDTokenToSessionResponse, mp.IDTokenToSessionError } @@ -36,7 +36,7 @@ func (mp MockProvider) Validate(ctx context.Context, s string) (bool, error) { } // Refresh is a mocked providers function. -func (mp MockProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { +func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { return mp.RefreshResponse, mp.RefreshError } diff --git a/internal/identity/okta.go b/internal/identity/okta.go index a101f5fd5..81ddb921d 100644 --- a/internal/identity/okta.go +++ b/internal/identity/okta.go @@ -91,7 +91,7 @@ type accessToken struct { // Refresh renews a user's session using an oid refresh token without reprompting the user. // Group membership is also refreshed. If configured properly, Okta is we can configure the access token // to include group membership claims which allows us to avoid a follow up oauth2 call. -func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { +func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { if s.RefreshToken == "" { return nil, errors.New("identity/okta: missing refresh token") } diff --git a/internal/identity/onelogin.go b/internal/identity/onelogin.go index a5dac69f6..c95c1fe32 100644 --- a/internal/identity/onelogin.go +++ b/internal/identity/onelogin.go @@ -93,7 +93,7 @@ func (p *OneLoginProvider) GetSignInURL(state string) string { // Refresh renews a user's session using an oid refresh token without reprompting the user. // Group membership is also refreshed. // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens -func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { +func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { if s.RefreshToken == "" { return nil, errors.New("identity/microsoft: missing refresh token") } diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 394e69566..ef28690c3 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -45,10 +45,10 @@ type UserGrouper interface { // Authenticator is an interface representing the ability to authenticate with an identity provider. type Authenticator interface { - Authenticate(context.Context, string) (*sessions.SessionState, error) - IDTokenToSession(context.Context, string) (*sessions.SessionState, error) + Authenticate(context.Context, string) (*sessions.State, error) + IDTokenToSession(context.Context, string) (*sessions.State, error) Validate(context.Context, string) (bool, error) - Refresh(context.Context, *sessions.SessionState) (*sessions.SessionState, error) + Refresh(context.Context, *sessions.State) (*sessions.State, error) Revoke(string) error GetSignInURL(state string) string } @@ -131,7 +131,7 @@ func (p *Provider) Validate(ctx context.Context, idToken string) (bool, error) { // IDTokenToSession takes an identity provider issued JWT as input ('id_token') // and returns a session state. The provided token's audience ('aud') must // match Pomerium's client_id. -func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) { +func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) { idToken, err := p.verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("identity: could not verify id_token: %v", err) @@ -146,7 +146,7 @@ func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*se return nil, fmt.Errorf("identity: failed to parse id_token claims: %v", err) } - return &sessions.SessionState{ + return &sessions.State{ IDToken: rawIDToken, User: idToken.Subject, RefreshDeadline: idToken.Expiry.Truncate(time.Second), @@ -157,7 +157,7 @@ func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*se } // Authenticate creates a session with an identity provider from a authorization code -func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { +func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) { // exchange authorization for a oidc token oauth2Token, err := p.oauth.Exchange(ctx, code) if err != nil { @@ -181,7 +181,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.Ses // Refresh renews a user's session using therefresh_token without reprompting // the user. If supported, group membership is also refreshed. // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens -func (p *Provider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { +func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) { if s.RefreshToken == "" { return nil, errors.New("identity: missing refresh token") } diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 0bb4328e0..d9d46f2d4 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/urlutil" "golang.org/x/net/publicsuffix" ) @@ -70,7 +71,7 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err)) return } - redirectURI, err := url.Parse(r.Form.Get("redirect_uri")) + redirectURI, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri")) if err != nil { httputil.ErrorResponse(w, r, httputil.Error("bad redirect_uri", http.StatusBadRequest, err)) return @@ -131,7 +132,7 @@ func ValidateHost(validHost func(host string) bool) func(next http.Handler) http defer span.End() if !validHost(r.Host) { - httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not a known route.", r.Host), http.StatusNotFound, nil)) + httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil)) return } next.ServeHTTP(w, r.WithContext(ctx)) @@ -168,7 +169,7 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool { if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" { return false } - _, err := url.Parse(redirectURI) + _, err := urlutil.ParseAndValidateURL(redirectURI) if err != nil { return false } diff --git a/internal/middleware/reverse_proxy_test.go b/internal/middleware/reverse_proxy_test.go index d29d8fd33..ea83d906a 100644 --- a/internal/middleware/reverse_proxy_test.go +++ b/internal/middleware/reverse_proxy_test.go @@ -1,6 +1,7 @@ package middleware // import "github.com/pomerium/pomerium/internal/middleware" import ( + "encoding/base64" "fmt" "net/http" "net/http/httptest" @@ -40,7 +41,7 @@ func TestSignRequest(t *testing.T) { }) rr := httptest.NewRecorder() - signer, err := cryptutil.NewES256Signer([]byte(exampleKey), "audience") + signer, err := cryptutil.NewES256Signer(base64.StdEncoding.EncodeToString([]byte(exampleKey)), "audience") if err != nil { t.Fatal(err) } diff --git a/internal/sessions/cookie_store.go b/internal/sessions/cookie_store.go index ce6e973bf..92c0452f2 100644 --- a/internal/sessions/cookie_store.go +++ b/internal/sessions/cookie_store.go @@ -1,7 +1,6 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions" import ( - "errors" "fmt" "net" "net/http" @@ -11,15 +10,17 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" ) -// ErrInvalidSession is an error for invalid sessions. -var ErrInvalidSession = errors.New("internal/sessions: invalid session") - // ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if // the cookie is multi-part or not. This constant *should not* be valid // base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives const ChunkedCanaryByte byte = '%' +// DefaultBearerTokenHeader is default header name for the authorization bearer +// token header as defined in rfc2617 +// https://tools.ietf.org/html/rfc6750#section-2.1 +const DefaultBearerTokenHeader = "Authorization" + // MaxChunkSize sets the upper bound on a cookie chunks payload value. // Note, this should be lower than the actual cookie's max size (4096 bytes) // which includes metadata. @@ -29,39 +30,27 @@ const MaxChunkSize = 3800 // set to prevent any abuse. const MaxNumChunks = 5 -// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie -type CSRFStore interface { - SetCSRF(http.ResponseWriter, *http.Request, string) - GetCSRF(*http.Request) (*http.Cookie, error) - ClearCSRF(http.ResponseWriter, *http.Request) -} - -// SessionStore has the functions for setting, getting, and clearing the Session cookie -type SessionStore interface { - ClearSession(http.ResponseWriter, *http.Request) - LoadSession(*http.Request) (*SessionState, error) - SaveSession(http.ResponseWriter, *http.Request, *SessionState) error -} - // CookieStore represents all the cookie related configurations type CookieStore struct { - Name string - CookieCipher cryptutil.Cipher - CookieExpire time.Duration - CookieRefresh time.Duration - CookieSecure bool - CookieHTTPOnly bool - CookieDomain string + Name string + CookieCipher cryptutil.Cipher + CookieExpire time.Duration + CookieRefresh time.Duration + CookieSecure bool + CookieHTTPOnly bool + CookieDomain string + BearerTokenHeader string } // CookieStoreOptions holds options for CookieStore type CookieStoreOptions struct { - Name string - CookieSecure bool - CookieHTTPOnly bool - CookieDomain string - CookieExpire time.Duration - CookieCipher cryptutil.Cipher + Name string + CookieSecure bool + CookieHTTPOnly bool + CookieDomain string + BearerTokenHeader string + CookieExpire time.Duration + CookieCipher cryptutil.Cipher } // NewCookieStore returns a new session with ciphers for each of the cookie secrets @@ -72,23 +61,28 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) { if opts.CookieCipher == nil { return nil, fmt.Errorf("internal/sessions: cipher cannot be nil") } + if opts.BearerTokenHeader == "" { + opts.BearerTokenHeader = DefaultBearerTokenHeader + } + return &CookieStore{ - Name: opts.Name, - CookieSecure: opts.CookieSecure, - CookieHTTPOnly: opts.CookieHTTPOnly, - CookieDomain: opts.CookieDomain, - CookieExpire: opts.CookieExpire, - CookieCipher: opts.CookieCipher, + Name: opts.Name, + CookieSecure: opts.CookieSecure, + CookieHTTPOnly: opts.CookieHTTPOnly, + CookieDomain: opts.CookieDomain, + CookieExpire: opts.CookieExpire, + CookieCipher: opts.CookieCipher, + BearerTokenHeader: opts.BearerTokenHeader, }, nil } -func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { +func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { domain := req.Host - if name == s.csrfName() { + if name == cs.csrfName() { domain = req.Host - } else if s.CookieDomain != "" { - domain = s.CookieDomain + } else if cs.CookieDomain != "" { + domain = cs.CookieDomain } else { domain = splitDomain(domain) } @@ -101,8 +95,8 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e Value: value, Path: "/", Domain: domain, - HttpOnly: s.CookieHTTPOnly, - Secure: s.CookieSecure, + HttpOnly: cs.CookieHTTPOnly, + Secure: cs.CookieSecure, } // only set an expiration if we want one, otherwise default to non perm session based if expiration != 0 { @@ -111,22 +105,20 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e return c } -func (s *CookieStore) csrfName() string { - return fmt.Sprintf("%s_csrf", s.Name) +func (cs *CookieStore) csrfName() string { + return fmt.Sprintf("%s_csrf", cs.Name) } // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. -func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { - return s.makeCookie(req, s.Name, value, expiration, now) +func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + return cs.makeCookie(req, cs.Name, value, expiration, now) } -// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time. -// CSRF cookies should be scoped to the actual domain -func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { - return s.makeCookie(req, s.csrfName(), value, expiration, now) +func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + return cs.makeCookie(req, cs.csrfName(), value, expiration, now) } -func (s *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { +func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { if len(cookie.String()) <= MaxChunkSize { http.SetCookie(w, cookie) return @@ -142,9 +134,9 @@ func (s *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i) nc.Value = c } + fmt.Println(i) http.SetCookie(w, &nc) } - } func chunk(s string, size int) []string { @@ -159,43 +151,54 @@ func chunk(s string, size int) []string { } // ClearCSRF clears the CSRF cookie from the request -func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) { - http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) +func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) { + http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) } // SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request -func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now())) +func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now())) } // GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request -func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) { - return req.Cookie(s.csrfName()) +func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) { + c, err := req.Cookie(cs.csrfName()) + if err != nil { + return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context + } + return c, nil } // ClearSession clears the session cookie from a request -func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) { - http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now())) +func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) { + http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now())) } -func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) { - s.setCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now())) +func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) { + cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now())) } -// LoadSession returns a SessionState from the cookie in the request. -func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) { - c, err := req.Cookie(s.Name) +func loadBearerToken(r *http.Request, headerKey string) string { + authHeader := r.Header.Get(headerKey) + split := strings.Split(authHeader, "Bearer") + if authHeader == "" || len(split) != 2 { + return "" + } + return strings.TrimSpace(split[1]) +} + +func loadChunkedCookie(r *http.Request, cookieName string) string { + c, err := r.Cookie(cookieName) if err != nil { - return nil, err // http.ErrNoCookie + return "" } cipherText := c.Value - // if the first byte is our canary byte, we need to handle the multipart bit if []byte(c.Value)[0] == ChunkedCanaryByte { var b strings.Builder fmt.Fprintf(&b, "%s", cipherText[1:]) - for i := 1; i < MaxNumChunks; i++ { - next, err := req.Cookie(fmt.Sprintf("%s_%d", s.Name, i)) + for i := 1; i <= MaxNumChunks; i++ { + next, err := r.Cookie(fmt.Sprintf("%s_%d", cookieName, i)) if err != nil { break // break if we can't find the next cookie } @@ -203,20 +206,32 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) { } cipherText = b.String() } - session, err := UnmarshalSession(cipherText, s.CookieCipher) + return cipherText +} + +// LoadSession returns a State from the cookie in the request. +func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) { + cipherText := loadChunkedCookie(req, cs.Name) + if cipherText == "" { + cipherText = loadBearerToken(req, cs.BearerTokenHeader) + } + if cipherText == "" { + return nil, ErrEmptySession + } + session, err := UnmarshalSession(cipherText, cs.CookieCipher) if err != nil { - return nil, ErrInvalidSession + return nil, err } return session, nil } // SaveSession saves a session state to a request sessions. -func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error { - value, err := MarshalSession(sessionState, s.CookieCipher) +func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error { + value, err := MarshalSession(s, cs.CookieCipher) if err != nil { return err } - s.setSessionCookie(w, req, value) + cs.setSessionCookie(w, req, value) return nil } diff --git a/internal/sessions/cookie_store_test.go b/internal/sessions/cookie_store_test.go index 53dc98394..64e61b228 100644 --- a/internal/sessions/cookie_store_test.go +++ b/internal/sessions/cookie_store_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/pomerium/pomerium/internal/cryptutil" ) @@ -49,30 +50,33 @@ func TestNewCookieStore(t *testing.T) { }{ {"good", &CookieStoreOptions{ - Name: "_cookie", - CookieSecure: true, - CookieHTTPOnly: true, - CookieDomain: "pomerium.io", - CookieExpire: 10 * time.Second, - CookieCipher: cipher, + Name: "_cookie", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher, + BearerTokenHeader: "Authorization", }, &CookieStore{ - Name: "_cookie", - CookieSecure: true, - CookieHTTPOnly: true, - CookieDomain: "pomerium.io", - CookieExpire: 10 * time.Second, - CookieCipher: cipher, + Name: "_cookie", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher, + BearerTokenHeader: "Authorization", }, false}, {"missing name", &CookieStoreOptions{ - Name: "", - CookieSecure: true, - CookieHTTPOnly: true, - CookieDomain: "pomerium.io", - CookieExpire: 10 * time.Second, - CookieCipher: cipher, + Name: "", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher, + BearerTokenHeader: "Authorization", }, nil, true}, @@ -95,8 +99,12 @@ func TestNewCookieStore(t *testing.T) { t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want) + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(cryptutil.XChaCha20Cipher{}), + } + + if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { + t.Errorf("NewCookieStore() = %s", diff) } }) } @@ -211,15 +219,15 @@ func TestCookieStore_SaveSession(t *testing.T) { t.Fatal(err) } tests := []struct { - name string - sessionState *SessionState - cipher cryptutil.Cipher - wantErr bool - wantLoadErr bool + name string + State *State + cipher cryptutil.Cipher + wantErr bool + wantLoadErr bool }{ - {"good", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, - {"bad cipher", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true}, - {"huge cookie", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, + {"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, + {"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true}, + {"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -234,12 +242,12 @@ func TestCookieStore_SaveSession(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr { + if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr { t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) } r = httptest.NewRequest("GET", "/", nil) for _, cookie := range w.Result().Cookies() { - t.Log(cookie) + // t.Log(cookie) r.AddCookie(cookie) } @@ -248,8 +256,10 @@ func TestCookieStore_SaveSession(t *testing.T) { t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr) return } - if err == nil && !reflect.DeepEqual(state, tt.sessionState) { - t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState) + if err == nil { + if diff := cmp.Diff(state, tt.State); diff != "" { + t.Errorf("CookieStore.LoadSession() got = %s", diff) + } } }) } @@ -291,18 +301,18 @@ func TestMockSessionStore(t *testing.T) { tests := []struct { name string mockCSRF *MockSessionStore - saveSession *SessionState + saveSession *State wantLoadErr bool wantSaveErr bool }{ {"basic", &MockSessionStore{ ResponseSession: "test", - Session: &SessionState{AccessToken: "AccessToken"}, + Session: &State{AccessToken: "AccessToken"}, SaveError: nil, LoadError: nil, }, - &SessionState{AccessToken: "AccessToken"}, + &State{AccessToken: "AccessToken"}, false, false}, } diff --git a/internal/sessions/mock_store.go b/internal/sessions/mock_store.go index 8ac0585b5..bbed23b1f 100644 --- a/internal/sessions/mock_store.go +++ b/internal/sessions/mock_store.go @@ -29,7 +29,7 @@ func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) { // MockSessionStore is a mock implementation of the SessionStore interface type MockSessionStore struct { ResponseSession string - Session *SessionState + Session *State SaveError error LoadError error } @@ -40,11 +40,11 @@ func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) { } // LoadSession returns the session and a error -func (ms MockSessionStore) LoadSession(*http.Request) (*SessionState, error) { +func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) { return ms.Session, ms.LoadError } // SaveSession returns a save error. -func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error { +func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *State) error { return ms.SaveError } diff --git a/internal/sessions/rest_store.go b/internal/sessions/rest_store.go deleted file mode 100644 index f2eadd3ff..000000000 --- a/internal/sessions/rest_store.go +++ /dev/null @@ -1,106 +0,0 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "strings" - "time" - - "github.com/pomerium/pomerium/internal/cryptutil" -) - -// DefaultBearerTokenHeader is default header name for the authorization bearer -// token header as defined in rfc2617 -// https://tools.ietf.org/html/rfc6750#section-2.1 -const DefaultBearerTokenHeader = "Authorization" - -// RestStore is a session store suitable for REST -type RestStore struct { - Name string - Cipher cryptutil.Cipher - // Expire time.Duration -} - -// RestStoreOptions contains the options required to build a new RestStore. -type RestStoreOptions struct { - Name string - Cipher cryptutil.Cipher - // Expire time.Duration -} - -// NewRestStore creates a new RestStore from a set of RestStoreOptions. -func NewRestStore(opts *RestStoreOptions) (*RestStore, error) { - if opts.Name == "" { - opts.Name = DefaultBearerTokenHeader - } - if opts.Cipher == nil { - return nil, fmt.Errorf("internal/sessions: cipher cannot be nil") - } - return &RestStore{ - Name: opts.Name, - // Expire: opts.Expire, - Cipher: opts.Cipher, - }, nil -} - -// ClearSession functions differently because REST is stateless, we instead -// inform the client that this token is no longer valid. -// https://tools.ietf.org/html/rfc6750 -func (s *RestStore) ClearSession(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - errMsg := ` - { - "error": "invalid_token", - "token_type": "Bearer", - "error_description": "The token has expired." - }` - w.Write([]byte(errMsg)) -} - -// LoadSession attempts to load a pomerium session from a Bearer Token set -// in the authorization header. -func (s *RestStore) LoadSession(r *http.Request) (*SessionState, error) { - authHeader := r.Header.Get(s.Name) - split := strings.Split(authHeader, "Bearer") - if authHeader == "" || len(split) != 2 { - return nil, errors.New("internal/sessions: no bearer token header found") - } - token := strings.TrimSpace(split[1]) - session, err := UnmarshalSession(token, s.Cipher) - if err != nil { - return nil, err - } - return session, nil -} - -// RestStoreResponse is the JSON struct returned to the client. -type RestStoreResponse struct { - // Token is the encrypted pomerium session that can be used to - // programmatically authenticate with pomerium. - Token string - // In addition to the token, non-sensitive meta data is returned to help - // the client manage token renewals. - Expiry time.Time -} - -// SaveSession returns an encrypted pomerium session as a JSON object with -// associated, non sensitive meta-data like -func (s *RestStore) SaveSession(w http.ResponseWriter, r *http.Request, sessionState *SessionState) error { - encToken, err := MarshalSession(sessionState, s.Cipher) - if err != nil { - return err - } - jsonBytes, err := json.Marshal( - &RestStoreResponse{ - Token: encToken, - Expiry: sessionState.RefreshDeadline, - }) - if err != nil { - return fmt.Errorf("internal/sessions: couldn't marshal token struct: %v", err) - } - w.Header().Set("Content-Type", "application/json") - w.Write(jsonBytes) - return nil -} diff --git a/internal/sessions/rest_store_test.go b/internal/sessions/rest_store_test.go deleted file mode 100644 index 8f5943f6f..000000000 --- a/internal/sessions/rest_store_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package sessions - -import ( - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/internal/cryptutil" -) - -func TestRestStore_SaveSession(t *testing.T) { - now := time.Date(2008, 1, 8, 17, 5, 5, 0, time.UTC) - - tests := []struct { - name string - optionsName string - optionsCipher cryptutil.Cipher - sessionState *SessionState - wantErr bool - wantSaveResponse string - }{ - {"good", "Authenticate", &cryptutil.MockCipher{MarshalResponse: "test"}, &SessionState{RefreshDeadline: now}, false, `{"Token":"test","Expiry":"2008-01-08T17:05:05Z"}`}, - {"bad session marshal", "Authenticate", &cryptutil.MockCipher{MarshalError: errors.New("error")}, &SessionState{RefreshDeadline: now}, true, ""}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s, err := NewRestStore( - &RestStoreOptions{ - Name: tt.optionsName, - Cipher: tt.optionsCipher, - }) - if err != nil { - t.Fatalf("NewRestStore err %v", err) - } - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr { - t.Errorf("RestStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) - } - resp := w.Result() - body, _ := ioutil.ReadAll(resp.Body) - if diff := cmp.Diff(string(body), tt.wantSaveResponse); diff != "" { - t.Errorf("RestStore.SaveSession() got / want diff \n%s\n", diff) - } - }) - } -} - -func TestNewRestStore(t *testing.T) { - - tests := []struct { - name string - optionsName string - optionsCipher cryptutil.Cipher - wantErr bool - }{ - {"good", "Authenticate", &cryptutil.MockCipher{}, false}, - {"good default to authenticate", "", &cryptutil.MockCipher{}, false}, - {"empty cipher", "Authenticate", nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewRestStore( - &RestStoreOptions{ - Name: tt.optionsName, - Cipher: tt.optionsCipher, - }) - if (err != nil) != tt.wantErr { - t.Errorf("NewRestStore() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func TestRestStore_ClearSession(t *testing.T) { - tests := []struct { - name string - expectedStatus int - }{ - {"always returns reset!", http.StatusUnauthorized}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &RestStore{Name: "Authenticate", Cipher: &cryptutil.MockCipher{}} - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - s.ClearSession(w, r) - resp := w.Result() - if diff := cmp.Diff(resp.StatusCode, tt.expectedStatus); diff != "" { - t.Errorf("RestStore.ClearSession() got / want diff \n%s\n", diff) - } - - }) - } -} - -func TestRestStore_LoadSession(t *testing.T) { - - tests := []struct { - name string - optionsName string - optionsCipher cryptutil.Cipher - token string - wantErr bool - }{ - {"good", "Authorization", &cryptutil.MockCipher{}, "test", false}, - {"empty auth header", "", &cryptutil.MockCipher{}, "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &RestStore{ - Name: tt.optionsName, - Cipher: tt.optionsCipher, - } - - r := httptest.NewRequest(http.MethodGet, "/", nil) - - if tt.optionsName != "" { - r.Header.Set(tt.optionsName, fmt.Sprintf(("Bearer %s"), tt.token)) - - } - _, err := s.LoadSession(r) - if (err != nil) != tt.wantErr { - t.Errorf("RestStore.LoadSession() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} diff --git a/internal/sessions/session_state.go b/internal/sessions/state.go similarity index 74% rename from internal/sessions/session_state.go rename to internal/sessions/state.go index 41e9df2b5..bc17c4c36 100644 --- a/internal/sessions/session_state.go +++ b/internal/sessions/state.go @@ -3,7 +3,6 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions" import ( "encoding/base64" "encoding/json" - "errors" "fmt" "strings" "time" @@ -11,13 +10,11 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" ) -var ( - // ErrLifetimeExpired is an error for the lifetime deadline expiring - ErrLifetimeExpired = errors.New("user lifetime expired") -) +// ErrExpired is an error for a expired sessions. +var ErrExpired = fmt.Errorf("internal/sessions: expired session") -// SessionState is our object that keeps track of a user's session state -type SessionState struct { +// State is our object that keeps track of a user's session state +type State struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` IDToken string `json:"id_token"` @@ -31,18 +28,31 @@ type SessionState struct { ImpersonateGroups []string } -// RefreshPeriodExpired returns true if the refresh period has expired -func (s *SessionState) RefreshPeriodExpired() bool { - return isExpired(s.RefreshDeadline) +// Valid returns an error if the users's session state is not valid. +func (s *State) Valid() error { + if s.Expired() { + return ErrExpired + } + return nil +} + +// ForceRefresh sets the refresh deadline to now. +func (s *State) ForceRefresh() { + s.RefreshDeadline = time.Now().Truncate(time.Second) +} + +// Expired returns true if the refresh period has expired +func (s *State) Expired() bool { + return s.RefreshDeadline.Before(time.Now()) } // Impersonating returns if the request is impersonating. -func (s *SessionState) Impersonating() bool { +func (s *State) Impersonating() bool { return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0 } // RequestEmail is the email to make the request as. -func (s *SessionState) RequestEmail() string { +func (s *State) RequestEmail() string { if s.ImpersonateEmail != "" { return s.ImpersonateEmail } @@ -51,7 +61,7 @@ func (s *SessionState) RequestEmail() string { // RequestGroups returns the groups of the Groups making the request; uses // impersonating user if set. -func (s *SessionState) RequestGroups() string { +func (s *State) RequestGroups() string { if len(s.ImpersonateGroups) != 0 { return strings.Join(s.ImpersonateGroups, ",") } @@ -68,7 +78,7 @@ type idToken struct { } // IssuedAt parses the IDToken's issue date and returns a valid go time.Time. -func (s *SessionState) IssuedAt() (time.Time, error) { +func (s *State) IssuedAt() (time.Time, error) { payload, err := parseJWT(s.IDToken) if err != nil { return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err) @@ -80,13 +90,9 @@ func (s *SessionState) IssuedAt() (time.Time, error) { return time.Time(token.IssuedAt), nil } -func isExpired(t time.Time) bool { - return t.Before(time.Now()) -} - // MarshalSession marshals the session state as JSON, encrypts the JSON using the // given cipher, and base64-encodes the result -func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) { +func MarshalSession(s *State, c cryptutil.Cipher) (string, error) { v, err := c.Marshal(s) if err != nil { return "", err @@ -96,8 +102,8 @@ func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) { // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct -func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) { - s := &SessionState{} +func UnmarshalSession(value string, c cryptutil.Cipher) (*State, error) { + s := &State{} err := c.Unmarshal(value, s) if err != nil { return nil, err @@ -105,11 +111,6 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) { return s, nil } -// ExtendDeadline returns the time extended by a given duration, truncated by second -func ExtendDeadline(ttl time.Duration) time.Time { - return time.Now().Add(ttl).Truncate(time.Second) -} - func parseJWT(p string) ([]byte, error) { parts := strings.Split(p, ".") if len(parts) < 2 { diff --git a/internal/sessions/session_state_test.go b/internal/sessions/state_test.go similarity index 75% rename from internal/sessions/session_state_test.go rename to internal/sessions/state_test.go index 6d4f55e72..eaca3ff46 100644 --- a/internal/sessions/session_state_test.go +++ b/internal/sessions/state_test.go @@ -11,14 +11,14 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" ) -func TestSessionStateSerialization(t *testing.T) { +func TestStateSerialization(t *testing.T) { secret := cryptutil.GenerateKey() c, err := cryptutil.NewCipher(secret) if err != nil { t.Fatalf("expected to be able to create cipher: %v", err) } - want := &SessionState{ + want := &State{ AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), @@ -43,41 +43,21 @@ func TestSessionStateSerialization(t *testing.T) { } } -func TestSessionStateExpirations(t *testing.T) { - session := &SessionState{ +func TestStateExpirations(t *testing.T) { + session := &State{ AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(-1 * time.Hour), Email: "user@domain.com", User: "user", } - if !session.RefreshPeriodExpired() { + if !session.Expired() { t.Errorf("expected lifetime period to be expired") } } -func TestExtendDeadline(t *testing.T) { - // tons of wiggle room here - now := time.Now().Truncate(time.Second) - tests := []struct { - name string - ttl time.Duration - want time.Time - }{ - {"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)}, - {"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) { - t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSessionState_IssuedAt(t *testing.T) { +func TestState_IssuedAt(t *testing.T) { t.Parallel() tests := []struct { name string @@ -91,20 +71,20 @@ func TestSessionState_IssuedAt(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &SessionState{IDToken: tt.IDToken} + s := &State{IDToken: tt.IDToken} got, err := s.IssuedAt() if (err != nil) != tt.wantErr { - t.Errorf("SessionState.IssuedAt() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("State.IssuedAt() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("SessionState.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) + t.Errorf("State.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) } }) } } -func TestSessionState_Impersonating(t *testing.T) { +func TestState_Impersonating(t *testing.T) { t.Parallel() tests := []struct { name string @@ -123,20 +103,20 @@ func TestSessionState_Impersonating(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &SessionState{ + s := &State{ Email: tt.Email, Groups: tt.Groups, ImpersonateEmail: tt.ImpersonateEmail, ImpersonateGroups: tt.ImpersonateGroups, } if got := s.Impersonating(); got != tt.want { - t.Errorf("SessionState.Impersonating() = %v, want %v", got, tt.want) + t.Errorf("State.Impersonating() = %v, want %v", got, tt.want) } if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail { - t.Errorf("SessionState.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail) + t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail) } if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups { - t.Errorf("SessionState.v() = %v, want %v", gotGroups, tt.wantResponseGroups) + t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups) } }) } @@ -154,11 +134,11 @@ func TestMarshalSession(t *testing.T) { } tests := []struct { name string - s *SessionState + s *State wantErr bool }{ - {"simple", &SessionState{}, false}, - {"too big", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString)}, false}, + {"simple", &State{}, false}, + {"too big", &State{AccessToken: fmt.Sprintf("%x", hugeString)}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -179,3 +159,45 @@ func TestMarshalSession(t *testing.T) { }) } } + +func TestState_Valid(t *testing.T) { + + tests := []struct { + name string + RefreshDeadline time.Time + wantErr bool + }{ + {" good", time.Now().Add(10 * time.Second), false}, + {" expired", time.Now().Add(-10 * time.Second), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &State{ + RefreshDeadline: tt.RefreshDeadline, + } + if err := s.Valid(); (err != nil) != tt.wantErr { + t.Errorf("State.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestState_ForceRefresh(t *testing.T) { + tests := []struct { + name string + RefreshDeadline time.Time + }{ + {"good", time.Now().Truncate(time.Second)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &State{ + RefreshDeadline: tt.RefreshDeadline, + } + s.ForceRefresh() + if s.RefreshDeadline != tt.RefreshDeadline { + t.Errorf("refresh deadline not updated") + } + }) + } +} diff --git a/internal/sessions/store.go b/internal/sessions/store.go new file mode 100644 index 000000000..9ba2bad8f --- /dev/null +++ b/internal/sessions/store.go @@ -0,0 +1,26 @@ +package sessions // import "github.com/pomerium/pomerium/internal/sessions" + +import ( + "errors" + "net/http" +) + +// ErrEmptySession is an error for an empty sessions. +var ErrEmptySession = errors.New("internal/sessions: empty session") + +// ErrEmptyCSRF is an error for an empty sessions. +var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf") + +// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie +type CSRFStore interface { + SetCSRF(http.ResponseWriter, *http.Request, string) + GetCSRF(*http.Request) (*http.Cookie, error) + ClearCSRF(http.ResponseWriter, *http.Request) +} + +// SessionStore has the functions for setting, getting, and clearing the Session cookie +type SessionStore interface { + ClearSession(http.ResponseWriter, *http.Request) + LoadSession(*http.Request) (*State, error) + SaveSession(http.ResponseWriter, *http.Request, *State) error +} diff --git a/internal/telemetry/metrics/const.go b/internal/telemetry/metrics/const.go index 2cd2bb610..b55197af1 100644 --- a/internal/telemetry/metrics/const.go +++ b/internal/telemetry/metrics/const.go @@ -8,12 +8,12 @@ import ( // The following tags are applied to stats recorded by this package. var ( - TagKeyHTTPMethod tag.Key = tag.MustNewKey("http_method") - TagKeyService tag.Key = tag.MustNewKey("service") - TagKeyGRPCService tag.Key = tag.MustNewKey("grpc_service") - TagKeyGRPCMethod tag.Key = tag.MustNewKey("grpc_method") - TagKeyHost tag.Key = tag.MustNewKey("host") - TagKeyDestination tag.Key = tag.MustNewKey("destination") + TagKeyHTTPMethod = tag.MustNewKey("http_method") + TagKeyService = tag.MustNewKey("service") + TagKeyGRPCService = tag.MustNewKey("grpc_service") + TagKeyGRPCMethod = tag.MustNewKey("grpc_method") + TagKeyHost = tag.MustNewKey("host") + TagKeyDestination = tag.MustNewKey("destination") ) // Default distributions used by views in this package. diff --git a/proto/authenticate/authenticate.pb.go b/proto/authenticate/authenticate.pb.go deleted file mode 100644 index 70e9c2d2a..000000000 --- a/proto/authenticate/authenticate.pb.go +++ /dev/null @@ -1,399 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// source: authenticate.proto - -package authenticate - -import proto "github.com/golang/protobuf/proto" -import fmt "fmt" -import math "math" -import timestamp "github.com/golang/protobuf/ptypes/timestamp" - -import ( - context "golang.org/x/net/context" - grpc "google.golang.org/grpc" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package - -type AuthenticateRequest struct { - Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *AuthenticateRequest) Reset() { *m = AuthenticateRequest{} } -func (m *AuthenticateRequest) String() string { return proto.CompactTextString(m) } -func (*AuthenticateRequest) ProtoMessage() {} -func (*AuthenticateRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_authenticate_d9796afa57ba1f78, []int{0} -} -func (m *AuthenticateRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_AuthenticateRequest.Unmarshal(m, b) -} -func (m *AuthenticateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_AuthenticateRequest.Marshal(b, m, deterministic) -} -func (dst *AuthenticateRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_AuthenticateRequest.Merge(dst, src) -} -func (m *AuthenticateRequest) XXX_Size() int { - return xxx_messageInfo_AuthenticateRequest.Size(m) -} -func (m *AuthenticateRequest) XXX_DiscardUnknown() { - xxx_messageInfo_AuthenticateRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_AuthenticateRequest proto.InternalMessageInfo - -func (m *AuthenticateRequest) GetCode() string { - if m != nil { - return m.Code - } - return "" -} - -type ValidateRequest struct { - IdToken string `protobuf:"bytes,1,opt,name=id_token,json=idToken,proto3" json:"id_token,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ValidateRequest) Reset() { *m = ValidateRequest{} } -func (m *ValidateRequest) String() string { return proto.CompactTextString(m) } -func (*ValidateRequest) ProtoMessage() {} -func (*ValidateRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_authenticate_d9796afa57ba1f78, []int{1} -} -func (m *ValidateRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ValidateRequest.Unmarshal(m, b) -} -func (m *ValidateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ValidateRequest.Marshal(b, m, deterministic) -} -func (dst *ValidateRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_ValidateRequest.Merge(dst, src) -} -func (m *ValidateRequest) XXX_Size() int { - return xxx_messageInfo_ValidateRequest.Size(m) -} -func (m *ValidateRequest) XXX_DiscardUnknown() { - xxx_messageInfo_ValidateRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_ValidateRequest proto.InternalMessageInfo - -func (m *ValidateRequest) GetIdToken() string { - if m != nil { - return m.IdToken - } - return "" -} - -type ValidateReply struct { - IsValid bool `protobuf:"varint,1,opt,name=is_valid,json=isValid,proto3" json:"is_valid,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ValidateReply) Reset() { *m = ValidateReply{} } -func (m *ValidateReply) String() string { return proto.CompactTextString(m) } -func (*ValidateReply) ProtoMessage() {} -func (*ValidateReply) Descriptor() ([]byte, []int) { - return fileDescriptor_authenticate_d9796afa57ba1f78, []int{2} -} -func (m *ValidateReply) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ValidateReply.Unmarshal(m, b) -} -func (m *ValidateReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ValidateReply.Marshal(b, m, deterministic) -} -func (dst *ValidateReply) XXX_Merge(src proto.Message) { - xxx_messageInfo_ValidateReply.Merge(dst, src) -} -func (m *ValidateReply) XXX_Size() int { - return xxx_messageInfo_ValidateReply.Size(m) -} -func (m *ValidateReply) XXX_DiscardUnknown() { - xxx_messageInfo_ValidateReply.DiscardUnknown(m) -} - -var xxx_messageInfo_ValidateReply proto.InternalMessageInfo - -func (m *ValidateReply) GetIsValid() bool { - if m != nil { - return m.IsValid - } - return false -} - -type Session struct { - AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` - RefreshToken string `protobuf:"bytes,2,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"` - IdToken string `protobuf:"bytes,3,opt,name=id_token,json=idToken,proto3" json:"id_token,omitempty"` - User string `protobuf:"bytes,4,opt,name=user,proto3" json:"user,omitempty"` - Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"` - Groups []string `protobuf:"bytes,6,rep,name=groups,proto3" json:"groups,omitempty"` - RefreshDeadline *timestamp.Timestamp `protobuf:"bytes,7,opt,name=refresh_deadline,json=refreshDeadline,proto3" json:"refresh_deadline,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Session) Reset() { *m = Session{} } -func (m *Session) String() string { return proto.CompactTextString(m) } -func (*Session) ProtoMessage() {} -func (*Session) Descriptor() ([]byte, []int) { - return fileDescriptor_authenticate_d9796afa57ba1f78, []int{3} -} -func (m *Session) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Session.Unmarshal(m, b) -} -func (m *Session) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Session.Marshal(b, m, deterministic) -} -func (dst *Session) XXX_Merge(src proto.Message) { - xxx_messageInfo_Session.Merge(dst, src) -} -func (m *Session) XXX_Size() int { - return xxx_messageInfo_Session.Size(m) -} -func (m *Session) XXX_DiscardUnknown() { - xxx_messageInfo_Session.DiscardUnknown(m) -} - -var xxx_messageInfo_Session proto.InternalMessageInfo - -func (m *Session) GetAccessToken() string { - if m != nil { - return m.AccessToken - } - return "" -} - -func (m *Session) GetRefreshToken() string { - if m != nil { - return m.RefreshToken - } - return "" -} - -func (m *Session) GetIdToken() string { - if m != nil { - return m.IdToken - } - return "" -} - -func (m *Session) GetUser() string { - if m != nil { - return m.User - } - return "" -} - -func (m *Session) GetEmail() string { - if m != nil { - return m.Email - } - return "" -} - -func (m *Session) GetGroups() []string { - if m != nil { - return m.Groups - } - return nil -} - -func (m *Session) GetRefreshDeadline() *timestamp.Timestamp { - if m != nil { - return m.RefreshDeadline - } - return nil -} - -func init() { - proto.RegisterType((*AuthenticateRequest)(nil), "authenticate.AuthenticateRequest") - proto.RegisterType((*ValidateRequest)(nil), "authenticate.ValidateRequest") - proto.RegisterType((*ValidateReply)(nil), "authenticate.ValidateReply") - proto.RegisterType((*Session)(nil), "authenticate.Session") -} - -// Reference imports to suppress errors if they are not otherwise used. -var _ context.Context -var _ grpc.ClientConn - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -const _ = grpc.SupportPackageIsVersion4 - -// AuthenticatorClient is the client API for Authenticator service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. -type AuthenticatorClient interface { - Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*Session, error) - Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateReply, error) - Refresh(ctx context.Context, in *Session, opts ...grpc.CallOption) (*Session, error) -} - -type authenticatorClient struct { - cc *grpc.ClientConn -} - -func NewAuthenticatorClient(cc *grpc.ClientConn) AuthenticatorClient { - return &authenticatorClient{cc} -} - -func (c *authenticatorClient) Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*Session, error) { - out := new(Session) - err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Authenticate", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *authenticatorClient) Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateReply, error) { - out := new(ValidateReply) - err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Validate", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *authenticatorClient) Refresh(ctx context.Context, in *Session, opts ...grpc.CallOption) (*Session, error) { - out := new(Session) - err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Refresh", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// AuthenticatorServer is the server API for Authenticator service. -type AuthenticatorServer interface { - Authenticate(context.Context, *AuthenticateRequest) (*Session, error) - Validate(context.Context, *ValidateRequest) (*ValidateReply, error) - Refresh(context.Context, *Session) (*Session, error) -} - -func RegisterAuthenticatorServer(s *grpc.Server, srv AuthenticatorServer) { - s.RegisterService(&_Authenticator_serviceDesc, srv) -} - -func _Authenticator_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(AuthenticateRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(AuthenticatorServer).Authenticate(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/authenticate.Authenticator/Authenticate", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(AuthenticatorServer).Authenticate(ctx, req.(*AuthenticateRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _Authenticator_Validate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ValidateRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(AuthenticatorServer).Validate(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/authenticate.Authenticator/Validate", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(AuthenticatorServer).Validate(ctx, req.(*ValidateRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _Authenticator_Refresh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Session) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(AuthenticatorServer).Refresh(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/authenticate.Authenticator/Refresh", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(AuthenticatorServer).Refresh(ctx, req.(*Session)) - } - return interceptor(ctx, in, info, handler) -} - -var _Authenticator_serviceDesc = grpc.ServiceDesc{ - ServiceName: "authenticate.Authenticator", - HandlerType: (*AuthenticatorServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Authenticate", - Handler: _Authenticator_Authenticate_Handler, - }, - { - MethodName: "Validate", - Handler: _Authenticator_Validate_Handler, - }, - { - MethodName: "Refresh", - Handler: _Authenticator_Refresh_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "authenticate.proto", -} - -func init() { proto.RegisterFile("authenticate.proto", fileDescriptor_authenticate_d9796afa57ba1f78) } - -var fileDescriptor_authenticate_d9796afa57ba1f78 = []byte{ - // 354 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x91, 0x4d, 0x4f, 0xb3, 0x40, - 0x14, 0x85, 0xcb, 0xdb, 0x0f, 0xda, 0x5b, 0x9a, 0xbe, 0xb9, 0x7e, 0x04, 0x31, 0xc6, 0x16, 0x37, - 0xd5, 0x18, 0x9a, 0xd4, 0x95, 0x4b, 0x13, 0x4d, 0x8c, 0x4b, 0x6c, 0xdc, 0x36, 0x14, 0x6e, 0xdb, - 0x89, 0x94, 0x41, 0x66, 0x30, 0xe9, 0xbf, 0xf5, 0x4f, 0xb8, 0x37, 0x0c, 0x10, 0xc1, 0xb4, 0x3b, - 0xee, 0x99, 0xe7, 0x0e, 0x67, 0xce, 0x01, 0xf4, 0x52, 0xb9, 0xa1, 0x48, 0x32, 0xdf, 0x93, 0xe4, - 0xc4, 0x09, 0x97, 0x1c, 0x8d, 0xaa, 0x66, 0x5d, 0xae, 0x39, 0x5f, 0x87, 0x34, 0x55, 0x67, 0xcb, - 0x74, 0x35, 0x95, 0x6c, 0x4b, 0x42, 0x7a, 0xdb, 0x38, 0xc7, 0xed, 0x6b, 0x38, 0x7a, 0xa8, 0x2c, - 0xb8, 0xf4, 0x91, 0x92, 0x90, 0x88, 0xd0, 0xf2, 0x79, 0x40, 0xa6, 0x36, 0xd2, 0x26, 0x3d, 0x57, - 0x7d, 0xdb, 0xb7, 0x30, 0x7c, 0xf3, 0x42, 0x16, 0x54, 0xb0, 0x33, 0xe8, 0xb2, 0x60, 0x21, 0xf9, - 0x3b, 0x45, 0x05, 0xaa, 0xb3, 0x60, 0x9e, 0x8d, 0xf6, 0x0d, 0x0c, 0x7e, 0xe9, 0x38, 0xdc, 0x29, - 0x56, 0x2c, 0x3e, 0x33, 0x4d, 0xb1, 0x5d, 0x57, 0x67, 0x42, 0x21, 0xf6, 0xb7, 0x06, 0xfa, 0x2b, - 0x09, 0xc1, 0x78, 0x84, 0x63, 0x30, 0x3c, 0xdf, 0x27, 0x21, 0x6a, 0xd7, 0xf6, 0x73, 0x4d, 0x5d, - 0x8d, 0x57, 0x30, 0x48, 0x68, 0x95, 0x90, 0xd8, 0x14, 0xcc, 0x3f, 0xc5, 0x18, 0x85, 0x98, 0x43, - 0x55, 0x6b, 0xcd, 0x9a, 0xb5, 0xec, 0x71, 0xa9, 0xa0, 0xc4, 0x6c, 0xe5, 0x8f, 0xcb, 0xbe, 0xf1, - 0x18, 0xda, 0xb4, 0xf5, 0x58, 0x68, 0xb6, 0x95, 0x98, 0x0f, 0x78, 0x0a, 0x9d, 0x75, 0xc2, 0xd3, - 0x58, 0x98, 0x9d, 0x51, 0x73, 0xd2, 0x73, 0x8b, 0x09, 0x9f, 0xe0, 0x7f, 0xe9, 0x20, 0x20, 0x2f, - 0x08, 0x59, 0x44, 0xa6, 0x3e, 0xd2, 0x26, 0xfd, 0x99, 0xe5, 0xe4, 0x89, 0x3b, 0x65, 0xe2, 0xce, - 0xbc, 0x4c, 0xdc, 0x1d, 0x16, 0x3b, 0x8f, 0xc5, 0xca, 0xec, 0x4b, 0x83, 0x41, 0x25, 0x7d, 0x9e, - 0xe0, 0x0b, 0x18, 0xd5, 0x3a, 0x70, 0xec, 0xd4, 0x2a, 0xde, 0x53, 0x95, 0x75, 0x52, 0x47, 0x8a, - 0x1c, 0xed, 0x06, 0x3e, 0x43, 0xb7, 0x6c, 0x00, 0x2f, 0xea, 0xd0, 0x9f, 0x1e, 0xad, 0xf3, 0x43, - 0xc7, 0x71, 0xb8, 0xb3, 0x1b, 0x78, 0x0f, 0xba, 0x9b, 0x5b, 0xc7, 0xfd, 0x7f, 0x3b, 0x68, 0x62, - 0xd9, 0x51, 0x39, 0xdc, 0xfd, 0x04, 0x00, 0x00, 0xff, 0xff, 0xdc, 0x47, 0xff, 0x7e, 0xab, 0x02, - 0x00, 0x00, -} diff --git a/proto/authenticate/authenticate.proto b/proto/authenticate/authenticate.proto deleted file mode 100644 index 7d6826cb5..000000000 --- a/proto/authenticate/authenticate.proto +++ /dev/null @@ -1,26 +0,0 @@ -syntax = "proto3"; -import "google/protobuf/timestamp.proto"; - -package authenticate; - -service Authenticator { - rpc Authenticate(AuthenticateRequest) returns (Session) {} - rpc Validate(ValidateRequest) returns (ValidateReply) {} - rpc Refresh(Session) returns (Session) {} -} - -message AuthenticateRequest { string code = 1; } - -message ValidateRequest { string id_token = 1; } - -message ValidateReply { bool is_valid = 1; } - -message Session { - string access_token = 1; - string refresh_token = 2; - string id_token = 3; - string user = 4; - string email = 5; - repeated string groups = 6; - google.protobuf.Timestamp refresh_deadline = 7; -} diff --git a/proto/authenticate/convert.go b/proto/authenticate/convert.go deleted file mode 100644 index 621d1c589..000000000 --- a/proto/authenticate/convert.go +++ /dev/null @@ -1,49 +0,0 @@ -package authenticate - -import ( - fmt "fmt" - - "github.com/golang/protobuf/ptypes" - "github.com/pomerium/pomerium/internal/sessions" -) - -// SessionFromProto converts a converts a protocol buffer session into a pomerium session state. -func SessionFromProto(p *Session) (*sessions.SessionState, error) { - if p == nil { - return nil, fmt.Errorf("proto/authenticate: SessionFromProto session cannot be nil") - } - - refreshDeadline, err := ptypes.Timestamp(p.RefreshDeadline) - if err != nil { - return nil, fmt.Errorf("proto/authenticate: couldn't parse refresh deadline %v", err) - } - return &sessions.SessionState{ - AccessToken: p.AccessToken, - RefreshToken: p.RefreshToken, - IDToken: p.IdToken, - Email: p.Email, - User: p.User, - Groups: p.Groups, - RefreshDeadline: refreshDeadline, - }, nil -} - -// ProtoFromSession converts a pomerium user session into a protocol buffer struct. -func ProtoFromSession(s *sessions.SessionState) (*Session, error) { - if s == nil { - return nil, fmt.Errorf("proto/authenticate: ProtoFromSession session cannot be nil") - } - refreshDeadline, err := ptypes.TimestampProto(s.RefreshDeadline) - if err != nil { - return nil, fmt.Errorf("proto/authenticate: couldn't parse refresh deadline %v", err) - } - return &Session{ - AccessToken: s.AccessToken, - RefreshToken: s.RefreshToken, - IdToken: s.IDToken, - Email: s.Email, - User: s.User, - Groups: s.Groups, - RefreshDeadline: refreshDeadline, - }, nil -} diff --git a/proto/authenticate/mock_authenticate/mock_authenticate.go b/proto/authenticate/mock_authenticate/mock_authenticate.go deleted file mode 100644 index 6315b0b33..000000000 --- a/proto/authenticate/mock_authenticate/mock_authenticate.go +++ /dev/null @@ -1,165 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: proto/authenticate/authenticate.pb.go - -// Package mock_authenticate is a generated GoMock package. -package mock_authenticate - -import ( - "context" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - "github.com/pomerium/pomerium/proto/authenticate" - grpc "google.golang.org/grpc" -) - -// MockAuthenticatorClient is a mock of AuthenticatorClient interface -type MockAuthenticatorClient struct { - ctrl *gomock.Controller - recorder *MockAuthenticatorClientMockRecorder -} - -// MockAuthenticatorClientMockRecorder is the mock recorder for MockAuthenticatorClient -type MockAuthenticatorClientMockRecorder struct { - mock *MockAuthenticatorClient -} - -// NewMockAuthenticatorClient creates a new mock instance -func NewMockAuthenticatorClient(ctrl *gomock.Controller) *MockAuthenticatorClient { - mock := &MockAuthenticatorClient{ctrl: ctrl} - mock.recorder = &MockAuthenticatorClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockAuthenticatorClient) EXPECT() *MockAuthenticatorClientMockRecorder { - return m.recorder -} - -// Authenticate mocks base method -func (m *MockAuthenticatorClient) Authenticate(ctx context.Context, in *authenticate.AuthenticateRequest, opts ...grpc.CallOption) (*authenticate.Session, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Authenticate", varargs...) - ret0, _ := ret[0].(*authenticate.Session) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Authenticate indicates an expected call of Authenticate -func (mr *MockAuthenticatorClientMockRecorder) Authenticate(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticatorClient)(nil).Authenticate), varargs...) -} - -// Validate mocks base method -func (m *MockAuthenticatorClient) Validate(ctx context.Context, in *authenticate.ValidateRequest, opts ...grpc.CallOption) (*authenticate.ValidateReply, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Validate", varargs...) - ret0, _ := ret[0].(*authenticate.ValidateReply) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Validate indicates an expected call of Validate -func (mr *MockAuthenticatorClientMockRecorder) Validate(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockAuthenticatorClient)(nil).Validate), varargs...) -} - -// Refresh mocks base method -func (m *MockAuthenticatorClient) Refresh(ctx context.Context, in *authenticate.Session, opts ...grpc.CallOption) (*authenticate.Session, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Refresh", varargs...) - ret0, _ := ret[0].(*authenticate.Session) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Refresh indicates an expected call of Refresh -func (mr *MockAuthenticatorClientMockRecorder) Refresh(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockAuthenticatorClient)(nil).Refresh), varargs...) -} - -// MockAuthenticatorServer is a mock of AuthenticatorServer interface -type MockAuthenticatorServer struct { - ctrl *gomock.Controller - recorder *MockAuthenticatorServerMockRecorder -} - -// MockAuthenticatorServerMockRecorder is the mock recorder for MockAuthenticatorServer -type MockAuthenticatorServerMockRecorder struct { - mock *MockAuthenticatorServer -} - -// NewMockAuthenticatorServer creates a new mock instance -func NewMockAuthenticatorServer(ctrl *gomock.Controller) *MockAuthenticatorServer { - mock := &MockAuthenticatorServer{ctrl: ctrl} - mock.recorder = &MockAuthenticatorServerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockAuthenticatorServer) EXPECT() *MockAuthenticatorServerMockRecorder { - return m.recorder -} - -// Authenticate mocks base method -func (m *MockAuthenticatorServer) Authenticate(arg0 context.Context, arg1 *authenticate.AuthenticateRequest) (*authenticate.Session, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Authenticate", arg0, arg1) - ret0, _ := ret[0].(*authenticate.Session) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Authenticate indicates an expected call of Authenticate -func (mr *MockAuthenticatorServerMockRecorder) Authenticate(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticatorServer)(nil).Authenticate), arg0, arg1) -} - -// Validate mocks base method -func (m *MockAuthenticatorServer) Validate(arg0 context.Context, arg1 *authenticate.ValidateRequest) (*authenticate.ValidateReply, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Validate", arg0, arg1) - ret0, _ := ret[0].(*authenticate.ValidateReply) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Validate indicates an expected call of Validate -func (mr *MockAuthenticatorServerMockRecorder) Validate(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockAuthenticatorServer)(nil).Validate), arg0, arg1) -} - -// Refresh mocks base method -func (m *MockAuthenticatorServer) Refresh(arg0 context.Context, arg1 *authenticate.Session) (*authenticate.Session, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Refresh", arg0, arg1) - ret0, _ := ret[0].(*authenticate.Session) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Refresh indicates an expected call of Refresh -func (mr *MockAuthenticatorServerMockRecorder) Refresh(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockAuthenticatorServer)(nil).Refresh), arg0, arg1) -} diff --git a/proxy/clients/authenticate_client.go b/proxy/clients/authenticate_client.go deleted file mode 100644 index 572828526..000000000 --- a/proxy/clients/authenticate_client.go +++ /dev/null @@ -1,117 +0,0 @@ -package clients // import "github.com/pomerium/pomerium/proxy/clients" - -import ( - "context" - "errors" - - "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" - pb "github.com/pomerium/pomerium/proto/authenticate" - - "google.golang.org/grpc" -) - -// Authenticator provides the authenticate service interface -type Authenticator interface { - // Redeem takes a code and returns a validated session or an error - Redeem(context.Context, string) (*sessions.SessionState, error) - // Refresh attempts to refresh a valid session with a refresh token. Returns a refreshed session. - Refresh(context.Context, *sessions.SessionState) (*sessions.SessionState, error) - // Validate evaluates a given oidc id_token for validity. Returns validity and any error. - Validate(context.Context, string) (bool, error) - // Close closes the authenticator connection if any. - Close() error -} - -// NewAuthenticateClient returns a new authenticate service client. Presently, -// only gRPC is supported and is always returned so name is ignored. -func NewAuthenticateClient(name string, opts *Options) (a Authenticator, err error) { - return NewGRPCAuthenticateClient(opts) -} - -// NewGRPCAuthenticateClient returns a new authenticate service client. -func NewGRPCAuthenticateClient(opts *Options) (p *AuthenticateGRPC, err error) { - conn, err := NewGRPCClientConn(opts) - if err != nil { - return nil, err - } - authClient := pb.NewAuthenticatorClient(conn) - return &AuthenticateGRPC{Conn: conn, client: authClient}, nil -} - -// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client) -type AuthenticateGRPC struct { - Conn *grpc.ClientConn - client pb.AuthenticatorClient -} - -// Redeem makes an RPC call to the authenticate service to creates a session state -// from an encrypted code provided as a result of an oauth2 callback process. -func (a *AuthenticateGRPC) Redeem(ctx context.Context, code string) (*sessions.SessionState, error) { - ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Redeem") - defer span.End() - - if code == "" { - return nil, errors.New("missing code") - } - protoSession, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code}) - if err != nil { - return nil, err - } - session, err := pb.SessionFromProto(protoSession) - if err != nil { - return nil, err - } - return session, nil -} - -// Refresh makes an RPC call to the authenticate service to attempt to refresh the -// user's session. Requires a valid refresh token. Will return an error if the identity provider -// has revoked the session or if the refresh token is no longer valid in this context. -func (a *AuthenticateGRPC) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { - ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Refresh") - defer span.End() - - if s.RefreshToken == "" { - return nil, errors.New("missing refresh token") - } - req, err := pb.ProtoFromSession(s) - if err != nil { - return nil, err - } - - // todo(bdd): add grpc specific timeouts to main options - // todo(bdd): handle request id (metadata!?) in grpc receiver and add to ctx logger - reply, err := a.client.Refresh(ctx, req) - if err != nil { - return nil, err - } - newSession, err := pb.SessionFromProto(reply) - if err != nil { - return nil, err - } - return newSession, nil -} - -// Validate makes an RPC call to the authenticate service to validate the JWT id token; -// does NOT do nonce or revokation validation. -// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func (a *AuthenticateGRPC) Validate(ctx context.Context, idToken string) (bool, error) { - ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Validate") - defer span.End() - - if idToken == "" { - return false, errors.New("missing id token") - } - - r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken}) - if err != nil { - return false, err - } - return r.IsValid, nil -} - -// Close tears down the ClientConn and all underlying connections. -func (a *AuthenticateGRPC) Close() error { - return a.Conn.Close() -} diff --git a/proxy/clients/authenticate_client_test.go b/proxy/clients/authenticate_client_test.go deleted file mode 100644 index d28b6e8f4..000000000 --- a/proxy/clients/authenticate_client_test.go +++ /dev/null @@ -1,242 +0,0 @@ -package clients // import "github.com/pomerium/pomerium/proxy/clients" - -import ( - "context" - "fmt" - "net/url" - "reflect" - "strings" - "testing" - "time" - - "github.com/golang/mock/gomock" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "github.com/pomerium/pomerium/internal/sessions" - pb "github.com/pomerium/pomerium/proto/authenticate" - mock "github.com/pomerium/pomerium/proto/authenticate/mock_authenticate" -) - -func TestNew(t *testing.T) { - tests := []struct { - name string - serviceName string - opts *Options - wantErr bool - }{ - {"grpc good", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "secret"}, false}, - {"grpc missing shared secret", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: ""}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewAuthenticateClient(tt.serviceName, tt.opts) - if (err != nil) != tt.wantErr { - t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) - -// rpcMsg implements the gomock.Matcher interface -type rpcMsg struct { - msg proto.Message -} - -func (r *rpcMsg) Matches(msg interface{}) bool { - m, ok := msg.(proto.Message) - if !ok { - return false - } - return proto.Equal(m, r.msg) -} - -func (r *rpcMsg) String() string { - return fmt.Sprintf("is %s", r.msg) -} - -func TestProxy_Redeem(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl) - req := &pb.AuthenticateRequest{Code: "unit_test"} - mockExpire, err := ptypes.TimestampProto(fixedDate) - if err != nil { - t.Fatalf("%v failed converting timestamp", err) - } - - mockAuthenticateClient.EXPECT().Authenticate( - gomock.Any(), - &rpcMsg{msg: req}, - ).Return(&pb.Session{ - AccessToken: "mocked access token", - RefreshToken: "mocked refresh token", - IdToken: "mocked id token", - User: "user1", - Email: "test@email.com", - RefreshDeadline: mockExpire, - }, nil) - tests := []struct { - name string - idToken string - want *sessions.SessionState - wantErr bool - }{ - {"good", "unit_test", &sessions.SessionState{ - AccessToken: "mocked access token", - RefreshToken: "mocked refresh token", - IDToken: "mocked id token", - User: "user1", - Email: "test@email.com", - RefreshDeadline: (fixedDate), - }, false}, - {"empty code", "", nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := AuthenticateGRPC{client: mockAuthenticateClient} - got, err := a.Redeem(context.Background(), tt.idToken) - if (err != nil) != tt.wantErr { - t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr) - return - } - if got != nil { - if got.AccessToken != "mocked access token" { - t.Errorf("authenticate: invalid access token") - } - if got.RefreshToken != "mocked refresh token" { - t.Errorf("authenticate: invalid refresh token") - } - if got.IDToken != "mocked id token" { - t.Errorf("authenticate: invalid id token") - } - if got.User != "user1" { - t.Errorf("authenticate: invalid user") - } - if got.Email != "test@email.com" { - t.Errorf("authenticate: invalid email") - } - } - }) - } -} -func TestProxy_AuthenticateValidate(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl) - req := &pb.ValidateRequest{IdToken: "unit_test"} - - mockAuthenticateClient.EXPECT().Validate( - gomock.Any(), - &rpcMsg{msg: req}, - ).Return(&pb.ValidateReply{IsValid: false}, nil) - - ac := mockAuthenticateClient - tests := []struct { - name string - idToken string - want bool - wantErr bool - }{ - {"good", "unit_test", false, false}, - {"empty id token", "", false, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := AuthenticateGRPC{client: ac} - - got, err := a.Validate(context.Background(), tt.idToken) - if (err != nil) != tt.wantErr { - t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Proxy.AuthenticateValidate() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestProxy_AuthenticateRefresh(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockRefreshClient := mock.NewMockAuthenticatorClient(ctrl) - mockExpire, _ := ptypes.TimestampProto(fixedDate) - - mockRefreshClient.EXPECT().Refresh( - gomock.Any(), - gomock.Not(sessions.SessionState{RefreshToken: "fail"}), - ).Return(&pb.Session{ - AccessToken: "new access token", - RefreshDeadline: mockExpire, - }, nil).AnyTimes() - - tests := []struct { - name string - session *sessions.SessionState - want *sessions.SessionState - wantErr bool - }{ - {"good", - &sessions.SessionState{RefreshToken: "unit_test"}, - &sessions.SessionState{ - AccessToken: "new access token", - RefreshDeadline: fixedDate, - }, false}, - {"empty refresh token", &sessions.SessionState{RefreshToken: ""}, nil, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := AuthenticateGRPC{client: mockRefreshClient} - - got, err := a.Refresh(context.Background(), tt.session) - if (err != nil) != tt.wantErr { - t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Proxy.AuthenticateRefresh() got = \n%#v\nwant \n%#v", got, tt.want) - } - }) - } -} - -func TestNewGRPC(t *testing.T) { - tests := []struct { - name string - opts *Options - wantErr bool - wantErrStr string - wantTarget string - }{ - {"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""}, - {"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""}, - {"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""}, - {"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"}, - {"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"}, - {"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"}, - {"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"}, - {"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"}, - {"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"}, - {"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"}, - {"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"}, - {"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NewGRPCAuthenticateClient(tt.opts) - if (err != nil) != tt.wantErr { - t.Errorf("NewGRPCAuthenticateClient() error = %v, wantErr %v", err, tt.wantErr) - if !strings.EqualFold(err.Error(), tt.wantErrStr) { - t.Errorf("NewGRPCAuthenticateClient() error = %v did not contain wantErr %v", err, tt.wantErrStr) - } - } - if got != nil && got.Conn.Target() != tt.wantTarget { - t.Errorf("NewGRPCAuthenticateClient() target = %v expected %v", got.Conn.Target(), tt.wantTarget) - - } - }) - } -} diff --git a/proxy/clients/authorize_client.go b/proxy/clients/authorize_client.go index 229fa87b5..85470099d 100644 --- a/proxy/clients/authorize_client.go +++ b/proxy/clients/authorize_client.go @@ -15,9 +15,9 @@ import ( type Authorizer interface { // Authorize takes a route and user session and returns whether the // request is valid per access policy - Authorize(context.Context, string, *sessions.SessionState) (bool, error) + Authorize(context.Context, string, *sessions.State) (bool, error) // IsAdmin takes a session and returns whether the user is an administrator - IsAdmin(context.Context, *sessions.SessionState) (bool, error) + IsAdmin(context.Context, *sessions.State) (bool, error) // Close closes the auth connection if any. Close() error } @@ -46,7 +46,7 @@ type AuthorizeGRPC struct { // Authorize takes a route and user session and returns whether the // request is valid per access policy -func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions.SessionState) (bool, error) { +func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions.State) (bool, error) { ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Authorize") defer span.End() @@ -65,7 +65,7 @@ func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions } // IsAdmin takes a session and returns whether the user is an administrator -func (a *AuthorizeGRPC) IsAdmin(ctx context.Context, s *sessions.SessionState) (bool, error) { +func (a *AuthorizeGRPC) IsAdmin(ctx context.Context, s *sessions.State) (bool, error) { ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.IsAdmin") defer span.End() diff --git a/proxy/clients/authorize_client_test.go b/proxy/clients/authorize_client_test.go index 66b6ca349..ea9393ee7 100644 --- a/proxy/clients/authorize_client_test.go +++ b/proxy/clients/authorize_client_test.go @@ -2,6 +2,8 @@ package clients import ( "context" + "net/url" + "strings" "testing" "github.com/golang/mock/gomock" @@ -23,12 +25,12 @@ func TestAuthorizeGRPC_Authorize(t *testing.T) { tests := []struct { name string route string - s *sessions.SessionState + s *sessions.State want bool wantErr bool }{ - {"good", "hello.pomerium.io", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false}, - {"impersonate request", "hello.pomerium.io", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io", ImpersonateEmail: "other@other.example"}, true, false}, + {"good", "hello.pomerium.io", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false}, + {"impersonate request", "hello.pomerium.io", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io", ImpersonateEmail: "other@other.example"}, true, false}, {"session cannot be nil", "hello.pomerium.io", nil, false, true}, } for _, tt := range tests { @@ -56,11 +58,11 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) { tests := []struct { name string - s *sessions.SessionState + s *sessions.State want bool wantErr bool }{ - {"good", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false}, + {"good", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false}, {"session cannot be nil", nil, false, true}, } for _, tt := range tests { @@ -77,3 +79,41 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) { }) } } + +func TestNewGRPC(t *testing.T) { + tests := []struct { + name string + opts *Options + wantErr bool + wantErrStr string + wantTarget string + }{ + {"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""}, + {"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""}, + {"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""}, + {"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"}, + {"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"}, + {"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"}, + {"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"}, + {"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"}, + {"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"}, + {"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"}, + {"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"}, + {"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewGRPCAuthorizeClient(tt.opts) + if (err != nil) != tt.wantErr { + t.Errorf("NewGRPCAuthorizeClient() error = %v, wantErr %v", err, tt.wantErr) + if !strings.EqualFold(err.Error(), tt.wantErrStr) { + t.Errorf("NewGRPCAuthorizeClient() error = %v did not contain wantErr %v", err, tt.wantErrStr) + } + } + if got != nil && got.Conn.Target() != tt.wantTarget { + t.Errorf("NewGRPCAuthorizeClient() target = %v expected %v", got.Conn.Target(), tt.wantTarget) + + } + }) + } +} diff --git a/proxy/clients/clients.go b/proxy/clients/clients.go index 52a8416e6..511d77383 100644 --- a/proxy/clients/clients.go +++ b/proxy/clients/clients.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/telemetry/metrics" + "go.opencensus.io/plugin/ocgrpc" "google.golang.org/grpc" "google.golang.org/grpc/balancer/roundrobin" @@ -25,7 +26,7 @@ const defaultGRPCPort = 443 // Options contains options for connecting to a pomerium rpc service. type Options struct { - // Addr is the location of the authenticate service. e.g. "service.corp.example:8443" + // Addr is the location of the service. e.g. "service.corp.example:8443" Addr *url.URL // InternalAddr is the internal (behind the ingress) address to use when // making a connection. If empty, Addr is used. @@ -34,7 +35,7 @@ type Options struct { // returned certificates from the server. gRPC internals also use it to override the virtual // hosting name if it is set. OverrideCertificateName string - // Shared secret is used to authenticate a authenticate-client with a authenticate-server. + // Shared secret is used to mutually authenticate a client and server. SharedSecret string // CA specifies the base64 encoded TLS certificate authority to use. CA string diff --git a/proxy/clients/mock_clients.go b/proxy/clients/mock_clients.go index f4d03e7b8..9acac4594 100644 --- a/proxy/clients/mock_clients.go +++ b/proxy/clients/mock_clients.go @@ -6,35 +6,6 @@ import ( "github.com/pomerium/pomerium/internal/sessions" ) -// MockAuthenticate provides a mocked implementation of the authenticator interface. -type MockAuthenticate struct { - RedeemError error - RedeemResponse *sessions.SessionState - RefreshResponse *sessions.SessionState - RefreshError error - ValidateResponse bool - ValidateError error - CloseError error -} - -// Redeem is a mocked authenticator client function. -func (a MockAuthenticate) Redeem(ctx context.Context, code string) (*sessions.SessionState, error) { - return a.RedeemResponse, a.RedeemError -} - -// Refresh is a mocked authenticator client function. -func (a MockAuthenticate) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { - return a.RefreshResponse, a.RefreshError -} - -// Validate is a mocked authenticator client function. -func (a MockAuthenticate) Validate(ctx context.Context, idToken string) (bool, error) { - return a.ValidateResponse, a.ValidateError -} - -// Close is a mocked authenticator client function. -func (a MockAuthenticate) Close() error { return a.CloseError } - // MockAuthorize provides a mocked implementation of the authorizer interface. type MockAuthorize struct { AuthorizeResponse bool @@ -48,11 +19,11 @@ type MockAuthorize struct { func (a MockAuthorize) Close() error { return a.CloseError } // Authorize is a mocked authorizer client function. -func (a MockAuthorize) Authorize(ctx context.Context, route string, s *sessions.SessionState) (bool, error) { +func (a MockAuthorize) Authorize(ctx context.Context, route string, s *sessions.State) (bool, error) { return a.AuthorizeResponse, a.AuthorizeError } // IsAdmin is a mocked IsAdmin function. -func (a MockAuthorize) IsAdmin(ctx context.Context, s *sessions.SessionState) (bool, error) { +func (a MockAuthorize) IsAdmin(ctx context.Context, s *sessions.State) (bool, error) { return a.IsAdminResponse, a.IsAdminError } diff --git a/proxy/clients/mock_clients_test.go b/proxy/clients/mock_clients_test.go deleted file mode 100644 index 8cd5330de..000000000 --- a/proxy/clients/mock_clients_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package clients - -import ( - "context" - "errors" - "reflect" - "testing" - - "github.com/pomerium/pomerium/internal/sessions" -) - -func TestMockAuthenticate(t *testing.T) { - // Absurd, but I caught a typo this way. - redeemResponse := &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - } - ma := &MockAuthenticate{ - RedeemError: errors.New("redeem error"), - RedeemResponse: redeemResponse, - RefreshResponse: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - }, - RefreshError: errors.New("refresh error"), - ValidateResponse: true, - ValidateError: errors.New("validate error"), - CloseError: errors.New("close error"), - } - got, gotErr := ma.Redeem(context.Background(), "a") - if gotErr.Error() != "redeem error" { - t.Errorf("unexpected value for gotErr %s", gotErr) - } - if !reflect.DeepEqual(redeemResponse, got) { - t.Errorf("unexpected value for redeemResponse %s", got) - } - newSession, gotErr := ma.Refresh(context.Background(), nil) - if gotErr.Error() != "refresh error" { - t.Errorf("unexpected value for gotErr %s", gotErr) - } - if !reflect.DeepEqual(newSession, redeemResponse) { - t.Errorf("unexpected value for newSession %s", newSession) - } - - ok, gotErr := ma.Validate(context.Background(), "a") - if !ok { - t.Errorf("unexpected value for ok : %t", ok) - } - if gotErr.Error() != "validate error" { - t.Errorf("unexpected value for gotErr %s", gotErr) - } - gotErr = ma.Close() - if gotErr.Error() != "close error" { - t.Errorf("unexpected value for ma.CloseError %s", gotErr) - } - -} diff --git a/proxy/handlers.go b/proxy/handlers.go index 791770d0e..1108dfd97 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/templates" + "github.com/pomerium/pomerium/internal/urlutil" ) // StateParameter holds the redirect id along with the session id. @@ -36,9 +37,9 @@ func (p *Proxy) Handler() http.Handler { mux.HandleFunc("/.pomerium", p.UserDashboard) mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST mux.HandleFunc("/.pomerium/sign_out", p.SignOut) - // handlers handlers with validation - mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback)) - mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh)) + // handlers with validation + mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback)) + mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh)) mux.Handle("/", validate.ThenFunc(p.Proxy)) return mux } @@ -60,12 +61,12 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, err) return } - uri, err := url.Parse(r.Form.Get("redirect_uri")) + uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri")) if err == nil && uri.String() != "" { redirectURL = uri } default: - uri, err := url.Parse(r.URL.Query().Get("redirect_uri")) + uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri")) if err == nil && uri.String() != "" { redirectURL = uri } @@ -76,24 +77,20 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { // OAuthStart begins the authenticate flow, encrypting the redirect url // in a request to the provider's sign in endpoint. func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { - - // create a CSRF value used to mitigate replay attacks. state := &StateParameter{ SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()), RedirectURI: r.URL.String(), } - // Encrypt, and save CSRF state. Will be checked on callback. - localState, err := p.cipher.Marshal(state) + // Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback. + csrfState, err := p.cipher.Marshal(state) if err != nil { httputil.ErrorResponse(w, r, err) return } - p.csrfStore.SetCSRF(w, r, localState) + p.csrfStore.SetCSRF(w, r, csrfState) - // Though the plaintext payload is identical, we re-encrypt which will - // create a different cipher text using another nonce - remoteState, err := p.cipher.Marshal(state) + paramState, err := p.cipher.Marshal(state) if err != nil { httputil.ErrorResponse(w, r, err) return @@ -101,68 +98,55 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { // Sanity check. The encrypted payload of local and remote state should // never match as each encryption round uses a cryptographic nonce. - // - // todo(bdd): since this should nearly (1/(2^32*2^32)) never happen should - // we panic as a failure most likely means the rands entropy source is failing? - if remoteState == localState { - p.sessionStore.ClearSession(w, r) - httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil)) - return - } + // if paramState == csrfState { + // httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil)) + // return + // } - signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), remoteState) - log.FromRequest(r).Debug().Str("SigninURL", signinURL.String()).Msg("proxy: oauth start") + signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState) // Redirect the user to the authenticate service along with the encrypted // state which contains a redirect uri back to the proxy and a nonce http.Redirect(w, r, signinURL.String(), http.StatusFound) } -// OAuthCallback validates the cookie sent back from the authenticate service. This function will -// contain an error, or it will contain a `code`; the code can be used to fetch an access token, and -// other metadata, from the authenticator. -// finish the oauth cycle -func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { +// AuthenticateCallback checks the state parameter to make sure it matches the +// local csrf state then redirects the user back to the original intended route. +func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { httputil.ErrorResponse(w, r, err) return } - if callbackError := r.Form.Get("error"); callbackError != "" { - httputil.ErrorResponse(w, r, httputil.Error(callbackError, http.StatusBadRequest, nil)) - return - } - // Encrypted CSRF passed from authenticate service remoteStateEncrypted := r.Form.Get("state") - remoteStatePlain := new(StateParameter) - if err := p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain); err != nil { + var remoteStatePlain StateParameter + if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil { httputil.ErrorResponse(w, r, err) return } - // Encrypted CSRF from session storage c, err := p.csrfStore.GetCSRF(r) if err != nil { httputil.ErrorResponse(w, r, err) return } p.csrfStore.ClearCSRF(w, r) + localStateEncrypted := c.Value - localStatePlain := new(StateParameter) - err = p.cipher.Unmarshal(localStateEncrypted, localStatePlain) + var localStatePlain StateParameter + err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain) if err != nil { httputil.ErrorResponse(w, r, err) return } - // If the encrypted value of local and remote state match, reject. - // Likely a replay attack or nonce-reuse. + // assert no nonce reuse if remoteStateEncrypted == localStateEncrypted { p.sessionStore.ClearSession(w, r) - - httputil.ErrorResponse(w, r, httputil.Error("local and remote state should not match!", http.StatusBadRequest, nil)) - + httputil.ErrorResponse(w, r, + httputil.Error("local and remote state", http.StatusBadRequest, + fmt.Errorf("possible nonce-reuse / replay attack"))) return } @@ -205,13 +189,23 @@ func isCORSPreflight(r *http.Request) bool { r.Header.Get("Origin") != "" } +func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) { + s, err := p.sessionStore.LoadSession(r) + if err != nil { + return nil, fmt.Errorf("proxy: invalid session: %w", err) + } + if err := s.Valid(); err != nil { + return nil, fmt.Errorf("proxy: invalid state: %w", err) + } + return s, nil +} + // Proxy authenticates a request, either proxying the request if it is authenticated, // or starting the authenticate service for validation if not. func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { - // does a route exist for this request? route, ok := p.router(r) if !ok { - httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not a managed route.", r.Host), http.StatusNotFound, nil)) + httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil)) return } @@ -221,30 +215,17 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { return } - s, err := p.restStore.LoadSession(r) - // if authorization bearer token does not exist or fails, use cookie store - if err != nil || s == nil { - s, err = p.sessionStore.LoadSession(r) - if err != nil { - log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("proxy: invalid session, re-authenticating") - p.sessionStore.ClearSession(w, r) - p.OAuthStart(w, r) - return - } - } - - if err = p.authenticate(w, r, s); err != nil { - p.sessionStore.ClearSession(w, r) - httputil.ErrorResponse(w, r, httputil.Error("User unauthenticated", http.StatusUnauthorized, err)) + s, err := p.loadExistingSession(r) + if err != nil { + log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") + p.OAuthStart(w, r) return } authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s) if err != nil { httputil.ErrorResponse(w, r, err) return - } - - if !authorized { + } else if !authorized { httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil)) return } @@ -259,20 +240,13 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { // It also contains certain administrative actions like user impersonation. // Nota bene: This endpoint does authentication, not authorization. func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { - session, err := p.sessionStore.LoadSession(r) + session, err := p.loadExistingSession(r) if err != nil { - log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("proxy: no session, redirecting to auth") - p.sessionStore.ClearSession(w, r) + log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") p.OAuthStart(w, r) return } - if err := p.authenticate(w, r, session); err != nil { - p.sessionStore.ClearSession(w, r) - httputil.ErrorResponse(w, r, httputil.Error("User unauthenticated", http.StatusUnauthorized, err)) - return - } - redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"} isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) if err != nil { @@ -314,13 +288,14 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { templates.New().ExecuteTemplate(w, "dashboard.html", t) } -// Refresh redeems and extends an existing authenticated oidc session with +// ForceRefresh redeems and extends an existing authenticated oidc session with // the underlying identity provider. All session details including groups, // timeouts, will be renewed. -func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) { - session, err := p.sessionStore.LoadSession(r) +func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) { + session, err := p.loadExistingSession(r) if err != nil { - httputil.ErrorResponse(w, r, err) + log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") + p.OAuthStart(w, r) return } iss, err := session.IssuedAt() @@ -332,16 +307,13 @@ func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) { // reject a refresh if it's been less than the refresh cooldown to prevent abuse if time.Since(iss) < p.refreshCooldown { httputil.ErrorResponse(w, r, - httputil.Error(fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown), http.StatusBadRequest, nil)) + httputil.Error( + fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown), + http.StatusBadRequest, nil)) return } - - newSession, err := p.AuthenticateClient.Refresh(r.Context(), session) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - if err = p.sessionStore.SaveSession(w, r, newSession); err != nil { + session.ForceRefresh() + if err = p.sessionStore.SaveSession(w, r, session); err != nil { httputil.ErrorResponse(w, r, err) return } @@ -357,12 +329,12 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, err) return } - session, err := p.sessionStore.LoadSession(r) + session, err := p.loadExistingSession(r) if err != nil { - httputil.ErrorResponse(w, r, err) + log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") + p.OAuthStart(w, r) return } - // authorization check -- is this user an admin? isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) if err != nil || !isAdmin { httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err)) @@ -376,7 +348,7 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { } p.csrfStore.ClearCSRF(w, r) encryptedCSRF := c.Value - decryptedCSRF := new(StateParameter) + var decryptedCSRF StateParameter if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil { httputil.ErrorResponse(w, r, err) return @@ -398,26 +370,6 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/.pomerium", http.StatusFound) } -// Authenticate authenticates a request by checking for a session cookie, and validating its expiration, -// clearing the session cookie if it's invalid and returning an error if necessary.. -func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request, s *sessions.SessionState) error { - if s.RefreshPeriodExpired() { - s, err := p.AuthenticateClient.Refresh(r.Context(), s) - if err != nil { - return fmt.Errorf("proxy: session refresh failed : %v", err) - } - if err := p.sessionStore.SaveSession(w, r, s); err != nil { - return fmt.Errorf("proxy: refresh failed : %v", err) - } - } else { - valid, err := p.AuthenticateClient.Validate(r.Context(), s.IDToken) - if err != nil || !valid { - return fmt.Errorf("proxy: session validate failed: %v : %v", valid, err) - } - } - return nil -} - // router attempts to find a route for a request. If a route is successfully matched, // it returns the route information and a bool value of `true`. If a route can // not be matched, a nil value for the route and false bool value is returned. @@ -461,7 +413,7 @@ func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"}) now := time.Now() rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) + params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux params.Set("redirect_uri", rawRedirect) params.Set("shared_secret", p.SharedKey) params.Set("response_type", "code") @@ -477,7 +429,7 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL { a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"}) now := time.Now() rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) + params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux params.Add("redirect_uri", rawRedirect) params.Set("ts", fmt.Sprint(now.Unix())) params.Set("sig", p.signRedirectURL(rawRedirect, now)) diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 4d14c42d8..7bc4d1154 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -72,7 +72,6 @@ func TestProxy_GetRedirectURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}} - if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) { t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want) } @@ -240,8 +239,7 @@ func TestProxy_router(t *testing.T) { if err != nil { t.Fatal(err) } - p.AuthenticateClient = clients.MockAuthenticate{} - p.cipher = mockCipher{} + p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"} req := httptest.NewRequest(http.MethodGet, tt.host, nil) _, ok := p.router(req) @@ -253,7 +251,7 @@ func TestProxy_router(t *testing.T) { } func TestProxy_Proxy(t *testing.T) { - goodSession := &sessions.SessionState{ + goodSession := &sessions.State{ AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), @@ -278,39 +276,34 @@ func TestProxy_Proxy(t *testing.T) { headersWs.Set("Upgrade", "websocket") tests := []struct { - name string - options config.Options - method string - header http.Header - host string - session sessions.SessionStore - authenticator clients.Authenticator - authorizer clients.Authorizer - wantStatus int + name string + options config.Options + method string + header http.Header + host string + session sessions.SessionStore + authorizer clients.Authorizer + wantStatus int }{ - {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, - {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, - {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, - {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, + {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, + {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, + {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, + {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, // same request as above, but with cors_allow_preflight=false in the policy - {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, + {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, // cors allowed, but the request is missing proper headers - {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, - {"unexpected error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest}, + {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, // redirect to start auth process - {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, - {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, - {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError}, + {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, + {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, + {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError}, // authenticate errors - {"weird load session error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest}, - {"failed refreshed session", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RefreshError: errors.New("refresh error")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized}, - {"cannot resave refreshed session", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{SaveError: errors.New("weird"), Session: &sessions.SessionState{RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized}, - {"authenticate validation error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: false}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized}, - {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, - {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound}, - // no session, redirect to login - {"no http found (no session)", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest}, - {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, + {"session error, redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, + {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, + {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, + {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound}, + {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, + {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, } for _, tt := range tests { @@ -323,13 +316,13 @@ func TestProxy_Proxy(t *testing.T) { if err != nil { t.Fatal(err) } - p.cipher = mockCipher{} + p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"} p.sessionStore = tt.session - p.AuthenticateClient = tt.authenticator p.AuthorizeClient = tt.authorizer r := httptest.NewRequest(tt.method, tt.host, nil) r.Header = tt.header + r.Header.Set("Accept", "application/json") w := httptest.NewRecorder() p.Proxy(w, r) if status := w.Code; status != tt.wantStatus { @@ -348,23 +341,21 @@ func TestProxy_Proxy(t *testing.T) { func TestProxy_UserDashboard(t *testing.T) { opts := testOptions(t) tests := []struct { - name string - options config.Options - method string - cipher cryptutil.Cipher - session sessions.SessionStore - authenticator clients.Authenticator - authorizer clients.Authorizer + name string + options config.Options + method string + cipher cryptutil.Cipher + session sessions.SessionStore + authorizer clients.Authorizer wantAdminForm bool wantStatus int }{ - {"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusOK}, - {"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusBadRequest}, - {"auth failure, validation error", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthenticate{ValidateError: errors.New("not valid anymore"), ValidateResponse: false}, clients.MockAuthorize{}, false, http.StatusUnauthorized}, - {"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, - {"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, - {"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError}, + {"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK}, + {"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, false, http.StatusFound}, + {"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, + {"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, + {"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError}, } for _, tt := range tests { @@ -375,15 +366,18 @@ func TestProxy_UserDashboard(t *testing.T) { } p.cipher = tt.cipher p.sessionStore = tt.session - p.AuthenticateClient = tt.authenticator p.AuthorizeClient = tt.authorizer r := httptest.NewRequest(tt.method, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() p.UserDashboard(w, r) if status := w.Code; status != tt.wantStatus { t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("\n%+v", opts) + t.Errorf("\n%+v", w.Body.String()) + } if adminForm := strings.Contains(w.Body.String(), "impersonate"); adminForm != tt.wantAdminForm { t.Errorf("wanted admin form got %v want %v", adminForm, tt.wantAdminForm) @@ -393,28 +387,27 @@ func TestProxy_UserDashboard(t *testing.T) { } } -func TestProxy_Refresh(t *testing.T) { +func TestProxy_ForceRefresh(t *testing.T) { opts := testOptions(t) opts.RefreshCooldown = 0 timeSinceError := testOptions(t) timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1)) tests := []struct { - name string - options config.Options - method string - cipher cryptutil.Cipher - session sessions.SessionStore - authenticator clients.Authenticator - authorizer clients.Authorizer - wantStatus int + name string + options config.Options + method string + cipher cryptutil.Cipher + session sessions.SessionStore + authorizer clients.Authorizer + wantStatus int }{ - {"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusFound}, - {"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError}, - {"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError}, - {"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusBadRequest}, - {"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{RefreshError: errors.New("err")}, clients.MockAuthorize{}, http.StatusInternalServerError}, - {"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError}, + {"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, + {"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, http.StatusFound}, + {"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError}, + {"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest}, + {"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, + {"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -424,12 +417,11 @@ func TestProxy_Refresh(t *testing.T) { } p.cipher = tt.cipher p.sessionStore = tt.session - p.AuthenticateClient = tt.authenticator p.AuthorizeClient = tt.authorizer r := httptest.NewRequest(tt.method, "/", nil) w := httptest.NewRecorder() - p.Refresh(w, r) + p.ForceRefresh(w, r) if status := w.Code; status != tt.wantStatus { t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("\n%+v", opts) @@ -442,30 +434,29 @@ func TestProxy_Impersonate(t *testing.T) { opts := testOptions(t) tests := []struct { - name string - malformed bool - options config.Options - method string - email string - groups string - csrf string - cipher cryptutil.Cipher - sessionStore sessions.SessionStore - csrfStore sessions.CSRFStore - authenticator clients.Authenticator - authorizer clients.Authorizer - wantStatus int + name string + malformed bool + options config.Options + method string + email string + groups string + csrf string + cipher cryptutil.Cipher + sessionStore sessions.SessionStore + csrfStore sessions.CSRFStore + authorizer clients.Authorizer + wantStatus int }{ - {"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, - {"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, - {"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, - {"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, - {"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + {"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + {"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + // {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, + {"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, + {"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, + {"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -476,7 +467,6 @@ func TestProxy_Impersonate(t *testing.T) { p.cipher = tt.cipher p.sessionStore = tt.sessionStore p.csrfStore = tt.csrfStore - p.AuthenticateClient = tt.authenticator p.AuthorizeClient = tt.authorizer postForm := url.Values{} postForm.Add("email", tt.email) @@ -501,19 +491,17 @@ func TestProxy_Impersonate(t *testing.T) { func TestProxy_OAuthCallback(t *testing.T) { tests := []struct { - name string - csrf sessions.MockCSRFStore - session sessions.MockSessionStore - authenticator clients.MockAuthenticate - params map[string]string - wantCode int + name string + csrf sessions.MockCSRFStore + session sessions.MockSessionStore + params map[string]string + wantCode int }{ - {"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound}, - {"error", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"error": "some error"}, http.StatusBadRequest}, - {"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError}, - {"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, - {"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, - {"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, + {"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound}, + {"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError}, + {"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, + {"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, + {"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, } for _, tt := range tests { @@ -524,7 +512,6 @@ func TestProxy_OAuthCallback(t *testing.T) { } proxy.sessionStore = &tt.session proxy.csrfStore = tt.csrf - proxy.AuthenticateClient = tt.authenticator proxy.cipher = mockCipher{} // proxy.Csrf req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil) @@ -537,7 +524,7 @@ func TestProxy_OAuthCallback(t *testing.T) { req.URL.RawQuery = "email=%zzzzz" } w := httptest.NewRecorder() - proxy.OAuthCallback(w, req) + proxy.AuthenticateCallback(w, req) if status := w.Code; status != tt.wantCode { t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 4063d1fff..026ab50d3 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -2,11 +2,9 @@ package proxy // import "github.com/pomerium/pomerium/proxy" import ( "crypto/tls" - "encoding/base64" "fmt" "html/template" stdlog "log" - "net" "net/http" "net/http/httputil" "net/url" @@ -39,51 +37,27 @@ const ( // ValidateOptions checks that proper configuration settings are set to create // a proper Proxy instance func ValidateOptions(o config.Options) error { - decoded, err := base64.StdEncoding.DecodeString(o.SharedKey) - if err != nil { - return fmt.Errorf("`SHARED_SECRET` setting is invalid base64: %v", err) + if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil { + return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err) } - if len(decoded) != 32 { - return fmt.Errorf("`SHARED_SECRET` want 32 but got %d bytes", len(decoded)) + if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { + return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err) } - if o.AuthenticateURL == nil { - return fmt.Errorf("proxy: missing setting: authenticate-service-url") + return fmt.Errorf("proxy: missing 'AUTHENTICATE_SERVICE_URL'") } if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil { - return fmt.Errorf("proxy: error parsing authenticate url: %v", err) + return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err) } - if o.AuthorizeURL == nil { - return fmt.Errorf("proxy: missing setting: authenticate-service-url") + return fmt.Errorf("proxy: missing 'AUTHORIZE_SERVICE_URL'") } if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil { - return fmt.Errorf("proxy: error parsing authorize url: %v", err) - } - if o.AuthenticateInternalAddr != nil { - if _, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddr.String()); err != nil { - return fmt.Errorf("proxy: error parsing authorize url: %v", err) - } - } - - if o.CookieSecret == "" { - return fmt.Errorf("proxy: missing setting: cookie-secret") - } - decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret) - if err != nil { - return fmt.Errorf("proxy: cookie secret is invalid base64: %v", err) - } - if len(decodedCookieSecret) != 32 { - return fmt.Errorf("proxy: cookie secret expects 32 bytes but got %d", len(decodedCookieSecret)) + return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err) } if len(o.SigningKey) != 0 { - decodedSigningKey, err := base64.StdEncoding.DecodeString(o.SigningKey) - if err != nil { - return fmt.Errorf("proxy: signing key is invalid base64: %v", err) - } - _, err = cryptutil.NewES256Signer(decodedSigningKey, "localhost") - if err != nil { - return fmt.Errorf("proxy: invalid signing key is : %v", err) + if _, err := cryptutil.NewES256Signer(o.SigningKey, "localhost"); err != nil { + return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err) } } return nil @@ -92,12 +66,11 @@ func ValidateOptions(o config.Options) error { // Proxy stores all the information associated with proxying a request. type Proxy struct { // SharedKey used to mutually authenticate service communication - SharedKey string - authenticateURL *url.URL - authenticateInternalAddr *url.URL - authorizeURL *url.URL - AuthenticateClient clients.Authenticator - AuthorizeClient clients.Authorizer + SharedKey string + authenticateURL *url.URL + authorizeURL *url.URL + + AuthorizeClient clients.Authorizer cipher cryptutil.Cipher cookieName string @@ -105,7 +78,6 @@ type Proxy struct { defaultUpstreamTimeout time.Duration redirectURL *url.URL refreshCooldown time.Duration - restStore sessions.SessionStore routeConfigs map[string]*routeConfig sessionStore sessions.SessionStore signingKey string @@ -123,11 +95,9 @@ func New(opts config.Options) (*Proxy, error) { if err := ValidateOptions(opts); err != nil { return nil, err } - // error explicitly handled by validate - decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) - cipher, err := cryptutil.NewCipher(decodedSecret) + cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret) if err != nil { - return nil, fmt.Errorf("cookie-secret error: %s", err.Error()) + return nil, err } cookieStore, err := sessions.NewCookieStore( @@ -140,10 +110,6 @@ func New(opts config.Options) (*Proxy, error) { CookieCipher: cipher, }) - if err != nil { - return nil, err - } - restStore, err := sessions.NewRestStore(&sessions.RestStoreOptions{Cipher: cipher}) if err != nil { return nil, err } @@ -158,7 +124,6 @@ func New(opts config.Options) (*Proxy, error) { defaultUpstreamTimeout: opts.DefaultUpstreamTimeout, redirectURL: &url.URL{Path: "/.pomerium/callback"}, refreshCooldown: opts.RefreshCooldown, - restStore: restStore, sessionStore: cookieStore, signingKey: opts.SigningKey, templates: templates.New(), @@ -166,7 +131,6 @@ func New(opts config.Options) (*Proxy, error) { // DeepCopy urls to avoid accidental mutation, err checked in validate func p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) - p.authenticateInternalAddr, _ = urlutil.DeepCopy(opts.AuthenticateInternalAddr) if err := p.UpdatePolicies(&opts); err != nil { return nil, err @@ -174,20 +138,6 @@ func New(opts config.Options) (*Proxy, error) { metrics.AddPolicyCountCallback("proxy", func() int64 { return int64(len(p.routeConfigs)) }) - p.AuthenticateClient, err = clients.NewAuthenticateClient("grpc", - &clients.Options{ - Addr: p.authenticateURL, - InternalAddr: p.authenticateInternalAddr, - OverrideCertificateName: opts.OverrideCertificateName, - SharedSecret: opts.SharedKey, - CA: opts.CA, - CAFile: opts.CAFile, - RequestTimeout: opts.GRPCClientTimeout, - ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin, - }) - if err != nil { - return nil, err - } p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc", &clients.Options{ Addr: p.authorizeURL, @@ -213,19 +163,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error { } proxy := NewReverseProxy(policy.Destination) // build http transport (roundtripper) middleware chain - // todo(bdd): replace with transport.Clone() in go 1.13 - transport := http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } + transport := http.DefaultTransport.(*http.Transport).Clone() c := tripper.NewChain() c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host)) @@ -253,7 +191,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error { if isCustomClientConfig { transport.TLSClientConfig = &tlsClientConfig } - proxy.Transport = c.Then(&transport) + proxy.Transport = c.Then(transport) handler, err := p.newReverseProxyHandler(proxy, &policy) if err != nil { @@ -298,15 +236,6 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { return proxy } -// newRouteSigner creates a route specific signer. -func (p *Proxy) newRouteSigner(audience string) (cryptutil.JWTSigner, error) { - decodedSigningKey, err := base64.StdEncoding.DecodeString(p.signingKey) - if err != nil { - return nil, err - } - return cryptutil.NewES256Signer(decodedSigningKey, audience) -} - // newReverseProxyHandler applies handler specific options to a given route. func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) { handler = &UpstreamProxy{ @@ -318,7 +247,7 @@ func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config. // if signing key is set, add signer to middleware if len(p.signingKey) != 0 { - signer, err := p.newRouteSigner(route.Source.Host) + signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host) if err != nil { return nil, err } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 9b051fd68..5ee67fde6 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -169,9 +169,6 @@ func TestOptions_Validate(t *testing.T) { authurl, _ := url.Parse("authenticate.corp.beyondperimeter.com") authenticateBadScheme := testOptions(t) authenticateBadScheme.AuthenticateURL = authurl - authenticateInternalBadScheme := testOptions(t) - authenticateInternalBadScheme.AuthenticateInternalAddr = authurl - authorizeBadSCheme := testOptions(t) authorizeBadSCheme.AuthorizeURL = authurl authorizeNil := testOptions(t) @@ -200,7 +197,6 @@ func TestOptions_Validate(t *testing.T) { {"nil options", config.Options{}, true}, {"authenticate service url", badAuthURL, true}, {"authenticate service url no scheme", authenticateBadScheme, true}, - {"internal authenticate service url no scheme", authenticateInternalBadScheme, true}, {"authorize service url no scheme", authorizeBadSCheme, true}, {"authorize service cannot be nil", authorizeNil, true}, {"no cookie secret", emptyCookieSecret, true}, @@ -221,7 +217,6 @@ func TestOptions_Validate(t *testing.T) { } func TestNew(t *testing.T) { - good := testOptions(t) shortCookieLength := testOptions(t) shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="