all: general cleanup readying for tagged release (#48)

- docs: add code coverage to readme
- internal/sessions: refactor sessions to clarify lifetime
- authenticate: simplified signin flow
- deployment: update go mods
- internal/testutil: removed package
- internal/singleflight: removed package
This commit is contained in:
Bobby DeSimone 2019-02-16 12:43:18 -08:00 committed by GitHub
parent 13c03a2b5c
commit dbafc691c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 712 additions and 1017 deletions

View file

@ -4,24 +4,24 @@
# Pomerium # Pomerium
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://codecov.io/gh/pomerium/pomerium/branch/master/graph/badge.svg)](https://codecov.io/gh/pomerium/pomerium) [![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://img.shields.io/codecov/c/github/pomerium/pomerium.svg?style=flat)](https://codecov.io/gh/pomerium/pomerium)
Pomerium is a tool for managing secure access to internal applications and resources. Pomerium is a tool for managing secure access to internal applications and resources.
Use Pomerium to: Use Pomerium to:
- provide a unified gateway (reverse-proxy) to internal corporate applications. - provide a single-sign-on gateway to internal applications.
- enforce dynamic access policy based on context, identity, and device state. - enforce dynamic access policy based on **context**, **identity**, and **device state**.
- deploy mutual authenticated encryption (mTLS). - aggregate access logs and telemetry data.
- aggregate logging and telemetry data. - an alternative to a VPN.
Check out [awesome-zero-trust] to learn more about some problems Pomerium attempts to address. Check out [awesome-zero-trust] to learn more about some of the problems Pomerium attempts to address.
## Docs ## Docs
To get started with pomerium, check out our [quick start guide]. To get started with pomerium, check out our [quick start guide].
For comprehensive docs see our [documentation] and the [godocs]. For comprehensive docs, and tutorials see our [documentation] and the [godocs].
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust [awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
[documentation]: https://www.pomerium.io/docs/ [documentation]: https://www.pomerium.io/docs/

View file

@ -23,21 +23,23 @@ var defaultOptions = &Options{
CookieSecure: true, CookieSecure: true,
CookieExpire: time.Duration(168) * time.Hour, CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(30) * time.Minute, CookieRefresh: time.Duration(30) * time.Minute,
CookieLifetimeTTL: time.Duration(720) * time.Hour,
} }
// Options details the available configuration settings for the authenticate service // Options details the available configuration settings for the authenticate service
type Options struct { type Options struct {
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
// SharedKey is used to authenticate requests between services // SharedKey is used to authenticate requests between services
SharedKey string `envconfig:"SHARED_SECRET"` SharedKey string `envconfig:"SHARED_SECRET"`
// RedirectURL specifies the callback url following third party authentication
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
// Coarse authorization based on user email domain // Coarse authorization based on user email domain
// todo(bdd) : to be replaced with authorization module
AllowedDomains []string `envconfig:"ALLOWED_DOMAINS"` AllowedDomains []string `envconfig:"ALLOWED_DOMAINS"`
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"` ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
// Session/Cookie management // Session/Cookie management
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
CookieName string CookieName string
CookieSecret string `envconfig:"COOKIE_SECRET"` CookieSecret string `envconfig:"COOKIE_SECRET"`
CookieDomain string `envconfig:"COOKIE_DOMAIN"` CookieDomain string `envconfig:"COOKIE_DOMAIN"`
@ -45,10 +47,9 @@ type Options struct {
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"` CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"` CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"` CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
CookieLifetimeTTL time.Duration `envconfig:"COOKIE_LIFETIME"`
// IdentityProvider provider configuration variables as specified by RFC6749 // IdentityProvider provider configuration variables as specified by RFC6749
// See: https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749 // https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
ClientID string `envconfig:"IDP_CLIENT_ID"` ClientID string `envconfig:"IDP_CLIENT_ID"`
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"` ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
Provider string `envconfig:"IDP_PROVIDER"` Provider string `envconfig:"IDP_PROVIDER"`
@ -103,17 +104,13 @@ func (o *Options) Validate() error {
// Authenticate validates a user's identity // Authenticate validates a user's identity
type Authenticate struct { type Authenticate struct {
RedirectURL *url.URL
Validator func(string) bool
AllowedDomains []string
ProxyRootDomains []string
CookieSecure bool
SharedKey string SharedKey string
CookieLifetimeTTL time.Duration RedirectURL *url.URL
AllowedDomains []string
ProxyRootDomains []string
Validator func(string) bool
templates *template.Template templates *template.Template
csrfStore sessions.CSRFStore csrfStore sessions.CSRFStore
@ -137,35 +134,45 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
if err != nil { if err != nil {
return nil, err return nil, err
} }
cookieStore, err := sessions.NewCookieStore(opts.CookieName, cookieStore, err := sessions.NewCookieStore(
sessions.CreateCookieCipher(decodedCookieSecret), &sessions.CookieStoreOptions{
func(c *sessions.CookieStore) error { Name: opts.CookieName,
c.CookieDomain = opts.CookieDomain CookieSecure: opts.CookieSecure,
c.CookieHTTPOnly = opts.CookieHTTPOnly CookieHTTPOnly: opts.CookieHTTPOnly,
c.CookieExpire = opts.CookieExpire CookieExpire: opts.CookieExpire,
c.CookieSecure = opts.CookieSecure CookieCipher: cipher,
return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
provider, err := providers.New(
opts.Provider,
&providers.IdentityProvider{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
// SessionLifetimeTTL: opts.CookieLifetimeTTL,
Scopes: opts.Scopes,
})
if err != nil {
return nil, err
}
p := &Authenticate{ p := &Authenticate{
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
RedirectURL: opts.RedirectURL,
AllowedDomains: opts.AllowedDomains, AllowedDomains: opts.AllowedDomains,
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains), ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
CookieSecure: opts.CookieSecure,
RedirectURL: opts.RedirectURL,
templates: templates.New(), templates: templates.New(),
csrfStore: cookieStore, csrfStore: cookieStore,
sessionStore: cookieStore, sessionStore: cookieStore,
cipher: cipher, cipher: cipher,
} provider: provider,
p.provider, err = newProvider(opts)
if err != nil {
return nil, err
} }
// validation via dependency injected function // validation via dependency injected function
@ -179,20 +186,6 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
return p, nil return p, nil
} }
func newProvider(opts *Options) (providers.Provider, error) {
pd := &providers.IdentityProvider{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
SessionLifetimeTTL: opts.CookieLifetimeTTL,
Scopes: opts.Scopes,
}
np, err := providers.New(opts.Provider, pd)
return np, err
}
func dotPrependDomains(d []string) []string { func dotPrependDomains(d []string) []string {
for i := range d { for i := range d {
if d[i] != "" && !strings.HasPrefix(d[i], ".") { if d[i] != "" && !strings.HasPrefix(d[i], ".") {

View file

@ -19,8 +19,9 @@ func testOptions() *Options {
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieRefresh: time.Duration(1) * time.Hour, CookieRefresh: time.Duration(1) * time.Hour,
CookieLifetimeTTL: time.Duration(720) * time.Hour, // CookieLifetimeTTL: time.Duration(720) * time.Hour,
CookieExpire: time.Duration(168) * time.Hour, CookieExpire: time.Duration(168) * time.Hour,
CookieName: "pomerium",
} }
} }
@ -130,37 +131,6 @@ func Test_dotPrependDomains(t *testing.T) {
} }
} }
func Test_newProvider(t *testing.T) {
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
goodOpts := &Options{
RedirectURL: redirectURL,
Provider: "google",
ProviderURL: "",
ClientID: "cllient-id",
ClientSecret: "client-secret",
}
tests := []struct {
name string
opts *Options
wantErr bool
}{
{"good", goodOpts, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newProvider(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("newProvider() error = %v, wantErr %v", err, tt.wantErr)
return
}
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("newProvider() = %v, want %v", got, tt.want)
// }
})
}
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
good := testOptions() good := testOptions()
good.Provider = "google" good.Provider = "google"

View file

@ -117,7 +117,6 @@ func TestAuthenticate_Authenticate(t *testing.T) {
} }
lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC() lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC() rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
vt := time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC()
vtProto, err := ptypes.TimestampProto(rt) vtProto, err := ptypes.TimestampProto(rt)
if err != nil { if err != nil {
t.Fatal("failed to parse timestamp") t.Fatal("failed to parse timestamp")
@ -128,7 +127,7 @@ func TestAuthenticate_Authenticate(t *testing.T) {
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
LifetimeDeadline: lt, LifetimeDeadline: lt,
RefreshDeadline: rt, RefreshDeadline: rt,
ValidDeadline: vt,
Email: "user@domain.com", Email: "user@domain.com",
User: "user", User: "user",
} }

View file

@ -16,7 +16,8 @@ import (
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
) )
// securityHeaders corresponds to HTTP response headers related to security. // securityHeaders corresponds to HTTP response headers that help to protect against protocol
// downgrade attacks and cookie hijacking.
// https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers // https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers
var securityHeaders = map[string]string{ var securityHeaders = map[string]string{
"Strict-Transport-Security": "max-age=31536000", "Strict-Transport-Security": "max-age=31536000",
@ -28,7 +29,7 @@ var securityHeaders = map[string]string{
"Referrer-Policy": "Same-origin", "Referrer-Policy": "Same-origin",
} }
// Handler returns the Http.Handlers for authenticate, callback, and refresh // Handler returns the authenticate service's HTTP request multiplexer, and routes.
func (a *Authenticate) Handler() http.Handler { func (a *Authenticate) Handler() http.Handler {
// set up our standard middlewares // set up our standard middlewares
stdMiddleware := middleware.NewChain() stdMiddleware := middleware.NewChain()
@ -80,12 +81,6 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
return nil, err return nil, err
} }
// if long-lived lifetime has expired, clear session
if session.LifetimePeriodExpired() {
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
a.sessionStore.ClearSession(w, r)
return nil, sessions.ErrLifetimeExpired
}
// check if session refresh period is up // check if session refresh period is up
if session.RefreshPeriodExpired() { if session.RefreshPeriodExpired() {
newToken, err := a.provider.Refresh(session.RefreshToken) newToken, err := a.provider.Refresh(session.RefreshToken)
@ -130,32 +125,23 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
return session, nil return session, nil
} }
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user, // SignIn handles the sign_in endpoint. It attempts to authenticate the user,
// and if the user is not authenticated, it renders a sign in page. // and if the user is not authenticated, it renders a sign in page.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.authenticate(w, r) session, err := a.authenticate(w, r)
switch err { if err != nil {
case nil: log.FromRequest(r).Info().Err(err).Msg("authenticate: authenticate error")
// session good, redirect back to proxy
log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated")
a.ProxyCallback(w, r, session)
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
// session invalid, authenticate
log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure")
if err != http.ErrNoCookie {
a.sessionStore.ClearSession(w, r) a.sessionStore.ClearSession(w, r)
}
a.OAuthStart(w, r) a.OAuthStart(w, r)
default:
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
} }
log.FromRequest(r).Info().Msg("authenticate: user authenticated")
a.ProxyCallback(w, r, session)
} }
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as // ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
// url params, of the user's session state. // url params, of the user's session state as specified in RFC6749 3.1.2.
// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2 // https://tools.ietf.org/html/rfc6749#section-3.1.2
func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) { func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
@ -201,9 +187,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
return u.String() return u.String()
} }
// SignOut signs the user out by trying to revoke the users remote identity provider session // SignOut signs the user out by trying to revoke the user's remote identity session along with
// then removes the associated local session state. // the associated local session state. Handles both GET and POST.
// Handles both GET and POST of form.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
@ -256,8 +241,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirectURI, http.StatusFound) http.Redirect(w, r, redirectURI, http.StatusFound)
} }
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a // OAuthStart starts the authenticate process by redirecting to the identity provider.
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate. // https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri")) authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
if err != nil { if err != nil {
@ -298,7 +283,7 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
} }
// OAuthCallback handles the callback from the identity provider. Displays an error page if there // OAuthCallback handles the callback from the identity provider. Displays an error page if there
// was an error. If successful, redirects back to the proxy-service via the redirect-url. // was an error. If successful, the user is redirected back to the proxy-service.
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r) redirect, err := a.getOAuthCallback(w, r)
switch h := err.(type) { switch h := err.(type) {

View file

@ -71,29 +71,19 @@ func TestAuthenticate_Handler(t *testing.T) {
func TestAuthenticate_authenticate(t *testing.T) { func TestAuthenticate_authenticate(t *testing.T) {
// sessions.MockSessionStore{Session: expiredLifetime} // sessions.MockSessionStore{Session: expiredLifetime}
goodSession := sessions.MockSessionStore{ goodSession := &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}} }}
expiredSession := sessions.MockSessionStore{
expiredRefresPeriod := &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * -time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
expiredRefresPeriod := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second), RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}} }}
tests := []struct { tests := []struct {
@ -106,18 +96,16 @@ func TestAuthenticate_authenticate(t *testing.T) {
}{ }{
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false}, {"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true}, {"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
{"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, {"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true}, {"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
{"session fails after good validation", sessions.MockSessionStore{ {"session fails after good validation", &sessions.MockSessionStore{
SaveError: errors.New("error"), SaveError: errors.New("error"),
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second), }}, providers.MockProvider{ValidateResponse: true},
}}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, trueValidator, nil, true},
{"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"refresh expired", {"refresh expired",
expiredRefresPeriod, expiredRefresPeriod,
providers.MockProvider{ providers.MockProvider{
@ -136,14 +124,13 @@ func TestAuthenticate_authenticate(t *testing.T) {
}, },
trueValidator, nil, true}, trueValidator, nil, true},
{"refresh expired failed save", {"refresh expired failed save",
sessions.MockSessionStore{ &sessions.MockSessionStore{
SaveError: errors.New("error"), SaveError: errors.New("error"),
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second), RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
providers.MockProvider{ providers.MockProvider{
ValidateResponse: true, ValidateResponse: true,
@ -182,29 +169,23 @@ func TestAuthenticate_SignIn(t *testing.T) {
wantCode int wantCode int
}{ }{
{"good", {"good",
sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
providers.MockProvider{ValidateResponse: true}, providers.MockProvider{ValidateResponse: true},
trueValidator, trueValidator,
403}, http.StatusForbidden},
// {"no session", {"session fails after good validation", &sessions.MockSessionStore{
// sessions.MockSessionStore{ SaveError: errors.New("error"),
// Session: &sessions.SessionState{ Session: &sessions.SessionState{
// AccessToken: "AccessToken", AccessToken: "AccessToken",
// RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
// LifetimeDeadline: time.Now().Add(-10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
// RefreshDeadline: time.Now().Add(10 * time.Second), }}, providers.MockProvider{ValidateResponse: true},
// ValidDeadline: time.Now().Add(10 * time.Second), trueValidator, http.StatusBadRequest},
// }},
// providers.MockProvider{ValidateResponse: true},
// trueValidator,
// 200},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -212,6 +193,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
Validator: tt.validator, Validator: tt.validator,
RedirectURL: uriParse("http://www.pomerium.io"),
csrfStore: &sessions.MockCSRFStore{},
SharedKey: "secret",
cipher: mockCipher{},
} }
r := httptest.NewRequest("GET", "/sign-in", nil) r := httptest.NewRequest("GET", "/sign-in", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -264,11 +249,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{ &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
302, 302,
"<a href=\"https://corp.pomerium.io/?code=ok&amp;state=state\">Found</a>."}, "<a href=\"https://corp.pomerium.io/?code=ok&amp;state=state\">Found</a>."},
{"no state", {"no state",
@ -278,11 +262,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{ &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
403, 403,
"no state parameter supplied"}, "no state parameter supplied"},
{"no redirect_url", {"no redirect_url",
@ -292,11 +275,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{ &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
403, 403,
"no redirect_uri parameter"}, "no redirect_uri parameter"},
{"malformed redirect_url", {"malformed redirect_url",
@ -306,11 +288,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{ &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
400, 400,
"malformed redirect_uri"}, "malformed redirect_uri"},
} }
@ -389,14 +370,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig", "sig",
"ts", "ts",
providers.MockProvider{}, providers.MockProvider{},
sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
}, },
http.StatusFound, http.StatusFound,
@ -407,14 +387,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig", "sig",
"ts", "ts",
providers.MockProvider{RevokeError: errors.New("OH NO")}, providers.MockProvider{RevokeError: errors.New("OH NO")},
sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
}, },
http.StatusBadRequest, http.StatusBadRequest,
@ -426,14 +405,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig", "sig",
"ts", "ts",
providers.MockProvider{}, providers.MockProvider{},
sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
}, },
http.StatusOK, http.StatusOK,
@ -444,15 +422,14 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig", "sig",
"ts", "ts",
providers.MockProvider{}, providers.MockProvider{},
sessions.MockSessionStore{ &sessions.MockSessionStore{
LoadError: errors.New("uh oh"), LoadError: errors.New("uh oh"),
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
}, },
http.StatusBadRequest, http.StatusBadRequest,
@ -463,14 +440,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig", "sig",
"ts", "ts",
providers.MockProvider{}, providers.MockProvider{},
sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}, },
}, },
http.StatusBadRequest, http.StatusBadRequest,
@ -512,7 +488,6 @@ func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string
} }
func TestAuthenticate_OAuthStart(t *testing.T) { func TestAuthenticate_OAuthStart(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
method string method string
@ -634,15 +609,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -657,15 +631,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -681,15 +654,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -704,7 +676,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateError: errors.New("error"), AuthenticateError: errors.New("error"),
}, },
@ -721,15 +693,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{SaveError: errors.New("error")}, &sessions.MockSessionStore{SaveError: errors.New("error")},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -744,15 +715,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
falseValidator, falseValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -768,15 +738,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -791,15 +760,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -814,15 +782,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
"nonce:https://corp.pomerium.io", "nonce:https://corp.pomerium.io",
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -837,15 +804,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce")), base64.URLEncoding.EncodeToString([]byte("nonce")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",
@ -860,15 +826,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")),
[]string{"pomerium.io"}, []string{"pomerium.io"},
trueValidator, trueValidator,
sessions.MockSessionStore{}, &sessions.MockSessionStore{},
providers.MockProvider{ providers.MockProvider{
AuthenticateResponse: sessions.SessionState{ AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
Email: "blah@blah.com", Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, }},
sessions.MockCSRFStore{ sessions.MockCSRFStore{
ResponseCSRF: "csrf", ResponseCSRF: "csrf",

View file

@ -1,5 +1,5 @@
// Package providers implements OpenID Connect client logic for the set of supported identity // Package providers authentication for third party identity providers (IdP) using OpenID
// providers. // Connect, an identity layer on top of the OAuth 2.0 RFC6749 protocol.
// OpenID Connect 1.0 is a simple identity layer on top of the OAuth 2.0 RFC6749 protocol. //
// https://openid.net/specs/openid-connect-core-1_0.html // see: https://openid.net/specs/openid-connect-core-1_0.html
package providers // import "github.com/pomerium/pomerium/internal/providers" package providers // import "github.com/pomerium/pomerium/internal/providers"

View file

@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
import ( import (
"context" "context"
"errors"
oidc "github.com/pomerium/go-oidc" oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -19,7 +18,7 @@ type OIDCProvider struct {
func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) { func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) {
ctx := context.Background() ctx := context.Background()
if p.ProviderURL == "" { if p.ProviderURL == "" {
return nil, errors.New("missing required provider url") return nil, ErrMissingProviderURL
} }
var err error var err error
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL) p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)

View file

@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
import ( import (
"context" "context"
"errors"
"net/url" "net/url"
oidc "github.com/pomerium/go-oidc" oidc "github.com/pomerium/go-oidc"
@ -25,7 +24,7 @@ type OktaProvider struct {
func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) { func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) {
ctx := context.Background() ctx := context.Background()
if p.ProviderURL == "" { if p.ProviderURL == "" {
return nil, errors.New("missing required provider url") return nil, ErrMissingProviderURL
} }
var err error var err error
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL) p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)

View file

@ -29,6 +29,11 @@ const (
OktaProviderName = "okta" OktaProviderName = "okta"
) )
var (
// ErrMissingProviderURL is returned when the CB state is half open and the requests count is over the cb maxRequests
ErrMissingProviderURL = errors.New("proxy/providers: missing provider url")
)
// Provider is an interface exposing functions necessary to interact with a given provider. // Provider is an interface exposing functions necessary to interact with a given provider.
type Provider interface { type Provider interface {
Authenticate(string) (*sessions.SessionState, error) Authenticate(string) (*sessions.SessionState, error)

17
go.mod
View file

@ -5,18 +5,17 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/mock v1.2.0 github.com/golang/mock v1.2.0
github.com/golang/protobuf v1.2.0 github.com/golang/protobuf v1.2.0
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pomerium/envconfig v1.3.1-0.20190112072701-14cbcf832d31 github.com/pomerium/envconfig v1.3.1-0.20190112072701-14cbcf832d31
github.com/pomerium/go-oidc v2.0.0+incompatible github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/rs/zerolog v1.11.0 github.com/rs/zerolog v1.11.0
github.com/stretchr/testify v1.2.2 // indirect github.com/stretchr/testify v1.3.0 // indirect
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 golang.org/x/net v0.0.0-20190213061140-3a22650c66bd
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 // indirect
google.golang.org/appengine v1.4.0 // indirect google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa // indirect
google.golang.org/grpc v1.18.0 google.golang.org/grpc v1.18.0
gopkg.in/square/go-jose.v2 v2.2.1 gopkg.in/square/go-jose.v2 v2.2.2
) )

38
go.sum
View file

@ -1,11 +1,14 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc= github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8= github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
@ -23,34 +26,45 @@ github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAm
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M= github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis= golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE= golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635 h1:dOJmQysgY8iOBECuNp0vlKHWEtfiTnyjisEizRV3/4o=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 h1:g/Jesu8+QLnA0CPzF3E1pURg0Byr7i6jLoX5sqjcAh0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190116161447-11f53e031339/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa h1:FVL+/MjP2dzG4PxLpCJR7B6esIia88UAbsfYUrCc8U4=
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA= google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA=
google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic= gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA=
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -31,45 +31,41 @@ type SessionStore interface {
type CookieStore struct { type CookieStore struct {
Name string Name string
CSRFCookieName string CSRFCookieName string
CookieCipher cryptutil.Cipher
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration CookieRefresh time.Duration
CookieSecure bool CookieSecure bool
CookieHTTPOnly bool CookieHTTPOnly bool
CookieDomain string CookieDomain string
CookieCipher cryptutil.Cipher
SessionLifetimeTTL time.Duration
} }
// CreateCookieCipher creates a new miscreant cipher with the cookie secret // CookieStoreOptions holds options for CookieStore
func CreateCookieCipher(cookieSecret []byte) func(s *CookieStore) error { type CookieStoreOptions struct {
return func(s *CookieStore) error { Name string
cipher, err := cryptutil.NewCipher(cookieSecret) CookieSecure bool
if err != nil { CookieHTTPOnly bool
return fmt.Errorf("cookie-secret error: %s", err.Error()) CookieDomain string
} CookieExpire time.Duration
s.CookieCipher = cipher CookieCipher cryptutil.Cipher
return nil
}
} }
// NewCookieStore returns a new session with ciphers for each of the cookie secrets // NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) { func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
c := &CookieStore{ if opts.Name == "" {
Name: cookieName, return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
} }
if opts.CookieCipher == nil {
for _, f := range optFuncs { return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
err := f(c)
if err != nil {
return nil, err
} }
} return &CookieStore{
Name: opts.Name,
return c, nil CSRFCookieName: fmt.Sprintf("%v_%v", opts.Name, "csrf"),
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
CookieCipher: opts.CookieCipher,
}, nil
} }
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
@ -80,16 +76,19 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
if s.CookieDomain != "" { if s.CookieDomain != "" {
domain = s.CookieDomain domain = s.CookieDomain
} }
c := &http.Cookie{
return &http.Cookie{
Name: name, Name: name,
Value: value, Value: value,
Path: "/", Path: "/",
Domain: domain, Domain: domain,
HttpOnly: s.CookieHTTPOnly, HttpOnly: s.CookieHTTPOnly,
Secure: s.CookieSecure, Secure: s.CookieSecure,
Expires: now.Add(expiration),
} }
// only set an expiration if we want one, otherwise default to non perm session based
if expiration != 0 {
c.Expires = now.Add(expiration)
}
return c
} }
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
@ -103,13 +102,13 @@ func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration
} }
// ClearCSRF clears the CSRF cookie from the request // ClearCSRF clears the CSRF cookie from the request
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) { func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
} }
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request // SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) { func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now())) http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
} }
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request // GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
@ -118,20 +117,19 @@ func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
} }
// ClearSession clears the session cookie from a request // ClearSession clears the session cookie from a request
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) { func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now())) http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
} }
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now())) http.SetCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
} }
// LoadSession returns a SessionState from the cookie in the request. // LoadSession returns a SessionState from the cookie in the request.
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) { func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
c, err := req.Cookie(s.Name) c, err := req.Cookie(s.Name)
if err != nil { if err != nil {
// always http.ErrNoCookie return nil, err // http.ErrNoCookie
return nil, err
} }
session, err := UnmarshalSession(c.Value, s.CookieCipher) session, err := UnmarshalSession(c.Value, s.CookieCipher)
if err != nil { if err != nil {
@ -141,12 +139,11 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
} }
// SaveSession saves a session state to a request sessions. // SaveSession saves a session state to a request sessions.
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error { func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
value, err := MarshalSession(sessionState, s.CookieCipher) value, err := MarshalSession(sessionState, s.CookieCipher)
if err != nil { if err != nil {
return err return err
} }
s.setSessionCookie(w, req, value)
s.setSessionCookie(rw, req, value)
return nil return nil
} }

View file

@ -1,348 +1,348 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package sessions
import ( import (
"encoding/base64" "errors"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=") type mockCipher struct{}
func TestCreateCookieCipher(t *testing.T) { func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
testCases := []struct { if string(s) == "error" {
name string return []byte(""), errors.New("error encrypting")
cookieSecret []byte
expectedError bool
}{
{
name: "normal case with base64 encoded secret",
cookieSecret: testEncodedCookieSecret,
},
{
name: "error when not base64 encoded",
cookieSecret: []byte("abcd"),
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewCookieStore("cookieName", CreateCookieCipher(tc.cookieSecret))
if !tc.expectedError {
testutil.Ok(t, err)
} else {
testutil.NotEqual(t, err, nil)
}
})
} }
return []byte("OK"), nil
} }
func TestNewSession(t *testing.T) { func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
testCases := []struct { if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if string(s) == "unmarshal error" || string(s) == "error" {
return errors.New("error")
}
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string name string
optFuncs []func(*CookieStore) error opts *CookieStoreOptions
expectedError bool want *CookieStore
expectedSession *CookieStore wantErr bool
}{ }{
{ {"good",
name: "default with no opt funcs set", &CookieStoreOptions{
expectedSession: &CookieStore{ Name: "_cookie",
Name: "cookieName",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour, CookieDomain: "pomerium.io",
CSRFCookieName: "cookieName_csrf", CookieExpire: 10 * time.Second,
CookieCipher: cipher,
}, },
}, &CookieStore{
{ Name: "_cookie",
name: "opt func with an error returns an error", CSRFCookieName: "_cookie_csrf",
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
expectedError: true,
},
{
name: "opt func overrides default values",
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
s.CookieExpire = time.Hour
return nil
}},
expectedSession: &CookieStore{
Name: "cookieName",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieExpire: time.Hour, CookieDomain: "pomerium.io",
CSRFCookieName: "cookieName_csrf", CookieExpire: 10 * time.Second,
CookieCipher: cipher,
}, },
false},
{"missing name",
&CookieStoreOptions{
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
}, },
nil,
true},
{"missing cipher",
&CookieStoreOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: nil,
},
nil,
true},
} }
for _, tt := range tests {
for _, tc := range testCases { t.Run(tt.name, func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { got, err := NewCookieStore(tt.opts)
session, err := NewCookieStore("cookieName", tc.optFuncs...) if (err != nil) != tt.wantErr {
if tc.expectedError { t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
testutil.NotEqual(t, err, nil) return
} else { }
testutil.Ok(t, err) if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
} }
testutil.Equal(t, tc.expectedSession, session)
}) })
} }
} }
func TestMakeSessionCookie(t *testing.T) { func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
type fields struct {
Name string
CSRFCookieName string
CookieCipher cryptutil.Cipher
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
}
now := time.Now() now := time.Now()
cookieValue := "cookieValue" tests := []struct {
expiration := time.Hour
cookieName := "cookieName"
testCases := []struct {
name string name string
optFuncs []func(*CookieStore) error domain string
expectedCookie *http.Cookie
cookieName string
value string
expiration time.Duration
want *http.Cookie
}{ }{
{ {"good", "http://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
name: "default cookie domain", {"domains with https", "https://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
expectedCookie: &http.Cookie{ {"domain with port", "http://pomerium.io:443", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
Name: cookieName, {"expiration set", "http://pomerium.io:443", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
Value: cookieValue, }
Path: "/", for _, tt := range tests {
Domain: "www.example.com", t.Run(tt.name, func(t *testing.T) {
HttpOnly: true, r := httptest.NewRequest("GET", tt.domain, nil)
Secure: true,
Expires: now.Add(expiration), s := &CookieStore{
}, Name: "_pomerium",
}, CSRFCookieName: "_pomerium_csrf",
{ CookieSecure: true,
name: "custom cookie domain set", CookieHTTPOnly: true,
optFuncs: []func(*CookieStore) error{ CookieDomain: "pomerium.io",
func(s *CookieStore) error { CookieExpire: 10 * time.Second,
s.CookieDomain = "buzzfeed.com" CookieCipher: cipher}
return nil
}, if got := s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
}, t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
expectedCookie: &http.Cookie{ }
Name: cookieName, if got := s.makeSessionCookie(r, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
Value: cookieValue, t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
Path: "/", }
Domain: "buzzfeed.com", got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
HttpOnly: true, tt.want.Name = "_pomerium_csrf"
Secure: true, if !reflect.DeepEqual(got, tt.want) {
Expires: now.Add(expiration), t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}, }
}, w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
} }
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
testutil.Equal(t, cookie, tc.expectedCookie)
}) })
} }
} }
func TestMakeSessionCSRFCookie(t *testing.T) { func TestCookieStore_SaveSession(t *testing.T) {
now := time.Now() cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cookieValue := "cookieValue" if err != nil {
expiration := time.Hour t.Fatal(err)
cookieName := "cookieName" }
csrfName := "cookieName_csrf" tests := []struct {
testCases := []struct {
name string name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
testutil.Equal(t, tc.expectedCookie, cookie)
})
}
}
func TestSetSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.setSessionCookie(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestSetCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.SetCSRF(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearSession(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("clear csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearCSRF(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestLoadCookiedSession(t *testing.T) {
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
expectedError error
sessionState *SessionState sessionState *SessionState
cipher cryptutil.Cipher
wantErr bool
wantLoadErr bool
}{ }{
{ {"good",
name: "no cookie set returns an error", &SessionState{
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {}, AccessToken: "token1234",
expectedError: http.ErrNoCookie, RefreshToken: "refresh4321",
}, LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
{ RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
name: "cookie set with cipher set", Email: "user@domain.com",
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)}, User: "user",
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) { }, cipher, false, false},
value, err := MarshalSession(sessionState, s.CookieCipher) {"bad cipher",
testutil.Ok(t, err) &SessionState{
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now())) AccessToken: "token1234",
}, RefreshToken: "refresh4321",
sessionState: &SessionState{ LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
Email: "example@email.com", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshToken: "abccdddd", Email: "user@domain.com",
AccessToken: "access", User: "user",
}, }, mockCipher{}, true, true},
}, }
{ for _, tt := range tests {
name: "cookie set with invalid value cipher set", t.Run(tt.name, func(t *testing.T) {
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)}, s := &CookieStore{
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) { Name: "_pomerium",
value := "574b776a7c934d6b9fc42ec63a389f79" CSRFCookieName: "_pomerium_csrf",
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now())) CookieSecure: true,
}, CookieHTTPOnly: true,
expectedError: ErrInvalidSession, CookieDomain: "pomerium.io",
}, CookieExpire: 10 * time.Second,
CookieCipher: tt.cipher}
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() {
t.Log(cookie)
r.AddCookie(cookie)
} }
for _, tc := range testCases { state, err := s.LoadSession(r)
t.Run(tc.name, func(t *testing.T) { if (err != nil) != tt.wantLoadErr {
session, err := NewCookieStore(cookieName, tc.optFuncs...) t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
testutil.Ok(t, err) return
req := httptest.NewRequest("GET", "https://www.example.com", nil) }
tc.setupCookies(t, req, session, tc.sessionState) if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
s, err := session.LoadSession(req) t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
}
})
}
}
testutil.Equal(t, tc.expectedError, err) func TestMockCSRFStore(t *testing.T) {
testutil.Equal(t, tc.sessionState, s) tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
}) })
} }
} }
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *SessionState
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
ResponseSession: "test",
Session: &SessionState{AccessToken: "AccessToken"},
SaveError: nil,
LoadError: nil,
},
&SessionState{AccessToken: "AccessToken"},
false,
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
err := ms.SaveSession(nil, nil, tt.saveSession)
if (err != nil) != tt.wantSaveErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return
}
got, err := ms.LoadSession(nil)
if (err != nil) != tt.wantLoadErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
}
ms.ClearSession(nil, nil)
if ms.ResponseSession != "" {
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
}
})
}
}

View file

@ -35,7 +35,7 @@ type MockSessionStore struct {
} }
// ClearSession clears the ResponseSession // ClearSession clears the ResponseSession
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) { func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = "" ms.ResponseSession = ""
} }

View file

@ -20,11 +20,9 @@ type SessionState struct {
RefreshDeadline time.Time `json:"refresh_deadline"` RefreshDeadline time.Time `json:"refresh_deadline"`
LifetimeDeadline time.Time `json:"lifetime_deadline"` LifetimeDeadline time.Time `json:"lifetime_deadline"`
ValidDeadline time.Time `json:"valid_deadline"`
GracePeriodStart time.Time `json:"grace_period_start"`
Email string `json:"email"` Email string `json:"email"`
User string `json:"user"` // 'sub' in jwt parlance User string `json:"user"` // 'sub' in jwt
Groups []string `json:"groups"` Groups []string `json:"groups"`
} }
@ -38,11 +36,6 @@ func (s *SessionState) RefreshPeriodExpired() bool {
return isExpired(s.RefreshDeadline) return isExpired(s.RefreshDeadline)
} }
// ValidationPeriodExpired returns true if the validation period has expired
func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}
func isExpired(t time.Time) bool { func isExpired(t time.Time) bool {
return t.Before(time.Now()) return t.Before(time.Now())
} }
@ -64,7 +57,7 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
return s, nil return s, nil
} }
// ExtendDeadline returns the time extended by a given duration // ExtendDeadline returns the time extended by a given duration, truncated by second
func ExtendDeadline(ttl time.Duration) time.Time { func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second) return time.Now().Add(ttl).Truncate(time.Second)
} }

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package sessions
import ( import (
"reflect" "reflect"
@ -18,11 +18,8 @@ func TestSessionStateSerialization(t *testing.T) {
want := &SessionState{ want := &SessionState{
AccessToken: "token1234", AccessToken: "token1234",
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
Email: "user@domain.com", Email: "user@domain.com",
User: "user", User: "user",
} }
@ -48,24 +45,38 @@ func TestSessionStateExpirations(t *testing.T) {
session := &SessionState{ session := &SessionState{
AccessToken: "token1234", AccessToken: "token1234",
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(-1 * time.Hour), LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour), RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
Email: "user@domain.com", Email: "user@domain.com",
User: "user", User: "user",
} }
if !session.LifetimePeriodExpired() { if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired") t.Errorf("expected lifetime period to be expired")
} }
if !session.RefreshPeriodExpired() { if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired") t.Errorf("expected lifetime period to be expired")
} }
if !session.ValidationPeriodExpired() { }
t.Errorf("expcted lifetime period to be expired")
func TestExtendDeadline(t *testing.T) {
// tons of wiggle room here
now := time.Now().Truncate(time.Second)
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)},
{"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want)
}
})
} }
} }

View file

@ -1,75 +0,0 @@
// Original Copyright 2013 The Go Authors. All rights reserved.
//
// Modified by BuzzFeed to return duplicate counts.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression mechanism.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Count bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value of Count indicates how many tiems v was given to multiple callers.
// Count will be zero for requests are shared and only be non-zero for the originating request.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, 0, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.dups, c.err
}

View file

@ -1,87 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, _, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, _, err := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
return v, nil
}
const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, _, err := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

View file

@ -1,46 +0,0 @@
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
// testing util functions copied from https://github.com/benbjohnson/testing
import (
"fmt"
"path/filepath"
"reflect"
"runtime"
"testing"
)
// Assert fails the test if the condition is false.
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
if !condition {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
tb.FailNow()
}
}
// Ok fails the test if an err is not nil.
func Ok(tb testing.TB, err error) {
if err != nil {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
tb.FailNow()
}
}
// Equal fails the test if exp is not equal to act.
func Equal(tb testing.TB, exp, act interface{}) {
if !reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}
// NotEqual fails the test if exp is equal to act.
func NotEqual(tb testing.TB, exp, act interface{}) {
if reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}

View file

@ -25,7 +25,7 @@ type Options struct {
// InternalAddr is the internal (behind the ingress) address to use when making an // InternalAddr is the internal (behind the ingress) address to use when making an
// authentication connection. If empty, Addr is used. // authentication connection. If empty, Addr is used.
InternalAddr string InternalAddr string
// OverrideServerName overrides the server name used to verify the hostname on the // OverideCertificateName overrides the server name used to verify the hostname on the
// returned certificates from the server. gRPC internals also use it to override the virtual // returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set. // hosting name if it is set.
OverideCertificateName string OverideCertificateName string

View file

@ -203,15 +203,14 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
} }
// We store the session in a cookie and redirect the user back to the application // We store the session in a cookie and redirect the user back to the application
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{ err = p.sessionStore.SaveSession(w, r,
&sessions.SessionState{
AccessToken: rr.AccessToken, AccessToken: rr.AccessToken,
RefreshToken: rr.RefreshToken, RefreshToken: rr.RefreshToken,
IDToken: rr.IDToken, IDToken: rr.IDToken,
User: rr.User, User: rr.User,
Email: rr.Email, Email: rr.Email,
RefreshDeadline: (rr.Expiry).Truncate(time.Second), RefreshDeadline: (rr.Expiry).Truncate(time.Second),
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
ValidDeadline: extendDeadline(p.CookieExpire),
}) })
if err != nil { if err != nil {
log.FromRequest(r).Error().Msg("error saving session") log.FromRequest(r).Error().Msg("error saving session")
@ -250,9 +249,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
} }
} }
// ! ! ! // todo(bdd): add authorization service validation
// todo(bdd): ! Authorization service goes here !
// ! ! !
// We have validated the users request and now proxy their request to the provided upstream. // We have validated the users request and now proxy their request to the provided upstream.
route, ok := p.router(r) route, ok := p.router(r)
@ -278,14 +275,10 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
return err return err
} }
if session.LifetimePeriodExpired() {
log.FromRequest(r).Info().Msg("proxy: lifetime expired")
return sessions.ErrLifetimeExpired
}
if session.RefreshPeriodExpired() { if session.RefreshPeriodExpired() {
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a // AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
// refresh token (which doesn't change) can be used to request a new access-token. If access // refresh token (which doesn't change) can be used to request a new access-token. If access
// is revoked by identity provider, or no refresh token is set request will return an error // is revoked by identity provider, or no refresh token is set, request will return an error
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken) accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
if err != nil { if err != nil {
log.FromRequest(r).Warn(). log.FromRequest(r).Warn().

View file

@ -272,7 +272,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
proxy.sessionStore = tt.session proxy.sessionStore = &tt.session
proxy.csrfStore = tt.csrf proxy.csrfStore = tt.csrf
proxy.AuthenticateClient = tt.authenticator proxy.AuthenticateClient = tt.authenticator
proxy.cipher = mockCipher{} proxy.cipher = mockCipher{}
@ -352,12 +352,6 @@ func TestProxy_Proxy(t *testing.T) {
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second), LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
} }
tests := []struct { tests := []struct {
@ -368,11 +362,10 @@ func TestProxy_Proxy(t *testing.T) {
wantStatus int wantStatus int
}{ }{
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know. // weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable}, {"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError}, {"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
// redirect to start auth process // redirect to start auth process
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound}, {"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -402,13 +395,8 @@ func TestProxy_Authenticate(t *testing.T) {
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second), LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
} }
expiredDeadline := &sessions.SessionState{ expiredDeadline := &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
@ -426,25 +414,21 @@ func TestProxy_Authenticate(t *testing.T) {
{"cannot save session", {"cannot save session",
"https://corp.example.com/", "https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"}, map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")}, &sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
authenticator.MockAuthenticate{}, true}, authenticator.MockAuthenticate{}, true},
{"cannot load session", {"cannot load session",
"https://corp.example.com/", "https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"}, map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true}, &sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired lifetime",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
{"expired session", {"expired session",
"https://corp.example.com/", "https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"}, map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false}, &sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
{"bad refresh authenticator", {"bad refresh authenticator",
"https://corp.example.com/", "https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"}, map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: expiredDeadline, Session: expiredDeadline,
}, },
authenticator.MockAuthenticate{RefreshError: errors.New("error")}, authenticator.MockAuthenticate{RefreshError: errors.New("error")},
@ -453,7 +437,7 @@ func TestProxy_Authenticate(t *testing.T) {
{"good", {"good",
"https://corp.example.com/", "https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"}, map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false}, &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -133,9 +133,6 @@ type Proxy struct {
AuthenticateClient authenticator.Authenticator AuthenticateClient authenticator.Authenticator
// session // session
CookieExpire time.Duration
CookieRefresh time.Duration
CookieLifetimeTTL time.Duration
cipher cryptutil.Cipher cipher cryptutil.Cipher
csrfStore sessions.CSRFStore csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
@ -163,13 +160,14 @@ func New(opts *Options) (*Proxy, error) {
return nil, fmt.Errorf("cookie-secret error: %s", err.Error()) return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
} }
cookieStore, err := sessions.NewCookieStore(opts.CookieName, cookieStore, err := sessions.NewCookieStore(
sessions.CreateCookieCipher(decodedSecret), &sessions.CookieStoreOptions{
func(c *sessions.CookieStore) error { Name: opts.CookieName,
c.CookieDomain = opts.CookieDomain CookieDomain: opts.CookieDomain,
c.CookieHTTPOnly = opts.CookieHTTPOnly CookieSecure: opts.CookieSecure,
c.CookieExpire = opts.CookieExpire CookieHTTPOnly: opts.CookieHTTPOnly,
return nil CookieExpire: opts.CookieExpire,
CookieCipher: cipher,
}) })
if err != nil { if err != nil {
@ -187,8 +185,6 @@ func New(opts *Options) (*Proxy, error) {
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
redirectURL: &url.URL{Path: "/.pomerium/callback"}, redirectURL: &url.URL{Path: "/.pomerium/callback"},
templates: templates.New(), templates: templates.New(),
CookieExpire: opts.CookieExpire,
CookieLifetimeTTL: opts.CookieLifetimeTTL,
} }
for from, to := range opts.Routes { for from, to := range opts.Routes {
@ -200,7 +196,7 @@ func New(opts *Options) (*Proxy, error) {
return nil, err return nil, err
} }
p.Handle(fromURL.Host, handler) p.Handle(fromURL.Host, handler)
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route") log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy: new route")
} }
p.AuthenticateClient, err = authenticator.New( p.AuthenticateClient, err = authenticator.New(

View file

@ -139,6 +139,7 @@ func testOptions() *Options {
AuthenticateURL: authurl, AuthenticateURL: authurl,
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=", CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieName: "pomerium",
} }
} }