diff --git a/README.md b/README.md index 4ada87078..c3a71f271 100644 --- a/README.md +++ b/README.md @@ -4,24 +4,24 @@ # Pomerium -[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://codecov.io/gh/pomerium/pomerium/branch/master/graph/badge.svg)](https://codecov.io/gh/pomerium/pomerium) +[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://img.shields.io/codecov/c/github/pomerium/pomerium.svg?style=flat)](https://codecov.io/gh/pomerium/pomerium) Pomerium is a tool for managing secure access to internal applications and resources. Use Pomerium to: -- provide a unified gateway (reverse-proxy) to internal corporate applications. -- enforce dynamic access policy based on context, identity, and device state. -- deploy mutual authenticated encryption (mTLS). -- aggregate logging and telemetry data. +- provide a single-sign-on gateway to internal applications. +- enforce dynamic access policy based on **context**, **identity**, and **device state**. +- aggregate access logs and telemetry data. +- an alternative to a VPN. -Check out [awesome-zero-trust] to learn more about some problems Pomerium attempts to address. +Check out [awesome-zero-trust] to learn more about some of the problems Pomerium attempts to address. ## Docs To get started with pomerium, check out our [quick start guide]. -For comprehensive docs see our [documentation] and the [godocs]. +For comprehensive docs, and tutorials see our [documentation] and the [godocs]. [awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust [documentation]: https://www.pomerium.io/docs/ diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 01ee92df8..93d48091d 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -18,37 +18,38 @@ import ( ) var defaultOptions = &Options{ - CookieName: "_pomerium_authenticate", - CookieHTTPOnly: true, - CookieSecure: true, - CookieExpire: time.Duration(168) * time.Hour, - CookieRefresh: time.Duration(30) * time.Minute, - CookieLifetimeTTL: time.Duration(720) * time.Hour, + CookieName: "_pomerium_authenticate", + CookieHTTPOnly: true, + CookieSecure: true, + CookieExpire: time.Duration(168) * time.Hour, + CookieRefresh: time.Duration(30) * time.Minute, } // Options details the available configuration settings for the authenticate service type Options struct { - RedirectURL *url.URL `envconfig:"REDIRECT_URL"` - // SharedKey is used to authenticate requests between services SharedKey string `envconfig:"SHARED_SECRET"` + // RedirectURL specifies the callback url following third party authentication + RedirectURL *url.URL `envconfig:"REDIRECT_URL"` + // Coarse authorization based on user email domain + // todo(bdd) : to be replaced with authorization module AllowedDomains []string `envconfig:"ALLOWED_DOMAINS"` ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"` // Session/Cookie management - CookieName string - CookieSecret string `envconfig:"COOKIE_SECRET"` - CookieDomain string `envconfig:"COOKIE_DOMAIN"` - CookieSecure bool `envconfig:"COOKIE_SECURE"` - CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"` - CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"` - CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"` - CookieLifetimeTTL time.Duration `envconfig:"COOKIE_LIFETIME"` + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie + CookieName string + CookieSecret string `envconfig:"COOKIE_SECRET"` + CookieDomain string `envconfig:"COOKIE_DOMAIN"` + CookieSecure bool `envconfig:"COOKIE_SECURE"` + CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"` + CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"` + CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"` // IdentityProvider provider configuration variables as specified by RFC6749 - // See: https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749 + // https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749 ClientID string `envconfig:"IDP_CLIENT_ID"` ClientSecret string `envconfig:"IDP_CLIENT_SECRET"` Provider string `envconfig:"IDP_PROVIDER"` @@ -103,17 +104,13 @@ func (o *Options) Validate() error { // Authenticate validates a user's identity type Authenticate struct { - RedirectURL *url.URL - - Validator func(string) bool - - AllowedDomains []string - ProxyRootDomains []string - CookieSecure bool - SharedKey string - CookieLifetimeTTL time.Duration + RedirectURL *url.URL + AllowedDomains []string + ProxyRootDomains []string + + Validator func(string) bool templates *template.Template csrfStore sessions.CSRFStore @@ -137,37 +134,47 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate if err != nil { return nil, err } - cookieStore, err := sessions.NewCookieStore(opts.CookieName, - sessions.CreateCookieCipher(decodedCookieSecret), - func(c *sessions.CookieStore) error { - c.CookieDomain = opts.CookieDomain - c.CookieHTTPOnly = opts.CookieHTTPOnly - c.CookieExpire = opts.CookieExpire - c.CookieSecure = opts.CookieSecure - return nil + cookieStore, err := sessions.NewCookieStore( + &sessions.CookieStoreOptions{ + Name: opts.CookieName, + CookieSecure: opts.CookieSecure, + CookieHTTPOnly: opts.CookieHTTPOnly, + CookieExpire: opts.CookieExpire, + CookieCipher: cipher, }) if err != nil { return nil, err } - p := &Authenticate{ - SharedKey: opts.SharedKey, - AllowedDomains: opts.AllowedDomains, - ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains), - CookieSecure: opts.CookieSecure, - RedirectURL: opts.RedirectURL, - templates: templates.New(), - csrfStore: cookieStore, - sessionStore: cookieStore, - cipher: cipher, - } - - p.provider, err = newProvider(opts) + provider, err := providers.New( + opts.Provider, + &providers.IdentityProvider{ + RedirectURL: opts.RedirectURL, + ProviderName: opts.Provider, + ProviderURL: opts.ProviderURL, + ClientID: opts.ClientID, + ClientSecret: opts.ClientSecret, + // SessionLifetimeTTL: opts.CookieLifetimeTTL, + Scopes: opts.Scopes, + }) if err != nil { return nil, err } + p := &Authenticate{ + SharedKey: opts.SharedKey, + RedirectURL: opts.RedirectURL, + AllowedDomains: opts.AllowedDomains, + ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains), + + templates: templates.New(), + csrfStore: cookieStore, + sessionStore: cookieStore, + cipher: cipher, + provider: provider, + } + // validation via dependency injected function for _, optFunc := range optionFuncs { err := optFunc(p) @@ -179,20 +186,6 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate return p, nil } -func newProvider(opts *Options) (providers.Provider, error) { - pd := &providers.IdentityProvider{ - RedirectURL: opts.RedirectURL, - ProviderName: opts.Provider, - ProviderURL: opts.ProviderURL, - ClientID: opts.ClientID, - ClientSecret: opts.ClientSecret, - SessionLifetimeTTL: opts.CookieLifetimeTTL, - Scopes: opts.Scopes, - } - np, err := providers.New(opts.Provider, pd) - return np, err -} - func dotPrependDomains(d []string) []string { for i := range d { if d[i] != "" && !strings.HasPrefix(d[i], ".") { diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 9c16229f6..9bb26a901 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -11,16 +11,17 @@ import ( func testOptions() *Options { redirectURL, _ := url.Parse("https://example.com/oauth2/callback") return &Options{ - ProxyRootDomains: []string{"example.com"}, - AllowedDomains: []string{"example.com"}, - RedirectURL: redirectURL, - SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", - ClientID: "test-client-id", - ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", - CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", - CookieRefresh: time.Duration(1) * time.Hour, - CookieLifetimeTTL: time.Duration(720) * time.Hour, - CookieExpire: time.Duration(168) * time.Hour, + ProxyRootDomains: []string{"example.com"}, + AllowedDomains: []string{"example.com"}, + RedirectURL: redirectURL, + SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + ClientID: "test-client-id", + ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", + CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", + CookieRefresh: time.Duration(1) * time.Hour, + // CookieLifetimeTTL: time.Duration(720) * time.Hour, + CookieExpire: time.Duration(168) * time.Hour, + CookieName: "pomerium", } } @@ -130,37 +131,6 @@ func Test_dotPrependDomains(t *testing.T) { } } -func Test_newProvider(t *testing.T) { - redirectURL, _ := url.Parse("https://example.com/oauth3/callback") - - goodOpts := &Options{ - RedirectURL: redirectURL, - Provider: "google", - ProviderURL: "", - ClientID: "cllient-id", - ClientSecret: "client-secret", - } - tests := []struct { - name string - opts *Options - wantErr bool - }{ - {"good", goodOpts, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := newProvider(tt.opts) - if (err != nil) != tt.wantErr { - t.Errorf("newProvider() error = %v, wantErr %v", err, tt.wantErr) - return - } - // if !reflect.DeepEqual(got, tt.want) { - // t.Errorf("newProvider() = %v, want %v", got, tt.want) - // } - }) - } -} - func TestNew(t *testing.T) { good := testOptions() good.Provider = "google" diff --git a/authenticate/grpc_test.go b/authenticate/grpc_test.go index 3551379ff..cc93a162a 100644 --- a/authenticate/grpc_test.go +++ b/authenticate/grpc_test.go @@ -117,7 +117,6 @@ func TestAuthenticate_Authenticate(t *testing.T) { } lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC() rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC() - vt := time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC() vtProto, err := ptypes.TimestampProto(rt) if err != nil { t.Fatal("failed to parse timestamp") @@ -128,9 +127,9 @@ func TestAuthenticate_Authenticate(t *testing.T) { RefreshToken: "refresh4321", LifetimeDeadline: lt, RefreshDeadline: rt, - ValidDeadline: vt, - Email: "user@domain.com", - User: "user", + + Email: "user@domain.com", + User: "user", } goodReply := &pb.AuthenticateReply{ diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9147f350b..4596ce7eb 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -16,7 +16,8 @@ import ( "github.com/pomerium/pomerium/internal/version" ) -// securityHeaders corresponds to HTTP response headers related to security. +// securityHeaders corresponds to HTTP response headers that help to protect against protocol +// downgrade attacks and cookie hijacking. // https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers var securityHeaders = map[string]string{ "Strict-Transport-Security": "max-age=31536000", @@ -28,7 +29,7 @@ var securityHeaders = map[string]string{ "Referrer-Policy": "Same-origin", } -// Handler returns the Http.Handlers for authenticate, callback, and refresh +// Handler returns the authenticate service's HTTP request multiplexer, and routes. func (a *Authenticate) Handler() http.Handler { // set up our standard middlewares stdMiddleware := middleware.NewChain() @@ -80,12 +81,6 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se return nil, err } - // if long-lived lifetime has expired, clear session - if session.LifetimePeriodExpired() { - log.FromRequest(r).Warn().Msg("authenticate: lifetime expired") - a.sessionStore.ClearSession(w, r) - return nil, sessions.ErrLifetimeExpired - } // check if session refresh period is up if session.RefreshPeriodExpired() { newToken, err := a.provider.Refresh(session.RefreshToken) @@ -130,32 +125,23 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se return session, nil } -// SignIn handles the /sign_in endpoint. It attempts to authenticate the user, +// SignIn handles the sign_in endpoint. It attempts to authenticate the user, // and if the user is not authenticated, it renders a sign in page. func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { session, err := a.authenticate(w, r) - switch err { - case nil: - // session good, redirect back to proxy - log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated") - a.ProxyCallback(w, r, session) - case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession: - // session invalid, authenticate - log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure") - if err != http.ErrNoCookie { - a.sessionStore.ClearSession(w, r) - } + if err != nil { + log.FromRequest(r).Info().Err(err).Msg("authenticate: authenticate error") + a.sessionStore.ClearSession(w, r) a.OAuthStart(w, r) - - default: - log.Error().Err(err).Msg("authenticate: unexpected sign in error") - httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err)) } + log.FromRequest(r).Info().Msg("authenticate: user authenticated") + a.ProxyCallback(w, r, session) + } // ProxyCallback redirects the user back to proxy service along with an encrypted payload, as -// url params, of the user's session state. -// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2 +// url params, of the user's session state as specified in RFC6749 3.1.2. +// https://tools.ietf.org/html/rfc6749#section-3.1.2 func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) { err := r.ParseForm() if err != nil { @@ -201,9 +187,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string return u.String() } -// SignOut signs the user out by trying to revoke the users remote identity provider session -// then removes the associated local session state. -// Handles both GET and POST of form. +// SignOut signs the user out by trying to revoke the user's remote identity session along with +// the associated local session state. Handles both GET and POST. func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { @@ -256,8 +241,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, redirectURI, http.StatusFound) } -// OAuthStart starts the authenticate process by redirecting to the provider. It provides a -// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate. +// OAuthStart starts the authenticate process by redirecting to the identity provider. +// https://tools.ietf.org/html/rfc6749#section-4.2.1 func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri")) if err != nil { @@ -298,7 +283,7 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { } // OAuthCallback handles the callback from the identity provider. Displays an error page if there -// was an error. If successful, redirects back to the proxy-service via the redirect-url. +// was an error. If successful, the user is redirected back to the proxy-service. func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { redirect, err := a.getOAuthCallback(w, r) switch h := err.(type) { diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index be4b69a29..f92f70c31 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -71,29 +71,19 @@ func TestAuthenticate_Handler(t *testing.T) { func TestAuthenticate_authenticate(t *testing.T) { // sessions.MockSessionStore{Session: expiredLifetime} - goodSession := sessions.MockSessionStore{ + goodSession := &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), }} - expiredSession := sessions.MockSessionStore{ + + expiredRefresPeriod := &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * -time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), - }} - expiredRefresPeriod := sessions.MockSessionStore{ - Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * -time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + + RefreshDeadline: time.Now().Add(10 * -time.Second), }} tests := []struct { @@ -106,18 +96,16 @@ func TestAuthenticate_authenticate(t *testing.T) { }{ {"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false}, {"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true}, - {"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, + {"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, {"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true}, - {"session fails after good validation", sessions.MockSessionStore{ + {"session fails after good validation", &sessions.MockSessionStore{ SaveError: errors.New("error"), Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), - }}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, - {"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), + }}, providers.MockProvider{ValidateResponse: true}, + trueValidator, nil, true}, {"refresh expired", expiredRefresPeriod, providers.MockProvider{ @@ -136,14 +124,13 @@ func TestAuthenticate_authenticate(t *testing.T) { }, trueValidator, nil, true}, {"refresh expired failed save", - sessions.MockSessionStore{ + &sessions.MockSessionStore{ SaveError: errors.New("error"), Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * -time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + + RefreshDeadline: time.Now().Add(10 * -time.Second), }}, providers.MockProvider{ ValidateResponse: true, @@ -182,29 +169,23 @@ func TestAuthenticate_SignIn(t *testing.T) { wantCode int }{ {"good", - sessions.MockSessionStore{ + &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), }}, providers.MockProvider{ValidateResponse: true}, trueValidator, - 403}, - // {"no session", - // sessions.MockSessionStore{ - // Session: &sessions.SessionState{ - // AccessToken: "AccessToken", - // RefreshToken: "RefreshToken", - // LifetimeDeadline: time.Now().Add(-10 * time.Second), - // RefreshDeadline: time.Now().Add(10 * time.Second), - // ValidDeadline: time.Now().Add(10 * time.Second), - // }}, - // providers.MockProvider{ValidateResponse: true}, - // trueValidator, - // 200}, + http.StatusForbidden}, + {"session fails after good validation", &sessions.MockSessionStore{ + SaveError: errors.New("error"), + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), + }}, providers.MockProvider{ValidateResponse: true}, + trueValidator, http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -212,6 +193,10 @@ func TestAuthenticate_SignIn(t *testing.T) { sessionStore: tt.session, provider: tt.provider, Validator: tt.validator, + RedirectURL: uriParse("http://www.pomerium.io"), + csrfStore: &sessions.MockCSRFStore{}, + SharedKey: "secret", + cipher: mockCipher{}, } r := httptest.NewRequest("GET", "/sign-in", nil) w := httptest.NewRecorder() @@ -262,13 +247,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) { }{ {"good", "https://corp.pomerium.io/", "state", "code", &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, 302, "Found."}, {"no state", @@ -276,13 +260,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) { "", "code", &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, 403, "no state parameter supplied"}, {"no redirect_url", @@ -290,13 +273,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) { "state", "code", &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, 403, "no redirect_uri parameter"}, {"malformed redirect_url", @@ -304,13 +286,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) { "state", "code", &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, 400, "malformed redirect_uri"}, } @@ -389,14 +370,13 @@ func TestAuthenticate_SignOut(t *testing.T) { "sig", "ts", providers.MockProvider{}, - sessions.MockSessionStore{ + &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, }, http.StatusFound, @@ -407,14 +387,13 @@ func TestAuthenticate_SignOut(t *testing.T) { "sig", "ts", providers.MockProvider{RevokeError: errors.New("OH NO")}, - sessions.MockSessionStore{ + &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, }, http.StatusBadRequest, @@ -426,14 +405,13 @@ func TestAuthenticate_SignOut(t *testing.T) { "sig", "ts", providers.MockProvider{}, - sessions.MockSessionStore{ + &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, }, http.StatusOK, @@ -444,15 +422,14 @@ func TestAuthenticate_SignOut(t *testing.T) { "sig", "ts", providers.MockProvider{}, - sessions.MockSessionStore{ + &sessions.MockSessionStore{ LoadError: errors.New("uh oh"), Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, }, http.StatusBadRequest, @@ -463,14 +440,13 @@ func TestAuthenticate_SignOut(t *testing.T) { "sig", "ts", providers.MockProvider{}, - sessions.MockSessionStore{ + &sessions.MockSessionStore{ Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }, }, http.StatusBadRequest, @@ -512,7 +488,6 @@ func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string } func TestAuthenticate_OAuthStart(t *testing.T) { - tests := []struct { name string method string @@ -634,15 +609,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -657,15 +631,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -681,15 +654,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -704,7 +676,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateError: errors.New("error"), }, @@ -721,15 +693,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{SaveError: errors.New("error")}, + &sessions.MockSessionStore{SaveError: errors.New("error")}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -744,15 +715,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, falseValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -768,15 +738,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -791,15 +760,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -814,15 +782,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { "nonce:https://corp.pomerium.io", []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -837,15 +804,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", @@ -860,15 +826,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) { base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), []string{"pomerium.io"}, trueValidator, - sessions.MockSessionStore{}, + &sessions.MockSessionStore{}, providers.MockProvider{ AuthenticateResponse: sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - Email: "blah@blah.com", - LifetimeDeadline: time.Now().Add(10 * time.Second), - RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + + RefreshDeadline: time.Now().Add(10 * time.Second), }}, sessions.MockCSRFStore{ ResponseCSRF: "csrf", diff --git a/authenticate/providers/doc.go b/authenticate/providers/doc.go index 2f1e9f108..39f18b6b9 100644 --- a/authenticate/providers/doc.go +++ b/authenticate/providers/doc.go @@ -1,5 +1,5 @@ -// Package providers implements OpenID Connect client logic for the set of supported identity -// providers. -// OpenID Connect 1.0 is a simple identity layer on top of the OAuth 2.0 RFC6749 protocol. -// https://openid.net/specs/openid-connect-core-1_0.html +// Package providers authentication for third party identity providers (IdP) using OpenID +// Connect, an identity layer on top of the OAuth 2.0 RFC6749 protocol. +// +// see: https://openid.net/specs/openid-connect-core-1_0.html package providers // import "github.com/pomerium/pomerium/internal/providers" diff --git a/authenticate/providers/oidc.go b/authenticate/providers/oidc.go index 5225be3ee..8721d4b4a 100644 --- a/authenticate/providers/oidc.go +++ b/authenticate/providers/oidc.go @@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers" import ( "context" - "errors" oidc "github.com/pomerium/go-oidc" "golang.org/x/oauth2" @@ -19,7 +18,7 @@ type OIDCProvider struct { func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) { ctx := context.Background() if p.ProviderURL == "" { - return nil, errors.New("missing required provider url") + return nil, ErrMissingProviderURL } var err error p.provider, err = oidc.NewProvider(ctx, p.ProviderURL) diff --git a/authenticate/providers/okta.go b/authenticate/providers/okta.go index efc390d26..5bb782887 100644 --- a/authenticate/providers/okta.go +++ b/authenticate/providers/okta.go @@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers" import ( "context" - "errors" "net/url" oidc "github.com/pomerium/go-oidc" @@ -25,7 +24,7 @@ type OktaProvider struct { func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) { ctx := context.Background() if p.ProviderURL == "" { - return nil, errors.New("missing required provider url") + return nil, ErrMissingProviderURL } var err error p.provider, err = oidc.NewProvider(ctx, p.ProviderURL) diff --git a/authenticate/providers/providers.go b/authenticate/providers/providers.go index eef77f056..90ce78eb6 100644 --- a/authenticate/providers/providers.go +++ b/authenticate/providers/providers.go @@ -29,6 +29,11 @@ const ( OktaProviderName = "okta" ) +var ( + // ErrMissingProviderURL is returned when the CB state is half open and the requests count is over the cb maxRequests + ErrMissingProviderURL = errors.New("proxy/providers: missing provider url") +) + // Provider is an interface exposing functions necessary to interact with a given provider. type Provider interface { Authenticate(string) (*sessions.SessionState, error) diff --git a/go.mod b/go.mod index 9feeea951..edd70c7fa 100644 --- a/go.mod +++ b/go.mod @@ -5,18 +5,17 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/mock v1.2.0 github.com/golang/protobuf v1.2.0 - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pomerium/envconfig v1.3.1-0.20190112072701-14cbcf832d31 github.com/pomerium/go-oidc v2.0.0+incompatible github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/rs/zerolog v1.11.0 - github.com/stretchr/testify v1.2.2 // indirect - golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 - golang.org/x/net v0.0.0-20181220203305-927f97764cc3 - golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 - golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect - golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect - google.golang.org/appengine v1.4.0 // indirect + github.com/stretchr/testify v1.3.0 // indirect + golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 + golang.org/x/net v0.0.0-20190213061140-3a22650c66bd + golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635 + golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect + golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 // indirect + google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa // indirect google.golang.org/grpc v1.18.0 - gopkg.in/square/go-jose.v2 v2.2.1 + gopkg.in/square/go-jose.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index d537ad4e5..5544458aa 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,14 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc= github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= @@ -23,34 +26,45 @@ github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAm github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M= github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= -github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE= +golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE= -golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635 h1:dOJmQysgY8iOBECuNp0vlKHWEtfiTnyjisEizRV3/4o= +golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190116161447-11f53e031339 h1:g/Jesu8+QLnA0CPzF3E1pURg0Byr7i6jLoX5sqjcAh0= -golang.org/x/sys v0.0.0-20190116161447-11f53e031339/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa h1:FVL+/MjP2dzG4PxLpCJR7B6esIia88UAbsfYUrCc8U4= +google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA= google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic= -gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA= +gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/sessions/cookie_store.go b/internal/sessions/cookie_store.go index d7af2be84..8ad4f8a93 100644 --- a/internal/sessions/cookie_store.go +++ b/internal/sessions/cookie_store.go @@ -29,47 +29,43 @@ type SessionStore interface { // CookieStore represents all the cookie related configurations type CookieStore struct { - Name string - CSRFCookieName string - CookieExpire time.Duration - CookieRefresh time.Duration - CookieSecure bool - CookieHTTPOnly bool - CookieDomain string - CookieCipher cryptutil.Cipher - SessionLifetimeTTL time.Duration + Name string + CSRFCookieName string + CookieCipher cryptutil.Cipher + CookieExpire time.Duration + CookieRefresh time.Duration + CookieSecure bool + CookieHTTPOnly bool + CookieDomain string } -// CreateCookieCipher creates a new miscreant cipher with the cookie secret -func CreateCookieCipher(cookieSecret []byte) func(s *CookieStore) error { - return func(s *CookieStore) error { - cipher, err := cryptutil.NewCipher(cookieSecret) - if err != nil { - return fmt.Errorf("cookie-secret error: %s", err.Error()) - } - s.CookieCipher = cipher - return nil - } +// CookieStoreOptions holds options for CookieStore +type CookieStoreOptions struct { + Name string + CookieSecure bool + CookieHTTPOnly bool + CookieDomain string + CookieExpire time.Duration + CookieCipher cryptutil.Cipher } // NewCookieStore returns a new session with ciphers for each of the cookie secrets -func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) { - c := &CookieStore{ - Name: cookieName, - CookieSecure: true, - CookieHTTPOnly: true, - CookieExpire: 168 * time.Hour, - CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"), +func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) { + if opts.Name == "" { + return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty") } - - for _, f := range optFuncs { - err := f(c) - if err != nil { - return nil, err - } + if opts.CookieCipher == nil { + return nil, fmt.Errorf("internal/sessions: cipher cannot be nil") } - - return c, nil + return &CookieStore{ + Name: opts.Name, + CSRFCookieName: fmt.Sprintf("%v_%v", opts.Name, "csrf"), + CookieSecure: opts.CookieSecure, + CookieHTTPOnly: opts.CookieHTTPOnly, + CookieDomain: opts.CookieDomain, + CookieExpire: opts.CookieExpire, + CookieCipher: opts.CookieCipher, + }, nil } func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { @@ -80,16 +76,19 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e if s.CookieDomain != "" { domain = s.CookieDomain } - - return &http.Cookie{ + c := &http.Cookie{ Name: name, Value: value, Path: "/", Domain: domain, HttpOnly: s.CookieHTTPOnly, Secure: s.CookieSecure, - Expires: now.Add(expiration), } + // only set an expiration if we want one, otherwise default to non perm session based + if expiration != 0 { + c.Expires = now.Add(expiration) + } + return c } // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. @@ -103,13 +102,13 @@ func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration } // ClearCSRF clears the CSRF cookie from the request -func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) +func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) { + http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) } // SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request -func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now())) +func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now())) } // GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request @@ -118,20 +117,19 @@ func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) { } // ClearSession clears the session cookie from a request -func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now())) +func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) { + http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now())) } -func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now())) +func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now())) } // LoadSession returns a SessionState from the cookie in the request. func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) { c, err := req.Cookie(s.Name) if err != nil { - // always http.ErrNoCookie - return nil, err + return nil, err // http.ErrNoCookie } session, err := UnmarshalSession(c.Value, s.CookieCipher) if err != nil { @@ -141,12 +139,11 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) { } // SaveSession saves a session state to a request sessions. -func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error { +func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error { value, err := MarshalSession(sessionState, s.CookieCipher) if err != nil { return err } - - s.setSessionCookie(rw, req, value) + s.setSessionCookie(w, req, value) return nil } diff --git a/internal/sessions/cookie_store_test.go b/internal/sessions/cookie_store_test.go index 91da9568d..74754aae5 100644 --- a/internal/sessions/cookie_store_test.go +++ b/internal/sessions/cookie_store_test.go @@ -1,348 +1,348 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package sessions import ( - "encoding/base64" - "fmt" + "errors" "net/http" "net/http/httptest" + "reflect" "testing" "time" - "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/internal/cryptutil" ) -var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=") +type mockCipher struct{} -func TestCreateCookieCipher(t *testing.T) { - testCases := []struct { - name string - cookieSecret []byte - expectedError bool - }{ - { - name: "normal case with base64 encoded secret", - cookieSecret: testEncodedCookieSecret, - }, - - { - name: "error when not base64 encoded", - cookieSecret: []byte("abcd"), - expectedError: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := NewCookieStore("cookieName", CreateCookieCipher(tc.cookieSecret)) - if !tc.expectedError { - testutil.Ok(t, err) - } else { - testutil.NotEqual(t, err, nil) - } - }) +func (a mockCipher) Encrypt(s []byte) ([]byte, error) { + if string(s) == "error" { + return []byte(""), errors.New("error encrypting") } + return []byte("OK"), nil } -func TestNewSession(t *testing.T) { - testCases := []struct { - name string - optFuncs []func(*CookieStore) error - expectedError bool - expectedSession *CookieStore +func (a mockCipher) Decrypt(s []byte) ([]byte, error) { + if string(s) == "error" { + return []byte(""), errors.New("error encrypting") + } + return []byte("OK"), nil +} +func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") } +func (a mockCipher) Unmarshal(s string, i interface{}) error { + if string(s) == "unmarshal error" || string(s) == "error" { + return errors.New("error") + } + return nil +} +func TestNewCookieStore(t *testing.T) { + cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) + if err != nil { + t.Fatal(err) + } + tests := []struct { + name string + opts *CookieStoreOptions + want *CookieStore + wantErr bool }{ - { - name: "default with no opt funcs set", - expectedSession: &CookieStore{ - Name: "cookieName", + {"good", + &CookieStoreOptions{ + Name: "_cookie", CookieSecure: true, CookieHTTPOnly: true, - CookieExpire: 168 * time.Hour, - CSRFCookieName: "cookieName_csrf", + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher, }, - }, - { - name: "opt func with an error returns an error", - optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }}, - expectedError: true, - }, - { - name: "opt func overrides default values", - optFuncs: []func(*CookieStore) error{func(s *CookieStore) error { - s.CookieExpire = time.Hour - return nil - }}, - expectedSession: &CookieStore{ - Name: "cookieName", + &CookieStore{ + Name: "_cookie", + CSRFCookieName: "_cookie_csrf", CookieSecure: true, CookieHTTPOnly: true, - CookieExpire: time.Hour, - CSRFCookieName: "cookieName_csrf", + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher, }, - }, + false}, + {"missing name", + &CookieStoreOptions{ + Name: "", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher, + }, + nil, + true}, + {"missing cipher", + &CookieStoreOptions{ + Name: "_pomerium", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: nil, + }, + nil, + true}, } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - session, err := NewCookieStore("cookieName", tc.optFuncs...) - if tc.expectedError { - testutil.NotEqual(t, err, nil) - } else { - testutil.Ok(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewCookieStore(tt.opts) + if (err != nil) != tt.wantErr { + t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want) } - testutil.Equal(t, tc.expectedSession, session) }) } } -func TestMakeSessionCookie(t *testing.T) { +func TestCookieStore_makeCookie(t *testing.T) { + cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) + if err != nil { + t.Fatal(err) + } + + type fields struct { + Name string + CSRFCookieName string + CookieCipher cryptutil.Cipher + CookieExpire time.Duration + CookieRefresh time.Duration + CookieSecure bool + CookieHTTPOnly bool + CookieDomain string + } + now := time.Now() - cookieValue := "cookieValue" - expiration := time.Hour - cookieName := "cookieName" - testCases := []struct { - name string - optFuncs []func(*CookieStore) error - expectedCookie *http.Cookie - }{ - { - name: "default cookie domain", - expectedCookie: &http.Cookie{ - Name: cookieName, - Value: cookieValue, - Path: "/", - Domain: "www.example.com", - HttpOnly: true, - Secure: true, - Expires: now.Add(expiration), - }, - }, - { - name: "custom cookie domain set", - optFuncs: []func(*CookieStore) error{ - func(s *CookieStore) error { - s.CookieDomain = "buzzfeed.com" - return nil - }, - }, - expectedCookie: &http.Cookie{ - Name: cookieName, - Value: cookieValue, - Path: "/", - Domain: "buzzfeed.com", - HttpOnly: true, - Secure: true, - Expires: now.Add(expiration), - }, - }, - } + tests := []struct { + name string + domain string + + cookieName string + value string + expiration time.Duration + want *http.Cookie + }{ + {"good", "http://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}}, + {"domains with https", "https://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}}, + {"domain with port", "http://pomerium.io:443", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}}, + {"expiration set", "http://pomerium.io:443", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest("GET", tt.domain, nil) + + s := &CookieStore{ + Name: "_pomerium", + CSRFCookieName: "_pomerium_csrf", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: cipher} + + if got := s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want) + } + if got := s.makeSessionCookie(r, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want) + } + got := s.makeCSRFCookie(r, tt.value, tt.expiration, now) + tt.want.Name = "_pomerium_csrf" + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want) + } + w := httptest.NewRecorder() + want := "new-csrf" + s.SetCSRF(w, r, want) + found := false + for _, cookie := range w.Result().Cookies() { + if cookie.Name == s.CSRFCookieName && cookie.Value == want { + found = true + break + } + } + if !found { + t.Error("SetCSRF failed") + } + + w = httptest.NewRecorder() + s.ClearCSRF(w, r) + for _, cookie := range w.Result().Cookies() { + if cookie.Name == s.CSRFCookieName && cookie.Value == want { + t.Error("clear csrf failed") + break + + } + } + w = httptest.NewRecorder() + want = "new-session" + s.setSessionCookie(w, r, want) + found = false + for _, cookie := range w.Result().Cookies() { + if cookie.Name == s.Name && cookie.Value == want { + found = true + break + } + } + if !found { + t.Error("SetCSRF failed") + } + w = httptest.NewRecorder() + s.ClearSession(w, r) + for _, cookie := range w.Result().Cookies() { + if cookie.Name == s.Name && cookie.Value == want { + t.Error("clear csrf failed") + break + } + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - session, err := NewCookieStore(cookieName, tc.optFuncs...) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "http://www.example.com", nil) - cookie := session.makeSessionCookie(req, cookieValue, expiration, now) - testutil.Equal(t, cookie, tc.expectedCookie) }) } } -func TestMakeSessionCSRFCookie(t *testing.T) { - now := time.Now() - cookieValue := "cookieValue" - expiration := time.Hour - cookieName := "cookieName" - csrfName := "cookieName_csrf" - - testCases := []struct { - name string - optFuncs []func(*CookieStore) error - expectedCookie *http.Cookie - }{ - { - name: "default cookie domain", - expectedCookie: &http.Cookie{ - Name: csrfName, - Value: cookieValue, - Path: "/", - Domain: "www.example.com", - HttpOnly: true, - Secure: true, - Expires: now.Add(expiration), - }, - }, - { - name: "custom cookie domain set", - optFuncs: []func(*CookieStore) error{ - func(s *CookieStore) error { - s.CookieDomain = "buzzfeed.com" - return nil - }, - }, - expectedCookie: &http.Cookie{ - Name: csrfName, - Value: cookieValue, - Path: "/", - Domain: "buzzfeed.com", - HttpOnly: true, - Secure: true, - Expires: now.Add(expiration), - }, - }, +func TestCookieStore_SaveSession(t *testing.T) { + cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) + if err != nil { + t.Fatal(err) } + tests := []struct { + name string + sessionState *SessionState + cipher cryptutil.Cipher + wantErr bool + wantLoadErr bool + }{ + {"good", + &SessionState{ + AccessToken: "token1234", + RefreshToken: "refresh4321", + LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), + RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), + Email: "user@domain.com", + User: "user", + }, cipher, false, false}, + {"bad cipher", + &SessionState{ + AccessToken: "token1234", + RefreshToken: "refresh4321", + LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), + RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), + Email: "user@domain.com", + User: "user", + }, mockCipher{}, true, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &CookieStore{ + Name: "_pomerium", + CSRFCookieName: "_pomerium_csrf", + CookieSecure: true, + CookieHTTPOnly: true, + CookieDomain: "pomerium.io", + CookieExpire: 10 * time.Second, + CookieCipher: tt.cipher} - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - session, err := NewCookieStore(cookieName, tc.optFuncs...) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "http://www.example.com", nil) - cookie := session.makeCSRFCookie(req, cookieValue, expiration, now) - testutil.Equal(t, tc.expectedCookie, cookie) + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr { + t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) + } + r = httptest.NewRequest("GET", "/", nil) + for _, cookie := range w.Result().Cookies() { + t.Log(cookie) + r.AddCookie(cookie) + } + + state, err := s.LoadSession(r) + if (err != nil) != tt.wantLoadErr { + t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr) + return + } + if err == nil && !reflect.DeepEqual(state, tt.sessionState) { + t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState) + } }) } } -func TestSetSessionCookie(t *testing.T) { - cookieValue := "cookieValue" - cookieName := "cookieName" - - t.Run("set session cookie test", func(t *testing.T) { - session, err := NewCookieStore(cookieName) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "http://www.example.com", nil) - rw := httptest.NewRecorder() - session.setSessionCookie(rw, req, cookieValue) - var found bool - for _, cookie := range rw.Result().Cookies() { - if cookie.Name == cookieName { - found = true - testutil.Equal(t, cookieValue, cookie.Value) - testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now") - } - } - testutil.Assert(t, found, "cookie in header") - }) -} -func TestSetCSRFSessionCookie(t *testing.T) { - cookieValue := "cookieValue" - cookieName := "cookieName" - - t.Run("set csrf cookie test", func(t *testing.T) { - session, err := NewCookieStore(cookieName) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "http://www.example.com", nil) - rw := httptest.NewRecorder() - session.SetCSRF(rw, req, cookieValue) - var found bool - for _, cookie := range rw.Result().Cookies() { - if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) { - found = true - testutil.Equal(t, cookieValue, cookie.Value) - testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now") - } - } - testutil.Assert(t, found, "cookie in header") - }) -} - -func TestClearSessionCookie(t *testing.T) { - cookieValue := "cookieValue" - cookieName := "cookieName" - - t.Run("set session cookie test", func(t *testing.T) { - session, err := NewCookieStore(cookieName) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "http://www.example.com", nil) - req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now())) - - rw := httptest.NewRecorder() - session.ClearSession(rw, req) - var found bool - for _, cookie := range rw.Result().Cookies() { - if cookie.Name == cookieName { - found = true - testutil.Equal(t, "", cookie.Value) - testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now") - } - } - testutil.Assert(t, found, "cookie in header") - }) -} - -func TestClearCSRFSessionCookie(t *testing.T) { - cookieValue := "cookieValue" - cookieName := "cookieName" - - t.Run("clear csrf cookie test", func(t *testing.T) { - session, err := NewCookieStore(cookieName) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "http://www.example.com", nil) - req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now())) - - rw := httptest.NewRecorder() - session.ClearCSRF(rw, req) - var found bool - for _, cookie := range rw.Result().Cookies() { - if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) { - found = true - testutil.Equal(t, "", cookie.Value) - testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now") - } - } - testutil.Assert(t, found, "cookie in header") - }) -} - -func TestLoadCookiedSession(t *testing.T) { - cookieName := "cookieName" - - testCases := []struct { - name string - optFuncs []func(*CookieStore) error - setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState) - expectedError error - sessionState *SessionState +func TestMockCSRFStore(t *testing.T) { + tests := []struct { + name string + mockCSRF *MockCSRFStore + newCSRFValue string + wantErr bool }{ - { - name: "no cookie set returns an error", - setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {}, - expectedError: http.ErrNoCookie, - }, - { - name: "cookie set with cipher set", - optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)}, - setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) { - value, err := MarshalSession(sessionState, s.CookieCipher) - testutil.Ok(t, err) - req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now())) - }, - sessionState: &SessionState{ - Email: "example@email.com", - RefreshToken: "abccdddd", - AccessToken: "access", - }, - }, - { - name: "cookie set with invalid value cipher set", - optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)}, - setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) { - value := "574b776a7c934d6b9fc42ec63a389f79" - req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now())) - }, - expectedError: ErrInvalidSession, - }, + {"basic", + &MockCSRFStore{ + ResponseCSRF: "ok", + Cookie: &http.Cookie{Name: "hi"}}, + "newcsrf", + false}, } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - session, err := NewCookieStore(cookieName, tc.optFuncs...) - testutil.Ok(t, err) - req := httptest.NewRequest("GET", "https://www.example.com", nil) - tc.setupCookies(t, req, session, tc.sessionState) - s, err := session.LoadSession(req) - - testutil.Equal(t, tc.expectedError, err) - testutil.Equal(t, tc.sessionState, s) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ms := tt.mockCSRF + ms.SetCSRF(nil, nil, tt.newCSRFValue) + ms.ClearCSRF(nil, nil) + got, err := ms.GetCSRF(nil) + if (err != nil) != tt.wantErr { + t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) { + t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie) + } }) } } + +func TestMockSessionStore(t *testing.T) { + tests := []struct { + name string + mockCSRF *MockSessionStore + saveSession *SessionState + wantLoadErr bool + wantSaveErr bool + }{ + {"basic", + &MockSessionStore{ + ResponseSession: "test", + Session: &SessionState{AccessToken: "AccessToken"}, + SaveError: nil, + LoadError: nil, + }, + &SessionState{AccessToken: "AccessToken"}, + false, + false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ms := tt.mockCSRF + + err := ms.SaveSession(nil, nil, tt.saveSession) + if (err != nil) != tt.wantSaveErr { + t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr) + return + } + got, err := ms.LoadSession(nil) + if (err != nil) != tt.wantLoadErr { + t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr) + return + } + if !reflect.DeepEqual(got, tt.mockCSRF.Session) { + t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session) + } + ms.ClearSession(nil, nil) + if ms.ResponseSession != "" { + t.Errorf("ResponseSession not empty! %s", ms.ResponseSession) + } + }) + } +} diff --git a/internal/sessions/mock_store.go b/internal/sessions/mock_store.go index b1a968060..8ac0585b5 100644 --- a/internal/sessions/mock_store.go +++ b/internal/sessions/mock_store.go @@ -35,7 +35,7 @@ type MockSessionStore struct { } // ClearSession clears the ResponseSession -func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) { +func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) { ms.ResponseSession = "" } diff --git a/internal/sessions/session_state.go b/internal/sessions/session_state.go index 7fc8bbbe2..513ef7712 100644 --- a/internal/sessions/session_state.go +++ b/internal/sessions/session_state.go @@ -20,11 +20,9 @@ type SessionState struct { RefreshDeadline time.Time `json:"refresh_deadline"` LifetimeDeadline time.Time `json:"lifetime_deadline"` - ValidDeadline time.Time `json:"valid_deadline"` - GracePeriodStart time.Time `json:"grace_period_start"` Email string `json:"email"` - User string `json:"user"` // 'sub' in jwt parlance + User string `json:"user"` // 'sub' in jwt Groups []string `json:"groups"` } @@ -38,11 +36,6 @@ func (s *SessionState) RefreshPeriodExpired() bool { return isExpired(s.RefreshDeadline) } -// ValidationPeriodExpired returns true if the validation period has expired -func (s *SessionState) ValidationPeriodExpired() bool { - return isExpired(s.ValidDeadline) -} - func isExpired(t time.Time) bool { return t.Before(time.Now()) } @@ -64,7 +57,7 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) { return s, nil } -// ExtendDeadline returns the time extended by a given duration +// ExtendDeadline returns the time extended by a given duration, truncated by second func ExtendDeadline(ttl time.Duration) time.Time { return time.Now().Add(ttl).Truncate(time.Second) } diff --git a/internal/sessions/session_state_test.go b/internal/sessions/session_state_test.go index 8523611ba..d4f98f77c 100644 --- a/internal/sessions/session_state_test.go +++ b/internal/sessions/session_state_test.go @@ -1,4 +1,4 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package sessions import ( "reflect" @@ -16,15 +16,12 @@ func TestSessionStateSerialization(t *testing.T) { } want := &SessionState{ - AccessToken: "token1234", - RefreshToken: "refresh4321", - + AccessToken: "token1234", + RefreshToken: "refresh4321", LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), - ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(), - - Email: "user@domain.com", - User: "user", + Email: "user@domain.com", + User: "user", } ciphertext, err := MarshalSession(want, c) @@ -46,26 +43,40 @@ func TestSessionStateSerialization(t *testing.T) { func TestSessionStateExpirations(t *testing.T) { session := &SessionState{ - AccessToken: "token1234", - RefreshToken: "refresh4321", - + AccessToken: "token1234", + RefreshToken: "refresh4321", LifetimeDeadline: time.Now().Add(-1 * time.Hour), RefreshDeadline: time.Now().Add(-1 * time.Hour), - ValidDeadline: time.Now().Add(-1 * time.Minute), - - Email: "user@domain.com", - User: "user", + Email: "user@domain.com", + User: "user", } if !session.LifetimePeriodExpired() { - t.Errorf("expcted lifetime period to be expired") + t.Errorf("expected lifetime period to be expired") } if !session.RefreshPeriodExpired() { - t.Errorf("expcted lifetime period to be expired") + t.Errorf("expected lifetime period to be expired") } - if !session.ValidationPeriodExpired() { - t.Errorf("expcted lifetime period to be expired") +} + +func TestExtendDeadline(t *testing.T) { + // tons of wiggle room here + now := time.Now().Truncate(time.Second) + tests := []struct { + name string + ttl time.Duration + want time.Time + }{ + {"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)}, + {"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want) + } + }) } } diff --git a/internal/singleflight/singleflight.go b/internal/singleflight/singleflight.go deleted file mode 100644 index 44755b35b..000000000 --- a/internal/singleflight/singleflight.go +++ /dev/null @@ -1,75 +0,0 @@ -// Original Copyright 2013 The Go Authors. All rights reserved. -// -// Modified by BuzzFeed to return duplicate counts. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package singleflight provides a duplicate function call suppression mechanism. -package singleflight // import "github.com/pomerium/pomerium/internal/singleflight" - -import "sync" - -// call is an in-flight or completed singleflight.Do call -type call struct { - wg sync.WaitGroup - - // These fields are written once before the WaitGroup is done - // and are only read after the WaitGroup is done. - val interface{} - err error - - // These fields are read and written with the singleflight - // mutex held before the WaitGroup is done, and are read but - // not written after the WaitGroup is done. - dups int -} - -// Group represents a class of work and forms a namespace in -// which units of work can be executed with duplicate suppression. -type Group struct { - mu sync.Mutex // protects m - m map[string]*call // lazily initialized -} - -// Result holds the results of Do, so they can be passed -// on a channel. -type Result struct { - Val interface{} - Err error - Count bool -} - -// Do executes and returns the results of the given function, making -// sure that only one execution is in-flight for a given key at a -// time. If a duplicate comes in, the duplicate caller waits for the -// original to complete and receives the same results. -// The return value of Count indicates how many tiems v was given to multiple callers. -// Count will be zero for requests are shared and only be non-zero for the originating request. -func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) { - g.mu.Lock() - - if g.m == nil { - g.m = make(map[string]*call) - } - if c, ok := g.m[key]; ok { - c.dups++ - g.mu.Unlock() - c.wg.Wait() - return c.val, 0, c.err - } - c := new(call) - c.wg.Add(1) - g.m[key] = c - - g.mu.Unlock() - - c.val, c.err = fn() - c.wg.Done() - - g.mu.Lock() - delete(g.m, key) - g.mu.Unlock() - - return c.val, c.dups, c.err -} diff --git a/internal/singleflight/singleflight_test.go b/internal/singleflight/singleflight_test.go deleted file mode 100644 index 585514f8c..000000000 --- a/internal/singleflight/singleflight_test.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package singleflight // import "github.com/pomerium/pomerium/internal/singleflight" - -import ( - "errors" - "fmt" - "sync" - "sync/atomic" - "testing" - "time" -) - -func TestDo(t *testing.T) { - var g Group - v, _, err := g.Do("key", func() (interface{}, error) { - return "bar", nil - }) - if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { - t.Errorf("Do = %v; want %v", got, want) - } - if err != nil { - t.Errorf("Do error = %v", err) - } -} - -func TestDoErr(t *testing.T) { - var g Group - someErr := errors.New("Some error") - v, _, err := g.Do("key", func() (interface{}, error) { - return nil, someErr - }) - if err != someErr { - t.Errorf("Do error = %v; want someErr %v", err, someErr) - } - if v != nil { - t.Errorf("unexpected non-nil value %#v", v) - } -} - -func TestDoDupSuppress(t *testing.T) { - var g Group - var wg1, wg2 sync.WaitGroup - c := make(chan string, 1) - var calls int32 - fn := func() (interface{}, error) { - if atomic.AddInt32(&calls, 1) == 1 { - // First invocation. - wg1.Done() - } - v := <-c - c <- v // pump; make available for any future calls - - time.Sleep(10 * time.Millisecond) // let more goroutines enter Do - - return v, nil - } - - const n = 10 - wg1.Add(1) - for i := 0; i < n; i++ { - wg1.Add(1) - wg2.Add(1) - go func() { - defer wg2.Done() - wg1.Done() - v, _, err := g.Do("key", fn) - if err != nil { - t.Errorf("Do error: %v", err) - return - } - if s, _ := v.(string); s != "bar" { - t.Errorf("Do = %T %v; want %q", v, v, "bar") - } - }() - } - wg1.Wait() - // At least one goroutine is in fn now and all of them have at - // least reached the line before the Do. - c <- "bar" - wg2.Wait() - if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { - t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) - } -} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go deleted file mode 100644 index 3a317e8ed..000000000 --- a/internal/testutil/testutil.go +++ /dev/null @@ -1,46 +0,0 @@ -package testutil // import "github.com/pomerium/pomerium/internal/testutil" - -// testing util functions copied from https://github.com/benbjohnson/testing -import ( - "fmt" - "path/filepath" - "reflect" - "runtime" - "testing" -) - -// Assert fails the test if the condition is false. -func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) { - if !condition { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...) - tb.FailNow() - } -} - -// Ok fails the test if an err is not nil. -func Ok(tb testing.TB, err error) { - if err != nil { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error()) - tb.FailNow() - } -} - -// Equal fails the test if exp is not equal to act. -func Equal(tb testing.TB, exp, act interface{}) { - if !reflect.DeepEqual(exp, act) { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act) - tb.FailNow() - } -} - -// NotEqual fails the test if exp is equal to act. -func NotEqual(tb testing.TB, exp, act interface{}) { - if reflect.DeepEqual(exp, act) { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act) - tb.FailNow() - } -} diff --git a/proxy/authenticator/authenticator.go b/proxy/authenticator/authenticator.go index bf3854baa..a49dbbf9d 100644 --- a/proxy/authenticator/authenticator.go +++ b/proxy/authenticator/authenticator.go @@ -25,7 +25,7 @@ type Options struct { // InternalAddr is the internal (behind the ingress) address to use when making an // authentication connection. If empty, Addr is used. InternalAddr string - // OverrideServerName overrides the server name used to verify the hostname on the + // OverideCertificateName overrides the server name used to verify the hostname on the // returned certificates from the server. gRPC internals also use it to override the virtual // hosting name if it is set. OverideCertificateName string diff --git a/proxy/handlers.go b/proxy/handlers.go index b6655ddec..50b908f62 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -203,16 +203,15 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { } // We store the session in a cookie and redirect the user back to the application - err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{ - AccessToken: rr.AccessToken, - RefreshToken: rr.RefreshToken, - IDToken: rr.IDToken, - User: rr.User, - Email: rr.Email, - RefreshDeadline: (rr.Expiry).Truncate(time.Second), - LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL), - ValidDeadline: extendDeadline(p.CookieExpire), - }) + err = p.sessionStore.SaveSession(w, r, + &sessions.SessionState{ + AccessToken: rr.AccessToken, + RefreshToken: rr.RefreshToken, + IDToken: rr.IDToken, + User: rr.User, + Email: rr.Email, + RefreshDeadline: (rr.Expiry).Truncate(time.Second), + }) if err != nil { log.FromRequest(r).Error().Msg("error saving session") httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError) @@ -250,9 +249,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { } } - // ! ! ! - // todo(bdd): ! Authorization service goes here ! - // ! ! ! + // todo(bdd): add authorization service validation // We have validated the users request and now proxy their request to the provided upstream. route, ok := p.router(r) @@ -278,14 +275,10 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error) return err } - if session.LifetimePeriodExpired() { - log.FromRequest(r).Info().Msg("proxy: lifetime expired") - return sessions.ErrLifetimeExpired - } if session.RefreshPeriodExpired() { // AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a // refresh token (which doesn't change) can be used to request a new access-token. If access - // is revoked by identity provider, or no refresh token is set request will return an error + // is revoked by identity provider, or no refresh token is set, request will return an error accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken) if err != nil { log.FromRequest(r).Warn(). diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 4bb41dfb7..01e32e771 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -272,7 +272,7 @@ func TestProxy_OAuthCallback(t *testing.T) { if err != nil { t.Fatal(err) } - proxy.sessionStore = tt.session + proxy.sessionStore = &tt.session proxy.csrfStore = tt.csrf proxy.AuthenticateClient = tt.authenticator proxy.cipher = mockCipher{} @@ -352,12 +352,6 @@ func TestProxy_Proxy(t *testing.T) { RefreshToken: "RefreshToken", LifetimeDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), - } - expiredLifetime := &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(-10 * time.Second), } tests := []struct { @@ -368,11 +362,10 @@ func TestProxy_Proxy(t *testing.T) { wantStatus int }{ // weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know. - {"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable}, - {"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError}, + {"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable}, + {"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError}, // redirect to start auth process - {"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound}, - {"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound}, + {"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -402,13 +395,8 @@ func TestProxy_Authenticate(t *testing.T) { RefreshToken: "RefreshToken", LifetimeDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second), - ValidDeadline: time.Now().Add(10 * time.Second), - } - expiredLifetime := &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - LifetimeDeadline: time.Now().Add(-10 * time.Second), } + expiredDeadline := &sessions.SessionState{ AccessToken: "AccessToken", RefreshToken: "RefreshToken", @@ -426,25 +414,21 @@ func TestProxy_Authenticate(t *testing.T) { {"cannot save session", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, - sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")}, + &sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")}, authenticator.MockAuthenticate{}, true}, {"cannot load session", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, - sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true}, - {"expired lifetime", - "https://corp.example.com/", - map[string]string{"corp.example.com": "example.com"}, - sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true}, + &sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true}, {"expired session", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, - sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false}, + &sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false}, {"bad refresh authenticator", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, - sessions.MockSessionStore{ + &sessions.MockSessionStore{ Session: expiredDeadline, }, authenticator.MockAuthenticate{RefreshError: errors.New("error")}, @@ -453,7 +437,7 @@ func TestProxy_Authenticate(t *testing.T) { {"good", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, - sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false}, + &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/proxy/proxy.go b/proxy/proxy.go index b505e0893..35a109cf6 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -133,12 +133,9 @@ type Proxy struct { AuthenticateClient authenticator.Authenticator // session - CookieExpire time.Duration - CookieRefresh time.Duration - CookieLifetimeTTL time.Duration - cipher cryptutil.Cipher - csrfStore sessions.CSRFStore - sessionStore sessions.SessionStore + cipher cryptutil.Cipher + csrfStore sessions.CSRFStore + sessionStore sessions.SessionStore redirectURL *url.URL templates *template.Template @@ -163,13 +160,14 @@ func New(opts *Options) (*Proxy, error) { return nil, fmt.Errorf("cookie-secret error: %s", err.Error()) } - cookieStore, err := sessions.NewCookieStore(opts.CookieName, - sessions.CreateCookieCipher(decodedSecret), - func(c *sessions.CookieStore) error { - c.CookieDomain = opts.CookieDomain - c.CookieHTTPOnly = opts.CookieHTTPOnly - c.CookieExpire = opts.CookieExpire - return nil + cookieStore, err := sessions.NewCookieStore( + &sessions.CookieStoreOptions{ + Name: opts.CookieName, + CookieDomain: opts.CookieDomain, + CookieSecure: opts.CookieSecure, + CookieHTTPOnly: opts.CookieHTTPOnly, + CookieExpire: opts.CookieExpire, + CookieCipher: cipher, }) if err != nil { @@ -181,14 +179,12 @@ func New(opts *Options) (*Proxy, error) { // services AuthenticateURL: opts.AuthenticateURL, // session state - cipher: cipher, - csrfStore: cookieStore, - sessionStore: cookieStore, - SharedKey: opts.SharedKey, - redirectURL: &url.URL{Path: "/.pomerium/callback"}, - templates: templates.New(), - CookieExpire: opts.CookieExpire, - CookieLifetimeTTL: opts.CookieLifetimeTTL, + cipher: cipher, + csrfStore: cookieStore, + sessionStore: cookieStore, + SharedKey: opts.SharedKey, + redirectURL: &url.URL{Path: "/.pomerium/callback"}, + templates: templates.New(), } for from, to := range opts.Routes { @@ -200,7 +196,7 @@ func New(opts *Options) (*Proxy, error) { return nil, err } p.Handle(fromURL.Host, handler) - log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route") + log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy: new route") } p.AuthenticateClient, err = authenticator.New( diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index d7be621b4..2a2ca2feb 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -139,6 +139,7 @@ func testOptions() *Options { AuthenticateURL: authurl, SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", + CookieName: "pomerium", } }