all: general cleanup readying for tagged release (#48)

- docs: add code coverage to readme
- internal/sessions: refactor sessions to clarify lifetime
- authenticate: simplified signin flow
- deployment: update go mods
- internal/testutil: removed package
- internal/singleflight: removed package
This commit is contained in:
Bobby DeSimone 2019-02-16 12:43:18 -08:00 committed by GitHub
parent 13c03a2b5c
commit dbafc691c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 712 additions and 1017 deletions

View file

@ -4,24 +4,24 @@
# Pomerium
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://codecov.io/gh/pomerium/pomerium/branch/master/graph/badge.svg)](https://codecov.io/gh/pomerium/pomerium)
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://img.shields.io/codecov/c/github/pomerium/pomerium.svg?style=flat)](https://codecov.io/gh/pomerium/pomerium)
Pomerium is a tool for managing secure access to internal applications and resources.
Use Pomerium to:
- provide a unified gateway (reverse-proxy) to internal corporate applications.
- enforce dynamic access policy based on context, identity, and device state.
- deploy mutual authenticated encryption (mTLS).
- aggregate logging and telemetry data.
- provide a single-sign-on gateway to internal applications.
- enforce dynamic access policy based on **context**, **identity**, and **device state**.
- aggregate access logs and telemetry data.
- an alternative to a VPN.
Check out [awesome-zero-trust] to learn more about some problems Pomerium attempts to address.
Check out [awesome-zero-trust] to learn more about some of the problems Pomerium attempts to address.
## Docs
To get started with pomerium, check out our [quick start guide].
For comprehensive docs see our [documentation] and the [godocs].
For comprehensive docs, and tutorials see our [documentation] and the [godocs].
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
[documentation]: https://www.pomerium.io/docs/

View file

@ -18,37 +18,38 @@ import (
)
var defaultOptions = &Options{
CookieName: "_pomerium_authenticate",
CookieHTTPOnly: true,
CookieSecure: true,
CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(30) * time.Minute,
CookieLifetimeTTL: time.Duration(720) * time.Hour,
CookieName: "_pomerium_authenticate",
CookieHTTPOnly: true,
CookieSecure: true,
CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(30) * time.Minute,
}
// Options details the available configuration settings for the authenticate service
type Options struct {
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
// SharedKey is used to authenticate requests between services
SharedKey string `envconfig:"SHARED_SECRET"`
// RedirectURL specifies the callback url following third party authentication
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
// Coarse authorization based on user email domain
// todo(bdd) : to be replaced with authorization module
AllowedDomains []string `envconfig:"ALLOWED_DOMAINS"`
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
// Session/Cookie management
CookieName string
CookieSecret string `envconfig:"COOKIE_SECRET"`
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
CookieSecure bool `envconfig:"COOKIE_SECURE"`
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
CookieLifetimeTTL time.Duration `envconfig:"COOKIE_LIFETIME"`
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
CookieName string
CookieSecret string `envconfig:"COOKIE_SECRET"`
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
CookieSecure bool `envconfig:"COOKIE_SECURE"`
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
// IdentityProvider provider configuration variables as specified by RFC6749
// See: https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
// https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
ClientID string `envconfig:"IDP_CLIENT_ID"`
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
Provider string `envconfig:"IDP_PROVIDER"`
@ -103,17 +104,13 @@ func (o *Options) Validate() error {
// Authenticate validates a user's identity
type Authenticate struct {
RedirectURL *url.URL
Validator func(string) bool
AllowedDomains []string
ProxyRootDomains []string
CookieSecure bool
SharedKey string
CookieLifetimeTTL time.Duration
RedirectURL *url.URL
AllowedDomains []string
ProxyRootDomains []string
Validator func(string) bool
templates *template.Template
csrfStore sessions.CSRFStore
@ -137,37 +134,47 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
if err != nil {
return nil, err
}
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
sessions.CreateCookieCipher(decodedCookieSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
c.CookieHTTPOnly = opts.CookieHTTPOnly
c.CookieExpire = opts.CookieExpire
c.CookieSecure = opts.CookieSecure
return nil
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
Name: opts.CookieName,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
CookieCipher: cipher,
})
if err != nil {
return nil, err
}
p := &Authenticate{
SharedKey: opts.SharedKey,
AllowedDomains: opts.AllowedDomains,
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
CookieSecure: opts.CookieSecure,
RedirectURL: opts.RedirectURL,
templates: templates.New(),
csrfStore: cookieStore,
sessionStore: cookieStore,
cipher: cipher,
}
p.provider, err = newProvider(opts)
provider, err := providers.New(
opts.Provider,
&providers.IdentityProvider{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
// SessionLifetimeTTL: opts.CookieLifetimeTTL,
Scopes: opts.Scopes,
})
if err != nil {
return nil, err
}
p := &Authenticate{
SharedKey: opts.SharedKey,
RedirectURL: opts.RedirectURL,
AllowedDomains: opts.AllowedDomains,
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
templates: templates.New(),
csrfStore: cookieStore,
sessionStore: cookieStore,
cipher: cipher,
provider: provider,
}
// validation via dependency injected function
for _, optFunc := range optionFuncs {
err := optFunc(p)
@ -179,20 +186,6 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
return p, nil
}
func newProvider(opts *Options) (providers.Provider, error) {
pd := &providers.IdentityProvider{
RedirectURL: opts.RedirectURL,
ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
SessionLifetimeTTL: opts.CookieLifetimeTTL,
Scopes: opts.Scopes,
}
np, err := providers.New(opts.Provider, pd)
return np, err
}
func dotPrependDomains(d []string) []string {
for i := range d {
if d[i] != "" && !strings.HasPrefix(d[i], ".") {

View file

@ -11,16 +11,17 @@ import (
func testOptions() *Options {
redirectURL, _ := url.Parse("https://example.com/oauth2/callback")
return &Options{
ProxyRootDomains: []string{"example.com"},
AllowedDomains: []string{"example.com"},
RedirectURL: redirectURL,
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
ClientID: "test-client-id",
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieRefresh: time.Duration(1) * time.Hour,
CookieLifetimeTTL: time.Duration(720) * time.Hour,
CookieExpire: time.Duration(168) * time.Hour,
ProxyRootDomains: []string{"example.com"},
AllowedDomains: []string{"example.com"},
RedirectURL: redirectURL,
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
ClientID: "test-client-id",
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieRefresh: time.Duration(1) * time.Hour,
// CookieLifetimeTTL: time.Duration(720) * time.Hour,
CookieExpire: time.Duration(168) * time.Hour,
CookieName: "pomerium",
}
}
@ -130,37 +131,6 @@ func Test_dotPrependDomains(t *testing.T) {
}
}
func Test_newProvider(t *testing.T) {
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
goodOpts := &Options{
RedirectURL: redirectURL,
Provider: "google",
ProviderURL: "",
ClientID: "cllient-id",
ClientSecret: "client-secret",
}
tests := []struct {
name string
opts *Options
wantErr bool
}{
{"good", goodOpts, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newProvider(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("newProvider() error = %v, wantErr %v", err, tt.wantErr)
return
}
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("newProvider() = %v, want %v", got, tt.want)
// }
})
}
}
func TestNew(t *testing.T) {
good := testOptions()
good.Provider = "google"

View file

@ -117,7 +117,6 @@ func TestAuthenticate_Authenticate(t *testing.T) {
}
lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
vt := time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC()
vtProto, err := ptypes.TimestampProto(rt)
if err != nil {
t.Fatal("failed to parse timestamp")
@ -128,9 +127,9 @@ func TestAuthenticate_Authenticate(t *testing.T) {
RefreshToken: "refresh4321",
LifetimeDeadline: lt,
RefreshDeadline: rt,
ValidDeadline: vt,
Email: "user@domain.com",
User: "user",
Email: "user@domain.com",
User: "user",
}
goodReply := &pb.AuthenticateReply{

View file

@ -16,7 +16,8 @@ import (
"github.com/pomerium/pomerium/internal/version"
)
// securityHeaders corresponds to HTTP response headers related to security.
// securityHeaders corresponds to HTTP response headers that help to protect against protocol
// downgrade attacks and cookie hijacking.
// https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers
var securityHeaders = map[string]string{
"Strict-Transport-Security": "max-age=31536000",
@ -28,7 +29,7 @@ var securityHeaders = map[string]string{
"Referrer-Policy": "Same-origin",
}
// Handler returns the Http.Handlers for authenticate, callback, and refresh
// Handler returns the authenticate service's HTTP request multiplexer, and routes.
func (a *Authenticate) Handler() http.Handler {
// set up our standard middlewares
stdMiddleware := middleware.NewChain()
@ -80,12 +81,6 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
return nil, err
}
// if long-lived lifetime has expired, clear session
if session.LifetimePeriodExpired() {
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
a.sessionStore.ClearSession(w, r)
return nil, sessions.ErrLifetimeExpired
}
// check if session refresh period is up
if session.RefreshPeriodExpired() {
newToken, err := a.provider.Refresh(session.RefreshToken)
@ -130,32 +125,23 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
return session, nil
}
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
// SignIn handles the sign_in endpoint. It attempts to authenticate the user,
// and if the user is not authenticated, it renders a sign in page.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.authenticate(w, r)
switch err {
case nil:
// session good, redirect back to proxy
log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated")
a.ProxyCallback(w, r, session)
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
// session invalid, authenticate
log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure")
if err != http.ErrNoCookie {
a.sessionStore.ClearSession(w, r)
}
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: authenticate error")
a.sessionStore.ClearSession(w, r)
a.OAuthStart(w, r)
default:
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
}
log.FromRequest(r).Info().Msg("authenticate: user authenticated")
a.ProxyCallback(w, r, session)
}
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
// url params, of the user's session state.
// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2
// url params, of the user's session state as specified in RFC6749 3.1.2.
// https://tools.ietf.org/html/rfc6749#section-3.1.2
func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
err := r.ParseForm()
if err != nil {
@ -201,9 +187,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
return u.String()
}
// SignOut signs the user out by trying to revoke the users remote identity provider session
// then removes the associated local session state.
// Handles both GET and POST of form.
// SignOut signs the user out by trying to revoke the user's remote identity session along with
// the associated local session state. Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
@ -256,8 +241,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirectURI, http.StatusFound)
}
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
// OAuthStart starts the authenticate process by redirecting to the identity provider.
// https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
if err != nil {
@ -298,7 +283,7 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
}
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
// was an error. If successful, the user is redirected back to the proxy-service.
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r)
switch h := err.(type) {

View file

@ -71,29 +71,19 @@ func TestAuthenticate_Handler(t *testing.T) {
func TestAuthenticate_authenticate(t *testing.T) {
// sessions.MockSessionStore{Session: expiredLifetime}
goodSession := sessions.MockSessionStore{
goodSession := &sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}}
expiredSession := sessions.MockSessionStore{
expiredRefresPeriod := &sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * -time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
expiredRefresPeriod := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * -time.Second),
}}
tests := []struct {
@ -106,18 +96,16 @@ func TestAuthenticate_authenticate(t *testing.T) {
}{
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
{"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
{"session fails after good validation", sessions.MockSessionStore{
{"session fails after good validation", &sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}}, providers.MockProvider{ValidateResponse: true},
trueValidator, nil, true},
{"refresh expired",
expiredRefresPeriod,
providers.MockProvider{
@ -136,14 +124,13 @@ func TestAuthenticate_authenticate(t *testing.T) {
},
trueValidator, nil, true},
{"refresh expired failed save",
sessions.MockSessionStore{
&sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * -time.Second),
}},
providers.MockProvider{
ValidateResponse: true,
@ -182,29 +169,23 @@ func TestAuthenticate_SignIn(t *testing.T) {
wantCode int
}{
{"good",
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
providers.MockProvider{ValidateResponse: true},
trueValidator,
403},
// {"no session",
// sessions.MockSessionStore{
// Session: &sessions.SessionState{
// AccessToken: "AccessToken",
// RefreshToken: "RefreshToken",
// LifetimeDeadline: time.Now().Add(-10 * time.Second),
// RefreshDeadline: time.Now().Add(10 * time.Second),
// ValidDeadline: time.Now().Add(10 * time.Second),
// }},
// providers.MockProvider{ValidateResponse: true},
// trueValidator,
// 200},
http.StatusForbidden},
{"session fails after good validation", &sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}}, providers.MockProvider{ValidateResponse: true},
trueValidator, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -212,6 +193,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
sessionStore: tt.session,
provider: tt.provider,
Validator: tt.validator,
RedirectURL: uriParse("http://www.pomerium.io"),
csrfStore: &sessions.MockCSRFStore{},
SharedKey: "secret",
cipher: mockCipher{},
}
r := httptest.NewRequest("GET", "/sign-in", nil)
w := httptest.NewRecorder()
@ -262,13 +247,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
}{
{"good", "https://corp.pomerium.io/", "state", "code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
302,
"<a href=\"https://corp.pomerium.io/?code=ok&amp;state=state\">Found</a>."},
{"no state",
@ -276,13 +260,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
"",
"code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
403,
"no state parameter supplied"},
{"no redirect_url",
@ -290,13 +273,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
"state",
"code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
403,
"no redirect_uri parameter"},
{"malformed redirect_url",
@ -304,13 +286,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
"state",
"code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
400,
"malformed redirect_uri"},
}
@ -389,14 +370,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusFound,
@ -407,14 +387,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{RevokeError: errors.New("OH NO")},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
@ -426,14 +405,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusOK,
@ -444,15 +422,14 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
LoadError: errors.New("uh oh"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
@ -463,14 +440,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
@ -512,7 +488,6 @@ func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string
}
func TestAuthenticate_OAuthStart(t *testing.T) {
tests := []struct {
name string
method string
@ -634,15 +609,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -657,15 +631,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -681,15 +654,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -704,7 +676,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateError: errors.New("error"),
},
@ -721,15 +693,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{SaveError: errors.New("error")},
&sessions.MockSessionStore{SaveError: errors.New("error")},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -744,15 +715,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
falseValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -768,15 +738,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -791,15 +760,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -814,15 +782,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
"nonce:https://corp.pomerium.io",
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -837,15 +804,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
@ -860,15 +826,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
&sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",

View file

@ -1,5 +1,5 @@
// Package providers implements OpenID Connect client logic for the set of supported identity
// providers.
// OpenID Connect 1.0 is a simple identity layer on top of the OAuth 2.0 RFC6749 protocol.
// https://openid.net/specs/openid-connect-core-1_0.html
// Package providers authentication for third party identity providers (IdP) using OpenID
// Connect, an identity layer on top of the OAuth 2.0 RFC6749 protocol.
//
// see: https://openid.net/specs/openid-connect-core-1_0.html
package providers // import "github.com/pomerium/pomerium/internal/providers"

View file

@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"errors"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
@ -19,7 +18,7 @@ type OIDCProvider struct {
func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) {
ctx := context.Background()
if p.ProviderURL == "" {
return nil, errors.New("missing required provider url")
return nil, ErrMissingProviderURL
}
var err error
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)

View file

@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"context"
"errors"
"net/url"
oidc "github.com/pomerium/go-oidc"
@ -25,7 +24,7 @@ type OktaProvider struct {
func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) {
ctx := context.Background()
if p.ProviderURL == "" {
return nil, errors.New("missing required provider url")
return nil, ErrMissingProviderURL
}
var err error
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)

View file

@ -29,6 +29,11 @@ const (
OktaProviderName = "okta"
)
var (
// ErrMissingProviderURL is returned when the CB state is half open and the requests count is over the cb maxRequests
ErrMissingProviderURL = errors.New("proxy/providers: missing provider url")
)
// Provider is an interface exposing functions necessary to interact with a given provider.
type Provider interface {
Authenticate(string) (*sessions.SessionState, error)

17
go.mod
View file

@ -5,18 +5,17 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/mock v1.2.0
github.com/golang/protobuf v1.2.0
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pomerium/envconfig v1.3.1-0.20190112072701-14cbcf832d31
github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/rs/zerolog v1.11.0
github.com/stretchr/testify v1.2.2 // indirect
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
golang.org/x/net v0.0.0-20181220203305-927f97764cc3
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect
google.golang.org/appengine v1.4.0 // indirect
github.com/stretchr/testify v1.3.0 // indirect
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 // indirect
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa // indirect
google.golang.org/grpc v1.18.0
gopkg.in/square/go-jose.v2 v2.2.1
gopkg.in/square/go-jose.v2 v2.2.2
)

38
go.sum
View file

@ -1,11 +1,14 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
@ -23,34 +26,45 @@ github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAm
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635 h1:dOJmQysgY8iOBECuNp0vlKHWEtfiTnyjisEizRV3/4o=
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 h1:g/Jesu8+QLnA0CPzF3E1pURg0Byr7i6jLoX5sqjcAh0=
golang.org/x/sys v0.0.0-20190116161447-11f53e031339/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa h1:FVL+/MjP2dzG4PxLpCJR7B6esIia88UAbsfYUrCc8U4=
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA=
google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic=
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA=
gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -29,47 +29,43 @@ type SessionStore interface {
// CookieStore represents all the cookie related configurations
type CookieStore struct {
Name string
CSRFCookieName string
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
CookieCipher cryptutil.Cipher
SessionLifetimeTTL time.Duration
Name string
CSRFCookieName string
CookieCipher cryptutil.Cipher
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
}
// CreateCookieCipher creates a new miscreant cipher with the cookie secret
func CreateCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
return func(s *CookieStore) error {
cipher, err := cryptutil.NewCipher(cookieSecret)
if err != nil {
return fmt.Errorf("cookie-secret error: %s", err.Error())
}
s.CookieCipher = cipher
return nil
}
// CookieStoreOptions holds options for CookieStore
type CookieStoreOptions struct {
Name string
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
CookieExpire time.Duration
CookieCipher cryptutil.Cipher
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
c := &CookieStore{
Name: cookieName,
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
}
for _, f := range optFuncs {
err := f(c)
if err != nil {
return nil, err
}
if opts.CookieCipher == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
}
return c, nil
return &CookieStore{
Name: opts.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Name, "csrf"),
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
CookieCipher: opts.CookieCipher,
}, nil
}
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
@ -80,16 +76,19 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
if s.CookieDomain != "" {
domain = s.CookieDomain
}
return &http.Cookie{
c := &http.Cookie{
Name: name,
Value: value,
Path: "/",
Domain: domain,
HttpOnly: s.CookieHTTPOnly,
Secure: s.CookieSecure,
Expires: now.Add(expiration),
}
// only set an expiration if we want one, otherwise default to non perm session based
if expiration != 0 {
c.Expires = now.Add(expiration)
}
return c
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
@ -103,13 +102,13 @@ func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration
}
// ClearCSRF clears the CSRF cookie from the request
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
@ -118,20 +117,19 @@ func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
}
// ClearSession clears the session cookie from a request
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
}
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
}
// LoadSession returns a SessionState from the cookie in the request.
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
c, err := req.Cookie(s.Name)
if err != nil {
// always http.ErrNoCookie
return nil, err
return nil, err // http.ErrNoCookie
}
session, err := UnmarshalSession(c.Value, s.CookieCipher)
if err != nil {
@ -141,12 +139,11 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
}
// SaveSession saves a session state to a request sessions.
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
value, err := MarshalSession(sessionState, s.CookieCipher)
if err != nil {
return err
}
s.setSessionCookie(rw, req, value)
s.setSessionCookie(w, req, value)
return nil
}

View file

@ -1,348 +1,348 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package sessions
import (
"encoding/base64"
"fmt"
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/internal/cryptutil"
)
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
type mockCipher struct{}
func TestCreateCookieCipher(t *testing.T) {
testCases := []struct {
name string
cookieSecret []byte
expectedError bool
}{
{
name: "normal case with base64 encoded secret",
cookieSecret: testEncodedCookieSecret,
},
{
name: "error when not base64 encoded",
cookieSecret: []byte("abcd"),
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewCookieStore("cookieName", CreateCookieCipher(tc.cookieSecret))
if !tc.expectedError {
testutil.Ok(t, err)
} else {
testutil.NotEqual(t, err, nil)
}
})
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func TestNewSession(t *testing.T) {
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedError bool
expectedSession *CookieStore
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if string(s) == "unmarshal error" || string(s) == "error" {
return errors.New("error")
}
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
opts *CookieStoreOptions
want *CookieStore
wantErr bool
}{
{
name: "default with no opt funcs set",
expectedSession: &CookieStore{
Name: "cookieName",
{"good",
&CookieStoreOptions{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: 168 * time.Hour,
CSRFCookieName: "cookieName_csrf",
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
},
{
name: "opt func with an error returns an error",
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
expectedError: true,
},
{
name: "opt func overrides default values",
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
s.CookieExpire = time.Hour
return nil
}},
expectedSession: &CookieStore{
Name: "cookieName",
&CookieStore{
Name: "_cookie",
CSRFCookieName: "_cookie_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
CookieExpire: time.Hour,
CSRFCookieName: "cookieName_csrf",
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
},
false},
{"missing name",
&CookieStoreOptions{
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
nil,
true},
{"missing cipher",
&CookieStoreOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: nil,
},
nil,
true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore("cookieName", tc.optFuncs...)
if tc.expectedError {
testutil.NotEqual(t, err, nil)
} else {
testutil.Ok(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieStore(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
}
testutil.Equal(t, tc.expectedSession, session)
})
}
}
func TestMakeSessionCookie(t *testing.T) {
func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
type fields struct {
Name string
CSRFCookieName string
CookieCipher cryptutil.Cipher
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
}
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
}
tests := []struct {
name string
domain string
cookieName string
value string
expiration time.Duration
want *http.Cookie
}{
{"good", "http://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
{"domains with https", "https://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
{"domain with port", "http://pomerium.io:443", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
{"expiration set", "http://pomerium.io:443", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", tt.domain, nil)
s := &CookieStore{
Name: "_pomerium",
CSRFCookieName: "_pomerium_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher}
if got := s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}
if got := s.makeSessionCookie(r, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
tt.want.Name = "_pomerium_csrf"
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
}
w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
testutil.Equal(t, cookie, tc.expectedCookie)
})
}
}
func TestMakeSessionCSRFCookie(t *testing.T) {
now := time.Now()
cookieValue := "cookieValue"
expiration := time.Hour
cookieName := "cookieName"
csrfName := "cookieName_csrf"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
expectedCookie *http.Cookie
}{
{
name: "default cookie domain",
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "www.example.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
{
name: "custom cookie domain set",
optFuncs: []func(*CookieStore) error{
func(s *CookieStore) error {
s.CookieDomain = "buzzfeed.com"
return nil
},
},
expectedCookie: &http.Cookie{
Name: csrfName,
Value: cookieValue,
Path: "/",
Domain: "buzzfeed.com",
HttpOnly: true,
Secure: true,
Expires: now.Add(expiration),
},
},
func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
sessionState *SessionState
cipher cryptutil.Cipher
wantErr bool
wantLoadErr bool
}{
{"good",
&SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}, cipher, false, false},
{"bad cipher",
&SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}, mockCipher{}, true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &CookieStore{
Name: "_pomerium",
CSRFCookieName: "_pomerium_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: tt.cipher}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
testutil.Equal(t, tc.expectedCookie, cookie)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() {
t.Log(cookie)
r.AddCookie(cookie)
}
state, err := s.LoadSession(r)
if (err != nil) != tt.wantLoadErr {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return
}
if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
}
})
}
}
func TestSetSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.setSessionCookie(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestSetCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
rw := httptest.NewRecorder()
session.SetCSRF(rw, req, cookieValue)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, cookieValue, cookie.Value)
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("set session cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearSession(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == cookieName {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestClearCSRFSessionCookie(t *testing.T) {
cookieValue := "cookieValue"
cookieName := "cookieName"
t.Run("clear csrf cookie test", func(t *testing.T) {
session, err := NewCookieStore(cookieName)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "http://www.example.com", nil)
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
rw := httptest.NewRecorder()
session.ClearCSRF(rw, req)
var found bool
for _, cookie := range rw.Result().Cookies() {
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
found = true
testutil.Equal(t, "", cookie.Value)
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
}
}
testutil.Assert(t, found, "cookie in header")
})
}
func TestLoadCookiedSession(t *testing.T) {
cookieName := "cookieName"
testCases := []struct {
name string
optFuncs []func(*CookieStore) error
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
expectedError error
sessionState *SessionState
func TestMockCSRFStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{
name: "no cookie set returns an error",
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
expectedError: http.ErrNoCookie,
},
{
name: "cookie set with cipher set",
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value, err := MarshalSession(sessionState, s.CookieCipher)
testutil.Ok(t, err)
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
sessionState: &SessionState{
Email: "example@email.com",
RefreshToken: "abccdddd",
AccessToken: "access",
},
},
{
name: "cookie set with invalid value cipher set",
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)},
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
value := "574b776a7c934d6b9fc42ec63a389f79"
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
},
expectedError: ErrInvalidSession,
},
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session, err := NewCookieStore(cookieName, tc.optFuncs...)
testutil.Ok(t, err)
req := httptest.NewRequest("GET", "https://www.example.com", nil)
tc.setupCookies(t, req, session, tc.sessionState)
s, err := session.LoadSession(req)
testutil.Equal(t, tc.expectedError, err)
testutil.Equal(t, tc.sessionState, s)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
})
}
}
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *SessionState
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
ResponseSession: "test",
Session: &SessionState{AccessToken: "AccessToken"},
SaveError: nil,
LoadError: nil,
},
&SessionState{AccessToken: "AccessToken"},
false,
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
err := ms.SaveSession(nil, nil, tt.saveSession)
if (err != nil) != tt.wantSaveErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return
}
got, err := ms.LoadSession(nil)
if (err != nil) != tt.wantLoadErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
}
ms.ClearSession(nil, nil)
if ms.ResponseSession != "" {
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
}
})
}
}

