mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 21:17:13 +02:00
all: general cleanup readying for tagged release (#48)
- docs: add code coverage to readme - internal/sessions: refactor sessions to clarify lifetime - authenticate: simplified signin flow - deployment: update go mods - internal/testutil: removed package - internal/singleflight: removed package
This commit is contained in:
parent
13c03a2b5c
commit
dbafc691c3
25 changed files with 712 additions and 1017 deletions
14
README.md
14
README.md
|
@ -4,24 +4,24 @@
|
|||
|
||||
# Pomerium
|
||||
|
||||
[](https://travis-ci.org/pomerium/pomerium) [](https://goreportcard.com/report/github.com/pomerium/pomerium) [][godocs] [](https://github.com/pomerium/pomerium/blob/master/LICENSE)[](https://codecov.io/gh/pomerium/pomerium)
|
||||
[](https://travis-ci.org/pomerium/pomerium) [](https://goreportcard.com/report/github.com/pomerium/pomerium) [][godocs] [](https://github.com/pomerium/pomerium/blob/master/LICENSE)[](https://codecov.io/gh/pomerium/pomerium)
|
||||
|
||||
Pomerium is a tool for managing secure access to internal applications and resources.
|
||||
|
||||
Use Pomerium to:
|
||||
|
||||
- provide a unified gateway (reverse-proxy) to internal corporate applications.
|
||||
- enforce dynamic access policy based on context, identity, and device state.
|
||||
- deploy mutual authenticated encryption (mTLS).
|
||||
- aggregate logging and telemetry data.
|
||||
- provide a single-sign-on gateway to internal applications.
|
||||
- enforce dynamic access policy based on **context**, **identity**, and **device state**.
|
||||
- aggregate access logs and telemetry data.
|
||||
- an alternative to a VPN.
|
||||
|
||||
Check out [awesome-zero-trust] to learn more about some problems Pomerium attempts to address.
|
||||
Check out [awesome-zero-trust] to learn more about some of the problems Pomerium attempts to address.
|
||||
|
||||
## Docs
|
||||
|
||||
To get started with pomerium, check out our [quick start guide].
|
||||
|
||||
For comprehensive docs see our [documentation] and the [godocs].
|
||||
For comprehensive docs, and tutorials see our [documentation] and the [godocs].
|
||||
|
||||
[awesome-zero-trust]: https://github.com/pomerium/awesome-zero-trust
|
||||
[documentation]: https://www.pomerium.io/docs/
|
||||
|
|
|
@ -18,37 +18,38 @@ import (
|
|||
)
|
||||
|
||||
var defaultOptions = &Options{
|
||||
CookieName: "_pomerium_authenticate",
|
||||
CookieHTTPOnly: true,
|
||||
CookieSecure: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(30) * time.Minute,
|
||||
CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieName: "_pomerium_authenticate",
|
||||
CookieHTTPOnly: true,
|
||||
CookieSecure: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(30) * time.Minute,
|
||||
}
|
||||
|
||||
// Options details the available configuration settings for the authenticate service
|
||||
type Options struct {
|
||||
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
|
||||
|
||||
// SharedKey is used to authenticate requests between services
|
||||
SharedKey string `envconfig:"SHARED_SECRET"`
|
||||
|
||||
// RedirectURL specifies the callback url following third party authentication
|
||||
RedirectURL *url.URL `envconfig:"REDIRECT_URL"`
|
||||
|
||||
// Coarse authorization based on user email domain
|
||||
// todo(bdd) : to be replaced with authorization module
|
||||
AllowedDomains []string `envconfig:"ALLOWED_DOMAINS"`
|
||||
ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"`
|
||||
|
||||
// Session/Cookie management
|
||||
CookieName string
|
||||
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||
CookieSecure bool `envconfig:"COOKIE_SECURE"`
|
||||
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
|
||||
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
|
||||
CookieLifetimeTTL time.Duration `envconfig:"COOKIE_LIFETIME"`
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
||||
CookieName string
|
||||
CookieSecret string `envconfig:"COOKIE_SECRET"`
|
||||
CookieDomain string `envconfig:"COOKIE_DOMAIN"`
|
||||
CookieSecure bool `envconfig:"COOKIE_SECURE"`
|
||||
CookieHTTPOnly bool `envconfig:"COOKIE_HTTP_ONLY"`
|
||||
CookieExpire time.Duration `envconfig:"COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `envconfig:"COOKIE_REFRESH"`
|
||||
|
||||
// IdentityProvider provider configuration variables as specified by RFC6749
|
||||
// See: https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
|
||||
// https://openid.net/specs/openid-connect-basic-1_0.html#RFC6749
|
||||
ClientID string `envconfig:"IDP_CLIENT_ID"`
|
||||
ClientSecret string `envconfig:"IDP_CLIENT_SECRET"`
|
||||
Provider string `envconfig:"IDP_PROVIDER"`
|
||||
|
@ -103,17 +104,13 @@ func (o *Options) Validate() error {
|
|||
|
||||
// Authenticate validates a user's identity
|
||||
type Authenticate struct {
|
||||
RedirectURL *url.URL
|
||||
|
||||
Validator func(string) bool
|
||||
|
||||
AllowedDomains []string
|
||||
ProxyRootDomains []string
|
||||
CookieSecure bool
|
||||
|
||||
SharedKey string
|
||||
|
||||
CookieLifetimeTTL time.Duration
|
||||
RedirectURL *url.URL
|
||||
AllowedDomains []string
|
||||
ProxyRootDomains []string
|
||||
|
||||
Validator func(string) bool
|
||||
|
||||
templates *template.Template
|
||||
csrfStore sessions.CSRFStore
|
||||
|
@ -137,37 +134,47 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
|
||||
sessions.CreateCookieCipher(decodedCookieSecret),
|
||||
func(c *sessions.CookieStore) error {
|
||||
c.CookieDomain = opts.CookieDomain
|
||||
c.CookieHTTPOnly = opts.CookieHTTPOnly
|
||||
c.CookieExpire = opts.CookieExpire
|
||||
c.CookieSecure = opts.CookieSecure
|
||||
return nil
|
||||
cookieStore, err := sessions.NewCookieStore(
|
||||
&sessions.CookieStoreOptions{
|
||||
Name: opts.CookieName,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieCipher: cipher,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Authenticate{
|
||||
SharedKey: opts.SharedKey,
|
||||
AllowedDomains: opts.AllowedDomains,
|
||||
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
|
||||
CookieSecure: opts.CookieSecure,
|
||||
RedirectURL: opts.RedirectURL,
|
||||
templates: templates.New(),
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
cipher: cipher,
|
||||
}
|
||||
|
||||
p.provider, err = newProvider(opts)
|
||||
provider, err := providers.New(
|
||||
opts.Provider,
|
||||
&providers.IdentityProvider{
|
||||
RedirectURL: opts.RedirectURL,
|
||||
ProviderName: opts.Provider,
|
||||
ProviderURL: opts.ProviderURL,
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
// SessionLifetimeTTL: opts.CookieLifetimeTTL,
|
||||
Scopes: opts.Scopes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Authenticate{
|
||||
SharedKey: opts.SharedKey,
|
||||
RedirectURL: opts.RedirectURL,
|
||||
AllowedDomains: opts.AllowedDomains,
|
||||
ProxyRootDomains: dotPrependDomains(opts.ProxyRootDomains),
|
||||
|
||||
templates: templates.New(),
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
cipher: cipher,
|
||||
provider: provider,
|
||||
}
|
||||
|
||||
// validation via dependency injected function
|
||||
for _, optFunc := range optionFuncs {
|
||||
err := optFunc(p)
|
||||
|
@ -179,20 +186,6 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
|||
return p, nil
|
||||
}
|
||||
|
||||
func newProvider(opts *Options) (providers.Provider, error) {
|
||||
pd := &providers.IdentityProvider{
|
||||
RedirectURL: opts.RedirectURL,
|
||||
ProviderName: opts.Provider,
|
||||
ProviderURL: opts.ProviderURL,
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
SessionLifetimeTTL: opts.CookieLifetimeTTL,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
np, err := providers.New(opts.Provider, pd)
|
||||
return np, err
|
||||
}
|
||||
|
||||
func dotPrependDomains(d []string) []string {
|
||||
for i := range d {
|
||||
if d[i] != "" && !strings.HasPrefix(d[i], ".") {
|
||||
|
|
|
@ -11,16 +11,17 @@ import (
|
|||
func testOptions() *Options {
|
||||
redirectURL, _ := url.Parse("https://example.com/oauth2/callback")
|
||||
return &Options{
|
||||
ProxyRootDomains: []string{"example.com"},
|
||||
AllowedDomains: []string{"example.com"},
|
||||
RedirectURL: redirectURL,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
ProxyRootDomains: []string{"example.com"},
|
||||
AllowedDomains: []string{"example.com"},
|
||||
RedirectURL: redirectURL,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
// CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieName: "pomerium",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,37 +131,6 @@ func Test_dotPrependDomains(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_newProvider(t *testing.T) {
|
||||
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
|
||||
|
||||
goodOpts := &Options{
|
||||
RedirectURL: redirectURL,
|
||||
Provider: "google",
|
||||
ProviderURL: "",
|
||||
ClientID: "cllient-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", goodOpts, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := newProvider(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("newProvider() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("newProvider() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
good := testOptions()
|
||||
good.Provider = "google"
|
||||
|
|
|
@ -117,7 +117,6 @@ func TestAuthenticate_Authenticate(t *testing.T) {
|
|||
}
|
||||
lt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
|
||||
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
|
||||
vt := time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC()
|
||||
vtProto, err := ptypes.TimestampProto(rt)
|
||||
if err != nil {
|
||||
t.Fatal("failed to parse timestamp")
|
||||
|
@ -128,9 +127,9 @@ func TestAuthenticate_Authenticate(t *testing.T) {
|
|||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: lt,
|
||||
RefreshDeadline: rt,
|
||||
ValidDeadline: vt,
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
goodReply := &pb.AuthenticateReply{
|
||||
|
|
|
@ -16,7 +16,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// securityHeaders corresponds to HTTP response headers related to security.
|
||||
// securityHeaders corresponds to HTTP response headers that help to protect against protocol
|
||||
// downgrade attacks and cookie hijacking.
|
||||
// https://www.owasp.org/index.php/OWASP_Secure_Headers_Project#tab=Headers
|
||||
var securityHeaders = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000",
|
||||
|
@ -28,7 +29,7 @@ var securityHeaders = map[string]string{
|
|||
"Referrer-Policy": "Same-origin",
|
||||
}
|
||||
|
||||
// Handler returns the Http.Handlers for authenticate, callback, and refresh
|
||||
// Handler returns the authenticate service's HTTP request multiplexer, and routes.
|
||||
func (a *Authenticate) Handler() http.Handler {
|
||||
// set up our standard middlewares
|
||||
stdMiddleware := middleware.NewChain()
|
||||
|
@ -80,12 +81,6 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// if long-lived lifetime has expired, clear session
|
||||
if session.LifetimePeriodExpired() {
|
||||
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
return nil, sessions.ErrLifetimeExpired
|
||||
}
|
||||
// check if session refresh period is up
|
||||
if session.RefreshPeriodExpired() {
|
||||
newToken, err := a.provider.Refresh(session.RefreshToken)
|
||||
|
@ -130,32 +125,23 @@ func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
return session, nil
|
||||
}
|
||||
|
||||
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
||||
// SignIn handles the sign_in endpoint. It attempts to authenticate the user,
|
||||
// and if the user is not authenticated, it renders a sign in page.
|
||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.authenticate(w, r)
|
||||
switch err {
|
||||
case nil:
|
||||
// session good, redirect back to proxy
|
||||
log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated")
|
||||
a.ProxyCallback(w, r, session)
|
||||
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||
// session invalid, authenticate
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure")
|
||||
if err != http.ErrNoCookie {
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
}
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: authenticate error")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
a.OAuthStart(w, r)
|
||||
|
||||
default:
|
||||
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
|
||||
httputil.ErrorResponse(w, r, err.Error(), httputil.CodeForError(err))
|
||||
}
|
||||
log.FromRequest(r).Info().Msg("authenticate: user authenticated")
|
||||
a.ProxyCallback(w, r, session)
|
||||
|
||||
}
|
||||
|
||||
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
|
||||
// url params, of the user's session state.
|
||||
// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2
|
||||
// url params, of the user's session state as specified in RFC6749 3.1.2.
|
||||
// https://tools.ietf.org/html/rfc6749#section-3.1.2
|
||||
func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
|
@ -201,9 +187,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
|
|||
return u.String()
|
||||
}
|
||||
|
||||
// SignOut signs the user out by trying to revoke the users remote identity provider session
|
||||
// then removes the associated local session state.
|
||||
// Handles both GET and POST of form.
|
||||
// SignOut signs the user out by trying to revoke the user's remote identity session along with
|
||||
// the associated local session state. Handles both GET and POST.
|
||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
|
@ -256,8 +241,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
|
||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
|
||||
// OAuthStart starts the authenticate process by redirecting to the identity provider.
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||
if err != nil {
|
||||
|
@ -298,7 +283,7 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
|
||||
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
|
||||
// was an error. If successful, the user is redirected back to the proxy-service.
|
||||
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
redirect, err := a.getOAuthCallback(w, r)
|
||||
switch h := err.(type) {
|
||||
|
|
|
@ -71,29 +71,19 @@ func TestAuthenticate_Handler(t *testing.T) {
|
|||
|
||||
func TestAuthenticate_authenticate(t *testing.T) {
|
||||
// sessions.MockSessionStore{Session: expiredLifetime}
|
||||
goodSession := sessions.MockSessionStore{
|
||||
goodSession := &sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}}
|
||||
expiredSession := sessions.MockSessionStore{
|
||||
|
||||
expiredRefresPeriod := &sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * -time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
}}
|
||||
expiredRefresPeriod := sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||
}}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -106,18 +96,16 @@ func TestAuthenticate_authenticate(t *testing.T) {
|
|||
}{
|
||||
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
|
||||
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
|
||||
{"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||
{"can't load session", &sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
|
||||
{"session fails after good validation", sessions.MockSessionStore{
|
||||
{"session fails after good validation", &sessions.MockSessionStore{
|
||||
SaveError: errors.New("error"),
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
}}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||
{"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}}, providers.MockProvider{ValidateResponse: true},
|
||||
trueValidator, nil, true},
|
||||
{"refresh expired",
|
||||
expiredRefresPeriod,
|
||||
providers.MockProvider{
|
||||
|
@ -136,14 +124,13 @@ func TestAuthenticate_authenticate(t *testing.T) {
|
|||
},
|
||||
trueValidator, nil, true},
|
||||
{"refresh expired failed save",
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
SaveError: errors.New("error"),
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||
}},
|
||||
providers.MockProvider{
|
||||
ValidateResponse: true,
|
||||
|
@ -182,29 +169,23 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
wantCode int
|
||||
}{
|
||||
{"good",
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
providers.MockProvider{ValidateResponse: true},
|
||||
trueValidator,
|
||||
403},
|
||||
// {"no session",
|
||||
// sessions.MockSessionStore{
|
||||
// Session: &sessions.SessionState{
|
||||
// AccessToken: "AccessToken",
|
||||
// RefreshToken: "RefreshToken",
|
||||
// LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||
// RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
// ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
// }},
|
||||
// providers.MockProvider{ValidateResponse: true},
|
||||
// trueValidator,
|
||||
// 200},
|
||||
http.StatusForbidden},
|
||||
{"session fails after good validation", &sessions.MockSessionStore{
|
||||
SaveError: errors.New("error"),
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}}, providers.MockProvider{ValidateResponse: true},
|
||||
trueValidator, http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -212,6 +193,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
Validator: tt.validator,
|
||||
RedirectURL: uriParse("http://www.pomerium.io"),
|
||||
csrfStore: &sessions.MockCSRFStore{},
|
||||
SharedKey: "secret",
|
||||
cipher: mockCipher{},
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/sign-in", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -262,13 +247,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
|
|||
}{
|
||||
{"good", "https://corp.pomerium.io/", "state", "code",
|
||||
&sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
302,
|
||||
"<a href=\"https://corp.pomerium.io/?code=ok&state=state\">Found</a>."},
|
||||
{"no state",
|
||||
|
@ -276,13 +260,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
|
|||
"",
|
||||
"code",
|
||||
&sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
403,
|
||||
"no state parameter supplied"},
|
||||
{"no redirect_url",
|
||||
|
@ -290,13 +273,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
|
|||
"state",
|
||||
"code",
|
||||
&sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
403,
|
||||
"no redirect_uri parameter"},
|
||||
{"malformed redirect_url",
|
||||
|
@ -304,13 +286,12 @@ func TestAuthenticate_ProxyCallback(t *testing.T) {
|
|||
"state",
|
||||
"code",
|
||||
&sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
400,
|
||||
"malformed redirect_uri"},
|
||||
}
|
||||
|
@ -389,14 +370,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
},
|
||||
http.StatusFound,
|
||||
|
@ -407,14 +387,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{RevokeError: errors.New("OH NO")},
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
},
|
||||
http.StatusBadRequest,
|
||||
|
@ -426,14 +405,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
},
|
||||
http.StatusOK,
|
||||
|
@ -444,15 +422,14 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
LoadError: errors.New("uh oh"),
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
},
|
||||
http.StatusBadRequest,
|
||||
|
@ -463,14 +440,13 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
"sig",
|
||||
"ts",
|
||||
providers.MockProvider{},
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
},
|
||||
http.StatusBadRequest,
|
||||
|
@ -512,7 +488,6 @@ func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string
|
|||
}
|
||||
|
||||
func TestAuthenticate_OAuthStart(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
@ -634,15 +609,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -657,15 +631,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -681,15 +654,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -704,7 +676,7 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateError: errors.New("error"),
|
||||
},
|
||||
|
@ -721,15 +693,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{SaveError: errors.New("error")},
|
||||
&sessions.MockSessionStore{SaveError: errors.New("error")},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -744,15 +715,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
falseValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -768,15 +738,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -791,15 +760,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -814,15 +782,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
"nonce:https://corp.pomerium.io",
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -837,15 +804,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
@ -860,15 +826,14 @@ func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
|||
base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")),
|
||||
[]string{"pomerium.io"},
|
||||
trueValidator,
|
||||
sessions.MockSessionStore{},
|
||||
&sessions.MockSessionStore{},
|
||||
providers.MockProvider{
|
||||
AuthenticateResponse: sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "blah@blah.com",
|
||||
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}},
|
||||
sessions.MockCSRFStore{
|
||||
ResponseCSRF: "csrf",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Package providers implements OpenID Connect client logic for the set of supported identity
|
||||
// providers.
|
||||
// OpenID Connect 1.0 is a simple identity layer on top of the OAuth 2.0 RFC6749 protocol.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html
|
||||
// Package providers authentication for third party identity providers (IdP) using OpenID
|
||||
// Connect, an identity layer on top of the OAuth 2.0 RFC6749 protocol.
|
||||
//
|
||||
// see: https://openid.net/specs/openid-connect-core-1_0.html
|
||||
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||
|
|
|
@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -19,7 +18,7 @@ type OIDCProvider struct {
|
|||
func NewOIDCProvider(p *IdentityProvider) (*OIDCProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
return nil, errors.New("missing required provider url")
|
||||
return nil, ErrMissingProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
|
|
|
@ -2,7 +2,6 @@ package providers // import "github.com/pomerium/pomerium/internal/providers"
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
|
@ -25,7 +24,7 @@ type OktaProvider struct {
|
|||
func NewOktaProvider(p *IdentityProvider) (*OktaProvider, error) {
|
||||
ctx := context.Background()
|
||||
if p.ProviderURL == "" {
|
||||
return nil, errors.New("missing required provider url")
|
||||
return nil, ErrMissingProviderURL
|
||||
}
|
||||
var err error
|
||||
p.provider, err = oidc.NewProvider(ctx, p.ProviderURL)
|
||||
|
|
|
@ -29,6 +29,11 @@ const (
|
|||
OktaProviderName = "okta"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMissingProviderURL is returned when the CB state is half open and the requests count is over the cb maxRequests
|
||||
ErrMissingProviderURL = errors.New("proxy/providers: missing provider url")
|
||||
)
|
||||
|
||||
// Provider is an interface exposing functions necessary to interact with a given provider.
|
||||
type Provider interface {
|
||||
Authenticate(string) (*sessions.SessionState, error)
|
||||
|
|
17
go.mod
17
go.mod
|
@ -5,18 +5,17 @@ require (
|
|||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/golang/mock v1.2.0
|
||||
github.com/golang/protobuf v1.2.0
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pomerium/envconfig v1.3.1-0.20190112072701-14cbcf832d31
|
||||
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/rs/zerolog v1.11.0
|
||||
github.com/stretchr/testify v1.2.2 // indirect
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect
|
||||
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 // indirect
|
||||
google.golang.org/appengine v1.4.0 // indirect
|
||||
github.com/stretchr/testify v1.3.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd
|
||||
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 // indirect
|
||||
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa // indirect
|
||||
google.golang.org/grpc v1.18.0
|
||||
gopkg.in/square/go-jose.v2 v2.2.1
|
||||
gopkg.in/square/go-jose.v2 v2.2.2
|
||||
)
|
||||
|
|
38
go.sum
38
go.sum
|
@ -1,11 +1,14 @@
|
|||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
|
||||
github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
|
||||
|
@ -23,34 +26,45 @@ github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAm
|
|||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
|
||||
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE=
|
||||
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis=
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 h1:uESlIz09WIHT2I+pasSXcpLYqYK8wHcdCetU3VuMBJE=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635 h1:dOJmQysgY8iOBECuNp0vlKHWEtfiTnyjisEizRV3/4o=
|
||||
golang.org/x/oauth2 v0.0.0-20190212230446-3e8b2be13635/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190116161447-11f53e031339 h1:g/Jesu8+QLnA0CPzF3E1pURg0Byr7i6jLoX5sqjcAh0=
|
||||
golang.org/x/sys v0.0.0-20190116161447-11f53e031339/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2 h1:z99zHgr7hKfrUcX/KsoJk5FJfjTceCKIp96+biqP4To=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa h1:FVL+/MjP2dzG4PxLpCJR7B6esIia88UAbsfYUrCc8U4=
|
||||
google.golang.org/genproto v0.0.0-20190215211957-bd968387e4aa/go.mod h1:L3J43x8/uS+qIUoksaLKe6OS3nUKxOKuIFz1sl2/jx4=
|
||||
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
|
||||
google.golang.org/grpc v1.18.0 h1:IZl7mfBGfbhYx2p2rKRtYgDFw6SBz+kclmxYrCksPPA=
|
||||
google.golang.org/grpc v1.18.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
|
||||
gopkg.in/square/go-jose.v2 v2.2.1 h1:uRIz/V7RfMsMgGnCp+YybIdstDIz8wc0H283wHQfwic=
|
||||
gopkg.in/square/go-jose.v2 v2.2.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA=
|
||||
gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
|
|
@ -29,47 +29,43 @@ type SessionStore interface {
|
|||
|
||||
// CookieStore represents all the cookie related configurations
|
||||
type CookieStore struct {
|
||||
Name string
|
||||
CSRFCookieName string
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
CookieCipher cryptutil.Cipher
|
||||
SessionLifetimeTTL time.Duration
|
||||
Name string
|
||||
CSRFCookieName string
|
||||
CookieCipher cryptutil.Cipher
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
}
|
||||
|
||||
// CreateCookieCipher creates a new miscreant cipher with the cookie secret
|
||||
func CreateCookieCipher(cookieSecret []byte) func(s *CookieStore) error {
|
||||
return func(s *CookieStore) error {
|
||||
cipher, err := cryptutil.NewCipher(cookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cookie-secret error: %s", err.Error())
|
||||
}
|
||||
s.CookieCipher = cipher
|
||||
return nil
|
||||
}
|
||||
// CookieStoreOptions holds options for CookieStore
|
||||
type CookieStoreOptions struct {
|
||||
Name string
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieCipher cryptutil.Cipher
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieStore(cookieName string, optFuncs ...func(*CookieStore) error) (*CookieStore, error) {
|
||||
c := &CookieStore{
|
||||
Name: cookieName,
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: 168 * time.Hour,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", cookieName, "csrf"),
|
||||
func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
||||
if opts.Name == "" {
|
||||
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
||||
}
|
||||
|
||||
for _, f := range optFuncs {
|
||||
err := f(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.CookieCipher == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
return &CookieStore{
|
||||
Name: opts.Name,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Name, "csrf"),
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieCipher: opts.CookieCipher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
|
@ -80,16 +76,19 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
|
|||
if s.CookieDomain != "" {
|
||||
domain = s.CookieDomain
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
c := &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: domain,
|
||||
HttpOnly: s.CookieHTTPOnly,
|
||||
Secure: s.CookieSecure,
|
||||
Expires: now.Add(expiration),
|
||||
}
|
||||
// only set an expiration if we want one, otherwise default to non perm session based
|
||||
if expiration != 0 {
|
||||
c.Expires = now.Add(expiration)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||
|
@ -103,13 +102,13 @@ func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration
|
|||
}
|
||||
|
||||
// ClearCSRF clears the CSRF cookie from the request
|
||||
func (s *CookieStore) ClearCSRF(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
|
||||
func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
||||
|
@ -118,20 +117,19 @@ func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
|||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (s *CookieStore) ClearSession(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||
func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (s *CookieStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
|
||||
func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// LoadSession returns a SessionState from the cookie in the request.
|
||||
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
||||
c, err := req.Cookie(s.Name)
|
||||
if err != nil {
|
||||
// always http.ErrNoCookie
|
||||
return nil, err
|
||||
return nil, err // http.ErrNoCookie
|
||||
}
|
||||
session, err := UnmarshalSession(c.Value, s.CookieCipher)
|
||||
if err != nil {
|
||||
|
@ -141,12 +139,11 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
|||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (s *CookieStore) SaveSession(rw http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
|
||||
func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setSessionCookie(rw, req, value)
|
||||
s.setSessionCookie(w, req, value)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,348 +1,348 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
var testEncodedCookieSecret, _ = base64.StdEncoding.DecodeString("qICChm3wdjbjcWymm7PefwtPP6/PZv+udkFEubTeE38=")
|
||||
type mockCipher struct{}
|
||||
|
||||
func TestCreateCookieCipher(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
cookieSecret []byte
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "normal case with base64 encoded secret",
|
||||
cookieSecret: testEncodedCookieSecret,
|
||||
},
|
||||
|
||||
{
|
||||
name: "error when not base64 encoded",
|
||||
cookieSecret: []byte("abcd"),
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewCookieStore("cookieName", CreateCookieCipher(tc.cookieSecret))
|
||||
if !tc.expectedError {
|
||||
testutil.Ok(t, err)
|
||||
} else {
|
||||
testutil.NotEqual(t, err, nil)
|
||||
}
|
||||
})
|
||||
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
||||
if string(s) == "error" {
|
||||
return []byte(""), errors.New("error encrypting")
|
||||
}
|
||||
return []byte("OK"), nil
|
||||
}
|
||||
|
||||
func TestNewSession(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedError bool
|
||||
expectedSession *CookieStore
|
||||
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
||||
if string(s) == "error" {
|
||||
return []byte(""), errors.New("error encrypting")
|
||||
}
|
||||
return []byte("OK"), nil
|
||||
}
|
||||
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
|
||||
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
||||
if string(s) == "unmarshal error" || string(s) == "error" {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func TestNewCookieStore(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *CookieStoreOptions
|
||||
want *CookieStore
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "default with no opt funcs set",
|
||||
expectedSession: &CookieStore{
|
||||
Name: "cookieName",
|
||||
{"good",
|
||||
&CookieStoreOptions{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: 168 * time.Hour,
|
||||
CSRFCookieName: "cookieName_csrf",
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "opt func with an error returns an error",
|
||||
optFuncs: []func(*CookieStore) error{func(*CookieStore) error { return fmt.Errorf("error") }},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "opt func overrides default values",
|
||||
optFuncs: []func(*CookieStore) error{func(s *CookieStore) error {
|
||||
s.CookieExpire = time.Hour
|
||||
return nil
|
||||
}},
|
||||
expectedSession: &CookieStore{
|
||||
Name: "cookieName",
|
||||
&CookieStore{
|
||||
Name: "_cookie",
|
||||
CSRFCookieName: "_cookie_csrf",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Hour,
|
||||
CSRFCookieName: "cookieName_csrf",
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
},
|
||||
},
|
||||
false},
|
||||
{"missing name",
|
||||
&CookieStoreOptions{
|
||||
Name: "",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
},
|
||||
nil,
|
||||
true},
|
||||
{"missing cipher",
|
||||
&CookieStoreOptions{
|
||||
Name: "_pomerium",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: nil,
|
||||
},
|
||||
nil,
|
||||
true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore("cookieName", tc.optFuncs...)
|
||||
if tc.expectedError {
|
||||
testutil.NotEqual(t, err, nil)
|
||||
} else {
|
||||
testutil.Ok(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewCookieStore(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
testutil.Equal(t, tc.expectedSession, session)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeSessionCookie(t *testing.T) {
|
||||
func TestCookieStore_makeCookie(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Name string
|
||||
CSRFCookieName string
|
||||
CookieCipher cryptutil.Cipher
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
cookieValue := "cookieValue"
|
||||
expiration := time.Hour
|
||||
cookieName := "cookieName"
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedCookie *http.Cookie
|
||||
}{
|
||||
{
|
||||
name: "default cookie domain",
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "www.example.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom cookie domain set",
|
||||
optFuncs: []func(*CookieStore) error{
|
||||
func(s *CookieStore) error {
|
||||
s.CookieDomain = "buzzfeed.com"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "buzzfeed.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
|
||||
cookieName string
|
||||
value string
|
||||
expiration time.Duration
|
||||
want *http.Cookie
|
||||
}{
|
||||
{"good", "http://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"domains with https", "https://pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"domain with port", "http://pomerium.io:443", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"expiration set", "http://pomerium.io:443", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", tt.domain, nil)
|
||||
|
||||
s := &CookieStore{
|
||||
Name: "_pomerium",
|
||||
CSRFCookieName: "_pomerium_csrf",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher}
|
||||
|
||||
if got := s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
|
||||
}
|
||||
if got := s.makeSessionCookie(r, tt.value, tt.expiration, now); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
|
||||
}
|
||||
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
|
||||
tt.want.Name = "_pomerium_csrf"
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.want)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
want := "new-csrf"
|
||||
s.SetCSRF(w, r, want)
|
||||
found := false
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("SetCSRF failed")
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
s.ClearCSRF(w, r)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.CSRFCookieName && cookie.Value == want {
|
||||
t.Error("clear csrf failed")
|
||||
break
|
||||
|
||||
}
|
||||
}
|
||||
w = httptest.NewRecorder()
|
||||
want = "new-session"
|
||||
s.setSessionCookie(w, r, want)
|
||||
found = false
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.Name && cookie.Value == want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("SetCSRF failed")
|
||||
}
|
||||
w = httptest.NewRecorder()
|
||||
s.ClearSession(w, r)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.Name && cookie.Value == want {
|
||||
t.Error("clear csrf failed")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
cookie := session.makeSessionCookie(req, cookieValue, expiration, now)
|
||||
testutil.Equal(t, cookie, tc.expectedCookie)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeSessionCSRFCookie(t *testing.T) {
|
||||
now := time.Now()
|
||||
cookieValue := "cookieValue"
|
||||
expiration := time.Hour
|
||||
cookieName := "cookieName"
|
||||
csrfName := "cookieName_csrf"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
expectedCookie *http.Cookie
|
||||
}{
|
||||
{
|
||||
name: "default cookie domain",
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: csrfName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "www.example.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom cookie domain set",
|
||||
optFuncs: []func(*CookieStore) error{
|
||||
func(s *CookieStore) error {
|
||||
s.CookieDomain = "buzzfeed.com"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
expectedCookie: &http.Cookie{
|
||||
Name: csrfName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
Domain: "buzzfeed.com",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
Expires: now.Add(expiration),
|
||||
},
|
||||
},
|
||||
func TestCookieStore_SaveSession(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionState *SessionState
|
||||
cipher cryptutil.Cipher
|
||||
wantErr bool
|
||||
wantLoadErr bool
|
||||
}{
|
||||
{"good",
|
||||
&SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}, cipher, false, false},
|
||||
{"bad cipher",
|
||||
&SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}, mockCipher{}, true, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &CookieStore{
|
||||
Name: "_pomerium",
|
||||
CSRFCookieName: "_pomerium_csrf",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: tt.cipher}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
cookie := session.makeCSRFCookie(req, cookieValue, expiration, now)
|
||||
testutil.Equal(t, tc.expectedCookie, cookie)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
t.Log(cookie)
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
|
||||
state, err := s.LoadSession(r)
|
||||
if (err != nil) != tt.wantLoadErr {
|
||||
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
}
|
||||
if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
|
||||
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set session cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session.setSessionCookie(rw, req, cookieValue)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == cookieName {
|
||||
found = true
|
||||
testutil.Equal(t, cookieValue, cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
func TestSetCSRFSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set csrf cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session.SetCSRF(rw, req, cookieValue)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||
found = true
|
||||
testutil.Equal(t, cookieValue, cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.After(time.Now()), "cookie expires after now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("set session cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
req.AddCookie(session.makeSessionCookie(req, cookieValue, time.Hour, time.Now()))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
session.ClearSession(rw, req)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == cookieName {
|
||||
found = true
|
||||
testutil.Equal(t, "", cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearCSRFSessionCookie(t *testing.T) {
|
||||
cookieValue := "cookieValue"
|
||||
cookieName := "cookieName"
|
||||
|
||||
t.Run("clear csrf cookie test", func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "http://www.example.com", nil)
|
||||
req.AddCookie(session.makeCSRFCookie(req, cookieValue, time.Hour, time.Now()))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
session.ClearCSRF(rw, req)
|
||||
var found bool
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
if cookie.Name == fmt.Sprintf("%s_csrf", cookieName) {
|
||||
found = true
|
||||
testutil.Equal(t, "", cookie.Value)
|
||||
testutil.Assert(t, cookie.Expires.Before(time.Now()), "cookie expires before now")
|
||||
}
|
||||
}
|
||||
testutil.Assert(t, found, "cookie in header")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
cookieName := "cookieName"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
optFuncs []func(*CookieStore) error
|
||||
setupCookies func(*testing.T, *http.Request, *CookieStore, *SessionState)
|
||||
expectedError error
|
||||
sessionState *SessionState
|
||||
func TestMockCSRFStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockCSRFStore
|
||||
newCSRFValue string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no cookie set returns an error",
|
||||
setupCookies: func(*testing.T, *http.Request, *CookieStore, *SessionState) {},
|
||||
expectedError: http.ErrNoCookie,
|
||||
},
|
||||
{
|
||||
name: "cookie set with cipher set",
|
||||
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)},
|
||||
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
testutil.Ok(t, err)
|
||||
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||
},
|
||||
sessionState: &SessionState{
|
||||
Email: "example@email.com",
|
||||
RefreshToken: "abccdddd",
|
||||
AccessToken: "access",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cookie set with invalid value cipher set",
|
||||
optFuncs: []func(*CookieStore) error{CreateCookieCipher(testEncodedCookieSecret)},
|
||||
setupCookies: func(t *testing.T, req *http.Request, s *CookieStore, sessionState *SessionState) {
|
||||
value := "574b776a7c934d6b9fc42ec63a389f79"
|
||||
req.AddCookie(s.makeSessionCookie(req, value, time.Hour, time.Now()))
|
||||
},
|
||||
expectedError: ErrInvalidSession,
|
||||
},
|
||||
{"basic",
|
||||
&MockCSRFStore{
|
||||
ResponseCSRF: "ok",
|
||||
Cookie: &http.Cookie{Name: "hi"}},
|
||||
"newcsrf",
|
||||
false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := NewCookieStore(cookieName, tc.optFuncs...)
|
||||
testutil.Ok(t, err)
|
||||
req := httptest.NewRequest("GET", "https://www.example.com", nil)
|
||||
tc.setupCookies(t, req, session, tc.sessionState)
|
||||
s, err := session.LoadSession(req)
|
||||
|
||||
testutil.Equal(t, tc.expectedError, err)
|
||||
testutil.Equal(t, tc.sessionState, s)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ms := tt.mockCSRF
|
||||
ms.SetCSRF(nil, nil, tt.newCSRFValue)
|
||||
ms.ClearCSRF(nil, nil)
|
||||
got, err := ms.GetCSRF(nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockSessionStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockSessionStore
|
||||
saveSession *SessionState
|
||||
wantLoadErr bool
|
||||
wantSaveErr bool
|
||||
}{
|
||||
{"basic",
|
||||
&MockSessionStore{
|
||||
ResponseSession: "test",
|
||||
Session: &SessionState{AccessToken: "AccessToken"},
|
||||
SaveError: nil,
|
||||
LoadError: nil,
|
||||
},
|
||||
&SessionState{AccessToken: "AccessToken"},
|
||||
false,
|
||||
false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ms := tt.mockCSRF
|
||||
|
||||
err := ms.SaveSession(nil, nil, tt.saveSession)
|
||||
if (err != nil) != tt.wantSaveErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
|
||||
return
|
||||
}
|
||||
got, err := ms.LoadSession(nil)
|
||||
if (err != nil) != tt.wantLoadErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
|
||||
}
|
||||
ms.ClearSession(nil, nil)
|
||||
if ms.ResponseSession != "" {
|
||||
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ type MockSessionStore struct {
|
|||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
|
|
|
@ -20,11 +20,9 @@ type SessionState struct {
|
|||
|
||||
RefreshDeadline time.Time `json:"refresh_deadline"`
|
||||
LifetimeDeadline time.Time `json:"lifetime_deadline"`
|
||||
ValidDeadline time.Time `json:"valid_deadline"`
|
||||
GracePeriodStart time.Time `json:"grace_period_start"`
|
||||
|
||||
Email string `json:"email"`
|
||||
User string `json:"user"` // 'sub' in jwt parlance
|
||||
User string `json:"user"` // 'sub' in jwt
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
|
@ -38,11 +36,6 @@ func (s *SessionState) RefreshPeriodExpired() bool {
|
|||
return isExpired(s.RefreshDeadline)
|
||||
}
|
||||
|
||||
// ValidationPeriodExpired returns true if the validation period has expired
|
||||
func (s *SessionState) ValidationPeriodExpired() bool {
|
||||
return isExpired(s.ValidDeadline)
|
||||
}
|
||||
|
||||
func isExpired(t time.Time) bool {
|
||||
return t.Before(time.Now())
|
||||
}
|
||||
|
@ -64,7 +57,7 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// ExtendDeadline returns the time extended by a given duration
|
||||
// ExtendDeadline returns the time extended by a given duration, truncated by second
|
||||
func ExtendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
@ -16,15 +16,12 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||
}
|
||||
|
||||
want := &SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(),
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
ciphertext, err := MarshalSession(want, c)
|
||||
|
@ -46,26 +43,40 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||
|
||||
func TestSessionStateExpirations(t *testing.T) {
|
||||
session := &SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
|
||||
RefreshDeadline: time.Now().Add(-1 * time.Hour),
|
||||
ValidDeadline: time.Now().Add(-1 * time.Minute),
|
||||
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
if !session.LifetimePeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
t.Errorf("expected lifetime period to be expired")
|
||||
}
|
||||
|
||||
if !session.RefreshPeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
t.Errorf("expected lifetime period to be expired")
|
||||
}
|
||||
|
||||
if !session.ValidationPeriodExpired() {
|
||||
t.Errorf("expcted lifetime period to be expired")
|
||||
}
|
||||
|
||||
func TestExtendDeadline(t *testing.T) {
|
||||
// tons of wiggle room here
|
||||
now := time.Now().Truncate(time.Second)
|
||||
tests := []struct {
|
||||
name string
|
||||
ttl time.Duration
|
||||
want time.Time
|
||||
}{
|
||||
{"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)},
|
||||
{"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
// Original Copyright 2013 The Go Authors. All rights reserved.
|
||||
//
|
||||
// Modified by BuzzFeed to return duplicate counts.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package singleflight provides a duplicate function call suppression mechanism.
|
||||
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
|
||||
|
||||
import "sync"
|
||||
|
||||
// call is an in-flight or completed singleflight.Do call
|
||||
type call struct {
|
||||
wg sync.WaitGroup
|
||||
|
||||
// These fields are written once before the WaitGroup is done
|
||||
// and are only read after the WaitGroup is done.
|
||||
val interface{}
|
||||
err error
|
||||
|
||||
// These fields are read and written with the singleflight
|
||||
// mutex held before the WaitGroup is done, and are read but
|
||||
// not written after the WaitGroup is done.
|
||||
dups int
|
||||
}
|
||||
|
||||
// Group represents a class of work and forms a namespace in
|
||||
// which units of work can be executed with duplicate suppression.
|
||||
type Group struct {
|
||||
mu sync.Mutex // protects m
|
||||
m map[string]*call // lazily initialized
|
||||
}
|
||||
|
||||
// Result holds the results of Do, so they can be passed
|
||||
// on a channel.
|
||||
type Result struct {
|
||||
Val interface{}
|
||||
Err error
|
||||
Count bool
|
||||
}
|
||||
|
||||
// Do executes and returns the results of the given function, making
|
||||
// sure that only one execution is in-flight for a given key at a
|
||||
// time. If a duplicate comes in, the duplicate caller waits for the
|
||||
// original to complete and receives the same results.
|
||||
// The return value of Count indicates how many tiems v was given to multiple callers.
|
||||
// Count will be zero for requests are shared and only be non-zero for the originating request.
|
||||
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) {
|
||||
g.mu.Lock()
|
||||
|
||||
if g.m == nil {
|
||||
g.m = make(map[string]*call)
|
||||
}
|
||||
if c, ok := g.m[key]; ok {
|
||||
c.dups++
|
||||
g.mu.Unlock()
|
||||
c.wg.Wait()
|
||||
return c.val, 0, c.err
|
||||
}
|
||||
c := new(call)
|
||||
c.wg.Add(1)
|
||||
g.m[key] = c
|
||||
|
||||
g.mu.Unlock()
|
||||
|
||||
c.val, c.err = fn()
|
||||
c.wg.Done()
|
||||
|
||||
g.mu.Lock()
|
||||
delete(g.m, key)
|
||||
g.mu.Unlock()
|
||||
|
||||
return c.val, c.dups, c.err
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package singleflight // import "github.com/pomerium/pomerium/internal/singleflight"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
var g Group
|
||||
v, _, err := g.Do("key", func() (interface{}, error) {
|
||||
return "bar", nil
|
||||
})
|
||||
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
|
||||
t.Errorf("Do = %v; want %v", got, want)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Do error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoErr(t *testing.T) {
|
||||
var g Group
|
||||
someErr := errors.New("Some error")
|
||||
v, _, err := g.Do("key", func() (interface{}, error) {
|
||||
return nil, someErr
|
||||
})
|
||||
if err != someErr {
|
||||
t.Errorf("Do error = %v; want someErr %v", err, someErr)
|
||||
}
|
||||
if v != nil {
|
||||
t.Errorf("unexpected non-nil value %#v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoDupSuppress(t *testing.T) {
|
||||
var g Group
|
||||
var wg1, wg2 sync.WaitGroup
|
||||
c := make(chan string, 1)
|
||||
var calls int32
|
||||
fn := func() (interface{}, error) {
|
||||
if atomic.AddInt32(&calls, 1) == 1 {
|
||||
// First invocation.
|
||||
wg1.Done()
|
||||
}
|
||||
v := <-c
|
||||
c <- v // pump; make available for any future calls
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
const n = 10
|
||||
wg1.Add(1)
|
||||
for i := 0; i < n; i++ {
|
||||
wg1.Add(1)
|
||||
wg2.Add(1)
|
||||
go func() {
|
||||
defer wg2.Done()
|
||||
wg1.Done()
|
||||
v, _, err := g.Do("key", fn)
|
||||
if err != nil {
|
||||
t.Errorf("Do error: %v", err)
|
||||
return
|
||||
}
|
||||
if s, _ := v.(string); s != "bar" {
|
||||
t.Errorf("Do = %T %v; want %q", v, v, "bar")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg1.Wait()
|
||||
// At least one goroutine is in fn now and all of them have at
|
||||
// least reached the line before the Do.
|
||||
c <- "bar"
|
||||
wg2.Wait()
|
||||
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
|
||||
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
package testutil // import "github.com/pomerium/pomerium/internal/testutil"
|
||||
|
||||
// testing util functions copied from https://github.com/benbjohnson/testing
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Assert fails the test if the condition is false.
|
||||
func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
|
||||
if !condition {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Ok fails the test if an err is not nil.
|
||||
func Ok(tb testing.TB, err error) {
|
||||
if err != nil {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Equal fails the test if exp is not equal to act.
|
||||
func Equal(tb testing.TB, exp, act interface{}) {
|
||||
if !reflect.DeepEqual(exp, act) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// NotEqual fails the test if exp is equal to act.
|
||||
func NotEqual(tb testing.TB, exp, act interface{}) {
|
||||
if reflect.DeepEqual(exp, act) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
|
||||
tb.FailNow()
|
||||
}
|
||||
}
|
|
@ -25,7 +25,7 @@ type Options struct {
|
|||
// InternalAddr is the internal (behind the ingress) address to use when making an
|
||||
// authentication connection. If empty, Addr is used.
|
||||
InternalAddr string
|
||||
// OverrideServerName overrides the server name used to verify the hostname on the
|
||||
// OverideCertificateName overrides the server name used to verify the hostname on the
|
||||
// returned certificates from the server. gRPC internals also use it to override the virtual
|
||||
// hosting name if it is set.
|
||||
OverideCertificateName string
|
||||
|
|
|
@ -203,16 +203,15 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// We store the session in a cookie and redirect the user back to the application
|
||||
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{
|
||||
AccessToken: rr.AccessToken,
|
||||
RefreshToken: rr.RefreshToken,
|
||||
IDToken: rr.IDToken,
|
||||
User: rr.User,
|
||||
Email: rr.Email,
|
||||
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
|
||||
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
})
|
||||
err = p.sessionStore.SaveSession(w, r,
|
||||
&sessions.SessionState{
|
||||
AccessToken: rr.AccessToken,
|
||||
RefreshToken: rr.RefreshToken,
|
||||
IDToken: rr.IDToken,
|
||||
User: rr.User,
|
||||
Email: rr.Email,
|
||||
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
|
||||
})
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Msg("error saving session")
|
||||
httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
|
||||
|
@ -250,9 +249,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// ! ! !
|
||||
// todo(bdd): ! Authorization service goes here !
|
||||
// ! ! !
|
||||
// todo(bdd): add authorization service validation
|
||||
|
||||
// We have validated the users request and now proxy their request to the provided upstream.
|
||||
route, ok := p.router(r)
|
||||
|
@ -278,14 +275,10 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
|||
return err
|
||||
}
|
||||
|
||||
if session.LifetimePeriodExpired() {
|
||||
log.FromRequest(r).Info().Msg("proxy: lifetime expired")
|
||||
return sessions.ErrLifetimeExpired
|
||||
}
|
||||
if session.RefreshPeriodExpired() {
|
||||
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
|
||||
// refresh token (which doesn't change) can be used to request a new access-token. If access
|
||||
// is revoked by identity provider, or no refresh token is set request will return an error
|
||||
// is revoked by identity provider, or no refresh token is set, request will return an error
|
||||
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Warn().
|
||||
|
|
|
@ -272,7 +272,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
proxy.sessionStore = tt.session
|
||||
proxy.sessionStore = &tt.session
|
||||
proxy.csrfStore = tt.csrf
|
||||
proxy.AuthenticateClient = tt.authenticator
|
||||
proxy.cipher = mockCipher{}
|
||||
|
@ -352,12 +352,6 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
}
|
||||
expiredLifetime := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -368,11 +362,10 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
wantStatus int
|
||||
}{
|
||||
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
|
||||
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
|
||||
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
|
||||
{"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
|
||||
{"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
|
||||
// redirect to start auth process
|
||||
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
|
||||
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
|
||||
{"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -402,13 +395,8 @@ func TestProxy_Authenticate(t *testing.T) {
|
|||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
}
|
||||
expiredLifetime := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||
}
|
||||
|
||||
expiredDeadline := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
@ -426,25 +414,21 @@ func TestProxy_Authenticate(t *testing.T) {
|
|||
{"cannot save session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
|
||||
&sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
|
||||
authenticator.MockAuthenticate{}, true},
|
||||
|
||||
{"cannot load session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
|
||||
{"expired lifetime",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
|
||||
&sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
|
||||
{"expired session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
|
||||
&sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
|
||||
{"bad refresh authenticator",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{
|
||||
&sessions.MockSessionStore{
|
||||
Session: expiredDeadline,
|
||||
},
|
||||
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
|
||||
|
@ -453,7 +437,7 @@ func TestProxy_Authenticate(t *testing.T) {
|
|||
{"good",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
|
||||
&sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -133,12 +133,9 @@ type Proxy struct {
|
|||
AuthenticateClient authenticator.Authenticator
|
||||
|
||||
// session
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieLifetimeTTL time.Duration
|
||||
cipher cryptutil.Cipher
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
cipher cryptutil.Cipher
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
|
||||
redirectURL *url.URL
|
||||
templates *template.Template
|
||||
|
@ -163,13 +160,14 @@ func New(opts *Options) (*Proxy, error) {
|
|||
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieStore(opts.CookieName,
|
||||
sessions.CreateCookieCipher(decodedSecret),
|
||||
func(c *sessions.CookieStore) error {
|
||||
c.CookieDomain = opts.CookieDomain
|
||||
c.CookieHTTPOnly = opts.CookieHTTPOnly
|
||||
c.CookieExpire = opts.CookieExpire
|
||||
return nil
|
||||
cookieStore, err := sessions.NewCookieStore(
|
||||
&sessions.CookieStoreOptions{
|
||||
Name: opts.CookieName,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieCipher: cipher,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -181,14 +179,12 @@ func New(opts *Options) (*Proxy, error) {
|
|||
// services
|
||||
AuthenticateURL: opts.AuthenticateURL,
|
||||
// session state
|
||||
cipher: cipher,
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
SharedKey: opts.SharedKey,
|
||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||
templates: templates.New(),
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieLifetimeTTL: opts.CookieLifetimeTTL,
|
||||
cipher: cipher,
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
SharedKey: opts.SharedKey,
|
||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||
templates: templates.New(),
|
||||
}
|
||||
|
||||
for from, to := range opts.Routes {
|
||||
|
@ -200,7 +196,7 @@ func New(opts *Options) (*Proxy, error) {
|
|||
return nil, err
|
||||
}
|
||||
p.Handle(fromURL.Host, handler)
|
||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
|
||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy: new route")
|
||||
}
|
||||
|
||||
p.AuthenticateClient, err = authenticator.New(
|
||||
|
|
|
@ -139,6 +139,7 @@ func testOptions() *Options {
|
|||
AuthenticateURL: authurl,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieName: "pomerium",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue