all: general cleanup readying for tagged release (#48)

- docs: add code coverage to readme
- internal/sessions: refactor sessions to clarify lifetime
- authenticate: simplified signin flow
- deployment: update go mods
- internal/testutil: removed package
- internal/singleflight: removed package
This commit is contained in:
Bobby DeSimone 2019-02-16 12:43:18 -08:00 committed by GitHub
parent 13c03a2b5c
commit dbafc691c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 712 additions and 1017 deletions

View file

@ -4,24 +4,24 @@
# Pomerium
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://codecov.io/gh/pomerium/pomerium/branch/master/graph/badge.svg)](https://codecov.io/gh/pomerium/pomerium)
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://img.shields.io/codecov/c/github/pomerium/pomerium.svg?style=flat)](https://codecov.io/gh/pomerium/pomerium)
Pomerium is a tool for managing secure access to internal applications and resources.
Use Pomerium to:
- provide a unified gateway (reverse-proxy) to internal corporate applications.
- enforce dynamic access policy based on context, identity, and device state.
- deploy mutual authenticated encryption (mTLS).
- aggregate logging and telemetry data.
- provide a single-sign-on gateway to internal applications.
- enforce dynamic access policy based on **context**, **identity**, and **device state**.
- aggregate access logs and telemetry data.
- an alternative to a VPN.
Check out [awesome-zero-trust] to learn more about some problems Pomerium attempts to address.
Check out [awesome-zero-trust] to learn more about some of the problems Pomerium attempts to address.
## Docs
To get started with pomerium, check out our [quick start guide].
For comprehensive docs see our [documentation] and the [godocs].
For comprehensive docs, and tutorials see our [documentation] and the [godocs].
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
[documentation]: https://www.pomerium.io/docs/

View file

@ -23,21 +23,23 @@ var defaultOptions = &Options{
CookieSecure: true,
CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(30) * time.Minute,
CookieLifetimeTTL: time.Duration(720) * time.Hour,
}
// Options details the available configuration settings for the authenticate service
type Options struct {
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
// SharedKey is used to authenticate requests between services
SharedKey string `envconfig:"SHARED_SECRET"`
// RedirectURL specifies the callback url following third party authentication
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
// Coarse authorization based on user email domain
// todo(bdd) : to be replaced with authorization module
AllowedDomains []string `envconfig:"ALLOWED_DOMAINS"`
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
// Session/Cookie management
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
CookieName string
CookieSecret string `envconfig:"COOKIE_SECRET"`
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
@ -45,10 +47,9 @@ type Options struct {
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
CookieLifetimeTTL time.Duration `envconfig:"COOKIE_LIFETIME"`
// IdentityProvider provider configuration variables as specified by RFC6749
// See: https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
// https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
ClientID string `envconfig:"IDP_CLIENT_ID"`
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
Provider string `envconfig:"IDP_PROVIDER"`
@ -103,17 +104,13 @@ func (o *Options) Validate() error {
// Authenticate validates a user's identity
type Authenticate struct {
RedirectURL *url.URL
Validator func(string) bool
AllowedDomains []string
ProxyRootDomains []string
CookieSecure bool
SharedKey string
CookieLifetimeTTL time.Duration
RedirectURL *url.URL
AllowedDomains []string
ProxyRootDomains []string
Validator func(string) bool
templates *template.Template
csrfStore sessions.CSRFStore
@ -137,35 +134,45 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
if err != nil {
return nil, err
}
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
sessions.CreateCookieCipher(decodedCookieSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
c.CookieHTTPOnly = opts.CookieHTTPOnly
c.CookieExpire = opts.CookieExpire
c.CookieSecure = opts.CookieSecure
return nil
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
Name: opts.CookieName,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
CookieCipher: cipher,
})
if err != nil {
return nil, err
}
provider, err := providers.New(
opts.Provider,
&providers.IdentityProvider{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
// SessionLifetimeTTL: opts.CookieLifetimeTTL,
Scopes: opts.Scopes,
})
if err != nil {
return nil, err
}
p := &Authenticate{
SharedKey: opts.SharedKey,
RedirectURL: opts.RedirectURL,
AllowedDomains: opts.AllowedDomains,
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
CookieSecure: opts.CookieSecure,
RedirectURL: opts.RedirectURL,
templates: templates.New(),
csrfStore: cookieStore,
sessionStore: cookieStore,
cipher: cipher,
}
p.provider, err = newProvider(opts)
if err != nil {
return nil, err
provider: provider,
}
// validation via dependency injected function
@ -179,20 +186,6 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
return p, nil
}
func newProvider(opts *Options) (providers.Provider, error) {
pd := &providers.IdentityProvider{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
SessionLifetimeTTL: opts.CookieLifetimeTTL,
Scopes: opts.Scopes,
}
np, err := providers.New(opts.Provider, pd)
return np, err
}
func dotPrependDomains(d []string) []string {
for i := range d {
if d[i] != "" && !strings.HasPrefix(d[i], ".") {

View file

@ -19,8 +19,9 @@ func testOptions() *Options {
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieRefresh: time.Duration(1) * time.Hour,
CookieLifetimeTTL: time.Duration(720) * time.Hour,
// CookieLifetimeTTL: time.Duration(720) * time.Hour,
CookieExpire: time.Duration(168) * time.Hour,
CookieName: "pomerium",
}
}
@ -130,37 +131,6 @@ func Test_dotPrependDomains(t *testing.T) {
}
}
func Test_newProvider(t *testing.T) {
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
goodOpts := &Options{
RedirectURL: redirectURL,
Provider: "google",
ProviderURL: "",
ClientID: "cllient-id",
ClientSecret: "client-secret",
}
tests := []struct {
name string
opts *Options
wantErr bool
}{
{"good", goodOpts, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newProvider(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("newProvider() error = %v, wantErr %v", err, tt.wantErr)
return
}
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("newProvider() = %v, want %v", got, tt.want)
// }
})
}
}
func TestNew(t *testing.T) {
good := testOptions()
good.Provider = "google"

View file

@ -117,7 +117,6 @@ func TestAuthenticate_Authenticate(t *testing.T) {
}
lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
vt := time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC()
vtProto, err := ptypes.TimestampProto(rt)
if err != nil {
t.Fatal("failed to parse timestamp")
@ -128,7 +127,7 @@ func TestAuthenticate_Authenticate(t *testing.T) {
RefreshToken: "refresh4321",
LifetimeDeadline: lt,
RefreshDeadline: rt,
ValidDeadline: vt,
Email: "user@domain.com",
User: "user",
}

View file

@ -16,7 +16,8 @@ import (
"github.com/pomerium/pomerium/internal/version"
)
// securityHeaders corresponds to HTTP response headers related to security.
// securityHeaders corresponds to HTTP response headers that help to protect against protocol
// downgrade attacks and cookie hijacking.
// https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers
var securityHeaders = map[string]string{
"Strict-Transport-Security": "max-age=31536000",
@ -28,7 +29,7 @@ var securityHeaders = map[string]string{
"Referrer-Policy": "Same-origin",
}
// Handler returns the Http.Handlers for authenticate, callback, and refresh
// Handler returns the authenticate service's HTTP request multiplexer, and routes.
func (a *Authenticate) Handler() http.Handler {
// set up our standard middlewares
stdMiddleware := middleware.NewChain()
@ -80,12 +81,6 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
return nil, err
}
// if long-lived lifetime has expired, clear session
if session.LifetimePeriodExpired() {
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
a.sessionStore.ClearSession(w, r)
return nil, sessions.ErrLifetimeExpired
}
// check if session refresh period is up
if session.RefreshPeriodExpired() {
newToken, err := a.provider.Refresh(session.RefreshToken)
@ -130,32 +125,23 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
return session, nil
}
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
// SignIn handles the sign_in endpoint. It attempts to authenticate the user,
// and if the user is not authenticated, it renders a sign in page.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.authenticate(w, r)
switch err {
case nil:
// session good, redirect back to proxy
log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated")
a.ProxyCallback(w, r, session)
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
// session invalid, authenticate
log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure")
if err != http.ErrNoCookie {
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: authenticate error")
a.sessionStore.ClearSession(w, r)
}
a.OAuthStart(w, r)
default:
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
}
log.FromRequest(r).Info().Msg("authenticate: user authenticated")
a.ProxyCallback(w, r, session)
}
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
// url params, of the user's session state.
// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2
// url params, of the user's session state as specified in RFC6749 3.1.2.
// https://tools.ietf.org/html/rfc6749#section-3.1.2
func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
err := r.ParseForm()
if err != nil {
@ -201,9 +187,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
return u.String()
}
// SignOut signs the user out by trying to revoke the users remote identity provider session
// then removes the associated local session state.
// Handles both GET and POST of form.
// SignOut signs the user out by trying to revoke the user's remote identity session along with
// the associated local session state. Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
@ -256,8 +241,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirectURI, http.StatusFound)
}
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
// OAuthStart starts the authenticate process by redirecting to the identity provider.
// https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
if err != nil {
@ -298,7 +283,7 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
}
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
// was an error. If successful, the user is redirected back to the proxy-service.
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r)
switch h := err.(type) {

View file

@ -71,29 +71,19 @@ func TestAuthenticate_Handler(t *testing.T) {
func TestAuthenticate_authenticate(t *testing.T) {
// sessions.MockSessionStore{Session: expiredLifetime}
goodSession := sessions.MockSessionStore{
goodSession := &sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
expiredSession := sessions.MockSessionStore{
expiredRefresPeriod := &sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * -time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
expiredRefresPeriod := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
tests := []struct {
@ -106,18 +96,16 @@ func TestAuthenticate_authenticate(t *testing.T) {
}{
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
{"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
{"session fails after good validation", sessions.MockSessionStore{
{"session fails after good validation", &sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
}}, providers.MockProvider{ValidateResponse: true},
trueValidator, nil, true},
{"refresh expired",
expiredRefresPeriod,
providers.MockProvider{
@ -136,14 +124,13 @@ func TestAuthenticate_authenticate(t *testing.T) {
},
trueValidator, nil, true},
{"refresh expired failed save",
sessions.MockSessionStore{
&sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
providers.MockProvider{
ValidateResponse: true,
@ -182,29 +169,23 @@ func TestAuthenticate_SignIn(t *testing.T) {
wantCode int
}{
{"good",
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
providers.MockProvider{ValidateResponse: true},
trueValidator,
403},
// {"no session",
// sessions.MockSessionStore{
// Session: &sessions.SessionState{
// AccessToken: "AccessToken",
// RefreshToken: "RefreshToken",
// LifetimeDeadline: time.Now().Add(-10 * time.Second),
// RefreshDeadline: time.Now().Add(10 * time.Second),
// ValidDeadline: time.Now().Add(10 * time.Second),
// }},
// providers.MockProvider{ValidateResponse: true},
// trueValidator,
// 200},
http.StatusForbidden},
{"session fails after good validation", &sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}}, providers.MockProvider{ValidateResponse: true},
trueValidator, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -212,6 +193,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
sessionStore: tt.session,
provider: tt.provider,
Validator: tt.validator,
RedirectURL: uriParse("http://www.pomerium.io"),
csrfStore: &sessions.MockCSRFStore{},
SharedKey: "secret",
cipher: mockCipher{},
}
r := httptest.NewRequest("GET", "/sign-in", nil)
w := httptest.NewRecorder()
@ -264,11 +249,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
302,
"<a href=\"https://corp.pomerium.io/?code=ok&amp;state=state\">Found</a>."},
{"no state",
@ -278,11 +262,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
403,
"no state parameter supplied"},
{"no redirect_url",
@ -292,11 +275,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
403,
"no redirect_uri parameter"},
{"malformed redirect_url",
@ -306,11 +288,10 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
400,
"malformed redirect_uri"},
}
@ -389,14 +370,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusFound,
@ -407,14 +387,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{RevokeError: errors.New("OH NO")},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
@ -426,14 +405,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusOK,
@ -444,15 +422,14 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
LoadError: errors.New("uh oh"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
@ -463,14 +440,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
@ -512,7 +488,6 @@ func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string
}
func TestAuthenticate_OAuthStart(t *testing.T) {
tests := []struct {
name string
method string
@ -634,15 +609,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -657,15 +631,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -681,15 +654,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -704,7 +676,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateError: errors.New("error"),
},
@ -721,15 +693,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{SaveError: errors.New("error")},
&sessions.MockSessionStore{SaveError: errors.New("error")},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -744,15 +715,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
falseValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -768,15 +738,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -791,15 +760,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -814,15 +782,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
"nonce:https://corp.pomerium.io",
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -837,15 +804,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -860,15 +826,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",

View file

@ -1,5 +1,5 @@
// Package providers implements OpenID Connect client logic for the set of supported identity
// providers.
// OpenID Connect 1.0 is a simple identity layer on top of the OAuth 2.0 RFC6749 protocol.
// https://openid.net/specs/openid-connect-core-1_0.html
// Package providers authentication for third party identity providers (IdP) using OpenID
// Connect, an identity layer on top of the OAuth 2.0 RFC6749 protocol.
//
// see: https://openid.net/specs/openid-connect-core-1_0.html
package providers // import "github.com/pomerium/pomerium/internal/providers"

View file

@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"errors"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
@ -19,7 +18,7 @@ type OIDCProvider struct {
func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) {
ctx := context.Background()
if p.ProviderURL == "" {
return nil, errors.New("missing required provider url")
return nil, ErrMissingProviderURL
}
var err error
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)

View file

@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"errors"
"net/url"
oidc "github.com/pomerium/go-oidc"
@ -25,7 +24,7 @@ type OktaProvider struct {
func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) {
ctx := context.Background()
if p.ProviderURL == "" {
return nil, errors.New("missing required provider url")
return nil, ErrMissingProviderURL
}
var err error
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)

View file

@ -29,6 +29,11 @@ const (
OktaProviderName = "okta"
)
var (
// ErrMissingProviderURL is returned when the CB state is half open and the requests count is over the cb maxRequests
ErrMissingProviderURL = errors.New("proxy/providers: missing provider url")
)
// Provider is an interface exposing functions necessary to interact with a given provider.
type Provider interface {
Authenticate(string) (*sessions.SessionState, error)

17
go.mod
View file

@ -5,18 +5,17 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/mock v1.2.0
github.com/golang/protobuf v1.2.0
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pomerium/envconfig v1.3.1-0.20190112072701-14cbcf832d31
github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/rs/zerolog v1.11.0
github.com/stretchr/testify v1.2.2 // indirect
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
golang.org/x/net v0.0.0-20181220203305-927f97764cc3
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect
google.golang.org/appengine v1.4.0 // indirect
github.com/stretchr/testify v1.3.0 // indirect
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 // indirect
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa // indirect
google.golang.org/grpc v1.18.0
gopkg.in/square/go-jose.v2 v2.2.1
gopkg.in/square/go-jose.v2 v2.2.2
)

38
go.sum
View file

@ -1,11 +1,14 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
@ -23,34 +26,45 @@ github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAm
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635 h1:dOJmQysgY8iOBECuNp0vlKHWEtfiTnyjisEizRV3/4o=
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 h1:g/Jesu8+QLnA0CPzF3E1pURg0Byr7i6jLoX5sqjcAh0=
golang.org/x/sys v0.0.0-20190116161447-11f53e031339/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa h1:FVL+/MjP2dzG4PxLpCJR7B6esIia88UAbsfYUrCc8U4=
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA=
google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic=
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA=
gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -31,45 +31,41 @@ type SessionStore interface {
type CookieStore struct {
Name string
CSRFCookieName string
CookieCipher cryptutil.Cipher
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
CookieCipher cryptutil.Cipher
SessionLifetimeTTL time.Duration
}
// CreateCookieCipher creates a new miscreant cipher with the cookie secret
func CreateCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
return func(s *CookieStore) error {
cipher, err := cryptutil.NewCipher(cookieSecret)
if err != nil {
return fmt.Errorf("cookie-secret error: %s", err.Error())
}
s.CookieCipher = cipher
return nil
}
// CookieStoreOptions holds options for CookieStore
type CookieStoreOptions struct {
Name string
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
CookieExpire time.Duration
CookieCipher cryptutil.Cipher
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
c := &CookieStore{
Name: cookieName,
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
}
for _, f := range optFuncs {
err := f(c)
if err != nil {
return nil, err
if opts.CookieCipher == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
}
}
return c, nil
return &CookieStore{
Name: opts.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Name, "csrf"),
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
CookieCipher: opts.CookieCipher,
}, nil
}
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
@ -80,16 +76,19 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
if s.CookieDomain != "" {
domain = s.CookieDomain
}
return &http.Cookie{
c := &http.Cookie{
Name: name,
Value: value,
Path: "/",
Domain: domain,
HttpOnly: s.CookieHTTPOnly,
Secure: s.CookieSecure,
Expires: now.Add(expiration),
}
// only set an expiration if we want one, otherwise default to non perm session based
if expiration != 0 {
c.Expires = now.Add(expiration)
}
return c
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
@ -103,13 +102,13 @@ func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration
}
// ClearCSRF clears the CSRF cookie from the request
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
@ -118,20 +117,19 @@ func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
}
// ClearSession clears the session cookie from a request
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
}
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
}
// LoadSession returns a SessionState from the cookie in the request.
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
c, err := req.Cookie(s.Name)
if err != nil {
// always http.ErrNoCookie
return nil, err
return nil, err // http.ErrNoCookie
}
session, err := UnmarshalSession(c.Value, s.CookieCipher)
if err != nil {
@ -141,12 +139,11 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
}
// SaveSession saves a session state to a request sessions.
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
value, err := MarshalSession(sessionState, s.CookieCipher)
if err != nil {
return err
}
s.setSessionCookie(rw, req, value)
s.setSessionCookie(w, req, value)
return nil
}

View file

@ -1,348 +1,348 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package sessions
import (
"encoding/base64"
"fmt"
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/internal/cryptutil"
)
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
type mockCipher struct{}
func TestCreateCookieCipher(t *testing.T) {
testCases := []struct {
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if string(s) == "unmarshal error" || string(s) == "error" {
return errors.New("error")
}
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
cookieSecret []byte
expectedError bool
opts *CookieStoreOptions
want *CookieStore
wantErr bool
}{
{
name: "normal case with base64 encoded secret",
cookieSecret: testEncodedCookieSecret,
},
{
name: "error when not base64 encoded",
cookieSecret: []byte("abcd"),
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewCookieStore("cookieName", CreateCookieCipher(tc.cookieSecret))
if !tc.expectedError {
testutil.Ok(t, err)
} else {
testutil.NotEqual(t, err, nil)
}
})
}
}
func TestNewSession(t *testing.T) {
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedError bool
expectedSession *CookieStore
}{
{
name: "default with no opt funcs set",
expectedSession: &CookieStore{
Name: "cookieName",
{"good",
&CookieStoreOptions{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: "cookieName_csrf",
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
},
{
name: "opt func with an error returns an error",
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
expectedError: true,
},
{
name: "opt func overrides default values",
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
s.CookieExpire = time.Hour
return nil
}},
expectedSession: &CookieStore{
Name: "cookieName",
&CookieStore{
Name: "_cookie",
CSRFCookieName: "_cookie_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: time.Hour,
CSRFCookieName: "cookieName_csrf",
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
false},
{"missing name",
&CookieStoreOptions{
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
nil,
true},
{"missing cipher",
&CookieStoreOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: nil,
},
nil,
true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore("cookieName", tc.optFuncs...)
if tc.expectedError {
testutil.NotEqual(t, err, nil)
} else {
testutil.Ok(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieStore(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
}
testutil.Equal(t, tc.expectedSession, session)
})
}
}
func TestMakeSessionCookie(t *testing.T) {
func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
type fields struct {
Name string
CSRFCookieName string
CookieCipher cryptutil.Cipher
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
}
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
testCases := []struct {
tests := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
domain string
cookieName string
value string
expiration time.Duration
want *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{"good", "http://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
{"domains with https", "https://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
{"domain with port", "http://pomerium.io:443", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
{"expiration set", "http://pomerium.io:443", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", tt.domain, nil)
s := &CookieStore{
Name: "_pomerium",
CSRFCookieName: "_pomerium_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher}
if got := s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}
if got := s.makeSessionCookie(r, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
tt.want.Name = "_pomerium_csrf"
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}
w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
testutil.Equal(t, cookie, tc.expectedCookie)
})
}
}
func TestMakeSessionCSRFCookie(t *testing.T) {
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
csrfName := "cookieName_csrf"
testCases := []struct {
func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
testutil.Equal(t, tc.expectedCookie, cookie)
})
}
}
func TestSetSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.setSessionCookie(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestSetCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.SetCSRF(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearSession(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("clear csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearCSRF(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestLoadCookiedSession(t *testing.T) {
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
expectedError error
sessionState *SessionState
cipher cryptutil.Cipher
wantErr bool
wantLoadErr bool
}{
{
name: "no cookie set returns an error",
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
expectedError: http.ErrNoCookie,
},
{
name: "cookie set with cipher set",
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value, err := MarshalSession(sessionState, s.CookieCipher)
testutil.Ok(t, err)
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
sessionState: &SessionState{
Email: "example@email.com",
RefreshToken: "abccdddd",
AccessToken: "access",
},
},
{
name: "cookie set with invalid value cipher set",
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value := "574b776a7c934d6b9fc42ec63a389f79"
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
expectedError: ErrInvalidSession,
},
{"good",
&SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}, cipher, false, false},
{"bad cipher",
&SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}, mockCipher{}, true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &CookieStore{
Name: "_pomerium",
CSRFCookieName: "_pomerium_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: tt.cipher}
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() {
t.Log(cookie)
r.AddCookie(cookie)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "https://www.example.com", nil)
tc.setupCookies(t, req, session, tc.sessionState)
s, err := session.LoadSession(req)
state, err := s.LoadSession(r)
if (err != nil) != tt.wantLoadErr {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return
}
if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
}
})
}
}
testutil.Equal(t, tc.expectedError, err)
testutil.Equal(t, tc.sessionState, s)
func TestMockCSRFStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
})
}
}
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *SessionState
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
ResponseSession: "test",
Session: &SessionState{AccessToken: "AccessToken"},
SaveError: nil,
LoadError: nil,
},
&SessionState{AccessToken: "AccessToken"},
false,
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
err := ms.SaveSession(nil, nil, tt.saveSession)
if (err != nil) != tt.wantSaveErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return
}
got, err := ms.LoadSession(nil)
if (err != nil) != tt.wantLoadErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
}
ms.ClearSession(nil, nil)
if ms.ResponseSession != "" {
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
}
})
}
}

View file

@ -35,7 +35,7 @@ type MockSessionStore struct {
}
// ClearSession clears the ResponseSession
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}

View file

@ -20,11 +20,9 @@ type SessionState struct {
RefreshDeadline time.Time `json:"refresh_deadline"`
LifetimeDeadline time.Time `json:"lifetime_deadline"`
ValidDeadline time.Time `json:"valid_deadline"`
GracePeriodStart time.Time `json:"grace_period_start"`
Email string `json:"email"`
User string `json:"user"` // 'sub' in jwt parlance
User string `json:"user"` // 'sub' in jwt
Groups []string `json:"groups"`
}
@ -38,11 +36,6 @@ func (s *SessionState) RefreshPeriodExpired() bool {
return isExpired(s.RefreshDeadline)
}
// ValidationPeriodExpired returns true if the validation period has expired
func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}
func isExpired(t time.Time) bool {
return t.Before(time.Now())
}
@ -64,7 +57,7 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
return s, nil
}
// ExtendDeadline returns the time extended by a given duration
// ExtendDeadline returns the time extended by a given duration, truncated by second
func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package sessions
import (
"reflect"
@ -18,11 +18,8 @@ func TestSessionStateSerialization(t *testing.T) {
want := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}
@ -48,24 +45,38 @@ func TestSessionStateExpirations(t *testing.T) {
session := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
Email: "user@domain.com",
User: "user",
}
if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}
if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}
if !session.ValidationPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
func TestExtendDeadline(t *testing.T) {
// tons of wiggle room here
now := time.Now().Truncate(time.Second)
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)},
{"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,75 +0,0 @@
// Original Copyright 2013 The Go Authors. All rights reserved.
//
// Modified by BuzzFeed to return duplicate counts.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression mechanism.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Count bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value of Count indicates how many tiems v was given to multiple callers.
// Count will be zero for requests are shared and only be non-zero for the originating request.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, 0, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.dups, c.err
}

View file

@ -1,87 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, _, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, _, err := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
return v, nil
}
const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, _, err := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

View file

@ -1,46 +0,0 @@
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
// testing util functions copied from https://github.com/benbjohnson/testing
import (
"fmt"
"path/filepath"
"reflect"
"runtime"
"testing"
)
// Assert fails the test if the condition is false.
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
if !condition {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
tb.FailNow()
}
}
// Ok fails the test if an err is not nil.
func Ok(tb testing.TB, err error) {
if err != nil {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
tb.FailNow()
}
}
// Equal fails the test if exp is not equal to act.
func Equal(tb testing.TB, exp, act interface{}) {
if !reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}
// NotEqual fails the test if exp is equal to act.
func NotEqual(tb testing.TB, exp, act interface{}) {
if reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}

View file

@ -25,7 +25,7 @@ type Options struct {
// InternalAddr is the internal (behind the ingress) address to use when making an
// authentication connection. If empty, Addr is used.
InternalAddr string
// OverrideServerName overrides the server name used to verify the hostname on the
// OverideCertificateName overrides the server name used to verify the hostname on the
// returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set.
OverideCertificateName string

View file

@ -203,15 +203,14 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
}
// We store the session in a cookie and redirect the user back to the application
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{
err = p.sessionStore.SaveSession(w, r,
&sessions.SessionState{
AccessToken: rr.AccessToken,
RefreshToken: rr.RefreshToken,
IDToken: rr.IDToken,
User: rr.User,
Email: rr.Email,
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
ValidDeadline: extendDeadline(p.CookieExpire),
})
if err != nil {
log.FromRequest(r).Error().Msg("error saving session")
@ -250,9 +249,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
}
}
// ! ! !
// todo(bdd): ! Authorization service goes here !
// ! ! !
// todo(bdd): add authorization service validation
// We have validated the users request and now proxy their request to the provided upstream.
route, ok := p.router(r)
@ -278,14 +275,10 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
return err
}
if session.LifetimePeriodExpired() {
log.FromRequest(r).Info().Msg("proxy: lifetime expired")
return sessions.ErrLifetimeExpired
}
if session.RefreshPeriodExpired() {
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
// refresh token (which doesn't change) can be used to request a new access-token. If access
// is revoked by identity provider, or no refresh token is set request will return an error
// is revoked by identity provider, or no refresh token is set, request will return an error
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
if err != nil {
log.FromRequest(r).Warn().

View file

@ -272,7 +272,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
proxy.sessionStore = tt.session
proxy.sessionStore = &tt.session
proxy.csrfStore = tt.csrf
proxy.AuthenticateClient = tt.authenticator
proxy.cipher = mockCipher{}
@ -352,12 +352,6 @@ func TestProxy_Proxy(t *testing.T) {
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
tests := []struct {
@ -368,11 +362,10 @@ func TestProxy_Proxy(t *testing.T) {
wantStatus int
}{
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
{"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
// redirect to start auth process
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
{"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -402,13 +395,8 @@ func TestProxy_Authenticate(t *testing.T) {
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
expiredDeadline := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
@ -426,25 +414,21 @@ func TestProxy_Authenticate(t *testing.T) {
{"cannot save session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
&sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
authenticator.MockAuthenticate{}, true},
{"cannot load session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired lifetime",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
&sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
&sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
{"bad refresh authenticator",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: expiredDeadline,
},
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
@ -453,7 +437,7 @@ func TestProxy_Authenticate(t *testing.T) {
{"good",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
&sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -133,9 +133,6 @@ type Proxy struct {
AuthenticateClient authenticator.Authenticator
// session
CookieExpire time.Duration
CookieRefresh time.Duration
CookieLifetimeTTL time.Duration
cipher cryptutil.Cipher
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
@ -163,13 +160,14 @@ func New(opts *Options) (*Proxy, error) {
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
}
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
sessions.CreateCookieCipher(decodedSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
c.CookieHTTPOnly = opts.CookieHTTPOnly
c.CookieExpire = opts.CookieExpire
return nil
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
Name: opts.CookieName,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
CookieCipher: cipher,
})
if err != nil {
@ -187,8 +185,6 @@ func New(opts *Options) (*Proxy, error) {
SharedKey: opts.SharedKey,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
templates: templates.New(),
CookieExpire: opts.CookieExpire,
CookieLifetimeTTL: opts.CookieLifetimeTTL,
}
for from, to := range opts.Routes {
@ -200,7 +196,7 @@ func New(opts *Options) (*Proxy, error) {
return nil, err
}
p.Handle(fromURL.Host, handler)
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy: new route")
}
p.AuthenticateClient, err = authenticator.New(

View file

@ -139,6 +139,7 @@ func testOptions() *Options {
AuthenticateURL: authurl,
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieName: "pomerium",
}
}