View file

@ -35,7 +35,7 @@ type MockSessionStore struct {
}
// ClearSession clears the ResponseSession
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}

View file

@ -20,11 +20,9 @@ type SessionState struct {
RefreshDeadline time.Time `json:"refresh_deadline"`
LifetimeDeadline time.Time `json:"lifetime_deadline"`
ValidDeadline time.Time `json:"valid_deadline"`
GracePeriodStart time.Time `json:"grace_period_start"`
Email string `json:"email"`
User string `json:"user"` // 'sub' in jwt parlance
User string `json:"user"` // 'sub' in jwt
Groups []string `json:"groups"`
}
@ -38,11 +36,6 @@ func (s *SessionState) RefreshPeriodExpired() bool {
return isExpired(s.RefreshDeadline)
}
// ValidationPeriodExpired returns true if the validation period has expired
func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}
func isExpired(t time.Time) bool {
return t.Before(time.Now())
}
@ -64,7 +57,7 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
return s, nil
}
// ExtendDeadline returns the time extended by a given duration
// ExtendDeadline returns the time extended by a given duration, truncated by second
func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package sessions
import (
"reflect"
@ -16,15 +16,12 @@ func TestSessionStateSerialization(t *testing.T) {
}
want := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
Email: "user@domain.com",
User: "user",
}
ciphertext, err := MarshalSession(want, c)
@ -46,26 +43,40 @@ func TestSessionStateSerialization(t *testing.T) {
func TestSessionStateExpirations(t *testing.T) {
session := &SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
AccessToken: "token1234",
RefreshToken: "refresh4321",
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
Email: "user@domain.com",
User: "user",
Email: "user@domain.com",
User: "user",
}
if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}
if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}
if !session.ValidationPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
}
func TestExtendDeadline(t *testing.T) {
// tons of wiggle room here
now := time.Now().Truncate(time.Second)
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)},
{"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,75 +0,0 @@
// Original Copyright 2013 The Go Authors. All rights reserved.
//
// Modified by BuzzFeed to return duplicate counts.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression mechanism.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Count bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value of Count indicates how many tiems v was given to multiple callers.
// Count will be zero for requests are shared and only be non-zero for the originating request.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, 0, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.dups, c.err
}

View file

@ -1,87 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, _, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, _, err := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
return v, nil
}
const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, _, err := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

View file

@ -1,46 +0,0 @@
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
// testing util functions copied from https://github.com/benbjohnson/testing
import (
"fmt"
"path/filepath"
"reflect"
"runtime"
"testing"
)
// Assert fails the test if the condition is false.
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
if !condition {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
tb.FailNow()
}
}
// Ok fails the test if an err is not nil.
func Ok(tb testing.TB, err error) {
if err != nil {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
tb.FailNow()
}
}
// Equal fails the test if exp is not equal to act.
func Equal(tb testing.TB, exp, act interface{}) {
if !reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}
// NotEqual fails the test if exp is equal to act.
func NotEqual(tb testing.TB, exp, act interface{}) {
if reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}

View file

@ -25,7 +25,7 @@ type Options struct {
// InternalAddr is the internal (behind the ingress) address to use when making an
// authentication connection. If empty, Addr is used.
InternalAddr string
// OverrideServerName overrides the server name used to verify the hostname on the
// OverideCertificateName overrides the server name used to verify the hostname on the
// returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set.
OverideCertificateName string

View file

@ -203,16 +203,15 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
}
// We store the session in a cookie and redirect the user back to the application
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{
AccessToken: rr.AccessToken,
RefreshToken: rr.RefreshToken,
IDToken: rr.IDToken,
User: rr.User,
Email: rr.Email,
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
ValidDeadline: extendDeadline(p.CookieExpire),
})
err = p.sessionStore.SaveSession(w, r,
&sessions.SessionState{
AccessToken: rr.AccessToken,
RefreshToken: rr.RefreshToken,
IDToken: rr.IDToken,
User: rr.User,
Email: rr.Email,
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
})
if err != nil {
log.FromRequest(r).Error().Msg("error saving session")
httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
@ -250,9 +249,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
}
}
// ! ! !
// todo(bdd): ! Authorization service goes here !
// ! ! !
// todo(bdd): add authorization service validation
// We have validated the users request and now proxy their request to the provided upstream.
route, ok := p.router(r)
@ -278,14 +275,10 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
return err
}
if session.LifetimePeriodExpired() {
log.FromRequest(r).Info().Msg("proxy: lifetime expired")
return sessions.ErrLifetimeExpired
}
if session.RefreshPeriodExpired() {
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
// refresh token (which doesn't change) can be used to request a new access-token. If access
// is revoked by identity provider, or no refresh token is set request will return an error
// is revoked by identity provider, or no refresh token is set, request will return an error
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
if err != nil {
log.FromRequest(r).Warn().

View file

@ -272,7 +272,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
proxy.sessionStore = tt.session
proxy.sessionStore = &tt.session
proxy.csrfStore = tt.csrf
proxy.AuthenticateClient = tt.authenticator
proxy.cipher = mockCipher{}
@ -352,12 +352,6 @@ func TestProxy_Proxy(t *testing.T) {
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
tests := []struct {
@ -368,11 +362,10 @@ func TestProxy_Proxy(t *testing.T) {
wantStatus int
}{
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
{"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
// redirect to start auth process
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
{"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -402,13 +395,8 @@ func TestProxy_Authenticate(t *testing.T) {
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
expiredDeadline := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
@ -426,25 +414,21 @@ func TestProxy_Authenticate(t *testing.T) {
{"cannot save session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
&sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
authenticator.MockAuthenticate{}, true},
{"cannot load session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired lifetime",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
&sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
&sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
{"bad refresh authenticator",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{
&sessions.MockSessionStore{
Session: expiredDeadline,
},
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
@ -453,7 +437,7 @@ func TestProxy_Authenticate(t *testing.T) {
{"good",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
&sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -133,12 +133,9 @@ type Proxy struct {
AuthenticateClient authenticator.Authenticator
// session
CookieExpire time.Duration
CookieRefresh time.Duration
CookieLifetimeTTL time.Duration
cipher cryptutil.Cipher
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
cipher cryptutil.Cipher
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
redirectURL *url.URL
templates *template.Template
@ -163,13 +160,14 @@ func New(opts *Options) (*Proxy, error) {
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
}
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
sessions.CreateCookieCipher(decodedSecret),
func(c *sessions.CookieStore) error {
c.CookieDomain = opts.CookieDomain
c.CookieHTTPOnly = opts.CookieHTTPOnly
c.CookieExpire = opts.CookieExpire
return nil
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
Name: opts.CookieName,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
CookieCipher: cipher,
})
if err != nil {
@ -181,14 +179,12 @@ func New(opts *Options) (*Proxy, error) {
// services
AuthenticateURL: opts.AuthenticateURL,
// session state
cipher: cipher,
csrfStore: cookieStore,
sessionStore: cookieStore,
SharedKey: opts.SharedKey,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
templates: templates.New(),
CookieExpire: opts.CookieExpire,
CookieLifetimeTTL: opts.CookieLifetimeTTL,
cipher: cipher,
csrfStore: cookieStore,
sessionStore: cookieStore,
SharedKey: opts.SharedKey,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
templates: templates.New(),
}
for from, to := range opts.Routes {
@ -200,7 +196,7 @@ func New(opts *Options) (*Proxy, error) {
return nil, err
}
p.Handle(fromURL.Host, handler)
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy: new route")
}
p.AuthenticateClient, err = authenticator.New(

View file

@ -139,6 +139,7 @@ func testOptions() *Options {
AuthenticateURL: authurl,
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieName: "pomerium",
}
}