authenticate: make service http only

- Rename SessionState to State to avoid stutter.
- Simplified option validation to use a wrapper function for base64 secrets.
- Removed authenticates grpc code.
- Abstracted logic to load and validate a user's authenticate session.
- Removed instances of url.Parse in favor of urlutil's version.
- proxy: replaces grpc refresh logic with forced deadline advancement.
- internal/sessions: remove rest store; parse authorize header as part of session store.
- proxy: refactor request signer
- sessions: remove extend deadline (fixes #294)
- remove AuthenticateInternalAddr
- remove AuthenticateInternalAddrString
- omit type tag.Key from declaration of vars TagKey* it will be inferred
  from the right-hand side
- remove compatibility package xerrors
- use cloned http.DefaultTransport as base transport
This commit is contained in:
Bobby DeSimone 2019-08-29 22:12:29 -07:00
parent bc72d08ad4
commit 380d314404
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
53 changed files with 718 additions and 2280 deletions

View file

@ -45,7 +45,6 @@ tag: ## Create a new git tag to prepare to build a release
.PHONY: build .PHONY: build
build: ## Builds dynamic executables and/or packages. build: ## Builds dynamic executables and/or packages.
@echo "==> $@" @echo "==> $@"
@echo Untracked changes? dirty? $(BUILDMETA) files? $(GITUNTRACKEDCHANGES)
@CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)" @CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)"
.PHONY: lint .PHONY: lint

View file

@ -15,36 +15,31 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
// ValidateOptions checks to see if configuration values are valid for the authenticate service. // ValidateOptions checks that configuration are complete and valid.
// The checks do not modify the internal state of the Option structure. Returns // Returns on first error found.
// on first error found.
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err)
}
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
}
if o.AuthenticateURL == nil { if o.AuthenticateURL == nil {
return errors.New("authenticate: missing setting: authenticate-service-url") return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required")
} }
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil { if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
return fmt.Errorf("authenticate: error parsing authenticate url: %v", err) return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err)
} }
if o.ClientID == "" { if o.ClientID == "" {
return errors.New("authenticate: 'IDP_CLIENT_ID' missing") return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
} }
if o.ClientSecret == "" { if o.ClientSecret == "" {
return errors.New("authenticate: 'IDP_CLIENT_SECRET' missing") return errors.New("authenticate: 'IDP_CLIENT_SECRET' is required")
}
if o.SharedKey == "" {
return errors.New("authenticate: 'SHARED_SECRET' missing")
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
if err != nil {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' must be base64 encoded: %v", err)
}
if len(decodedCookieSecret) != 32 {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' %s be 32; got %d", o.CookieSecret, len(decodedCookieSecret))
} }
return nil return nil
} }
// Authenticate validates a user's identity // Authenticate contains data required to run the authenticate service.
type Authenticate struct { type Authenticate struct {
SharedKey string SharedKey string
RedirectURL *url.URL RedirectURL *url.URL
@ -52,12 +47,11 @@ type Authenticate struct {
templates *template.Template templates *template.Template
csrfStore sessions.CSRFStore csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
restStore sessions.SessionStore
cipher cryptutil.Cipher cipher cryptutil.Cipher
provider identity.Authenticator provider identity.Authenticator
} }
// New validates and creates a new authenticate service from a set of Options // New validates and creates a new authenticate service from a set of Options.
func New(opts config.Options) (*Authenticate, error) { func New(opts config.Options) (*Authenticate, error) {
if err := ValidateOptions(opts); err != nil { if err := ValidateOptions(opts); err != nil {
return nil, err return nil, err
@ -95,17 +89,13 @@ func New(opts config.Options) (*Authenticate, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
restStore, err := sessions.NewRestStore(&sessions.RestStoreOptions{Cipher: cipher})
if err != nil {
return nil, err
}
return &Authenticate{ return &Authenticate{
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
RedirectURL: redirectURL, RedirectURL: redirectURL,
templates: templates.New(), templates: templates.New(),
csrfStore: cookieStore, csrfStore: cookieStore,
sessionStore: cookieStore, sessionStore: cookieStore,
restStore: restStore,
cipher: cipher, cipher: cipher,
provider: provider, provider: provider,
}, nil }, nil

View file

@ -1,62 +0,0 @@
//go:generate protoc -I ../proto/authenticate --go_out=plugins=grpc:../proto/authenticate ../proto/authenticate/authenticate.proto
package authenticate // import "github.com/pomerium/pomerium/authenticate"
import (
"context"
"fmt"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
// Authenticate takes an encrypted code, and returns the authentication result.
func (p *Authenticate) Authenticate(ctx context.Context, in *pb.AuthenticateRequest) (*pb.Session, error) {
_, span := trace.StartSpan(ctx, "authenticate.grpc.Validate")
defer span.End()
session, err := sessions.UnmarshalSession(in.Code, p.cipher)
if err != nil {
return nil, fmt.Errorf("authenticate/grpc: authenticate %v", err)
}
newSessionProto, err := pb.ProtoFromSession(session)
if err != nil {
return nil, err
}
return newSessionProto, nil
}
// Validate locally validates a JWT id_token; does NOT do nonce or revokation validation.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (p *Authenticate) Validate(ctx context.Context, in *pb.ValidateRequest) (*pb.ValidateReply, error) {
ctx, span := trace.StartSpan(ctx, "authenticate.grpc.Validate")
defer span.End()
isValid, err := p.provider.Validate(ctx, in.IdToken)
if err != nil {
return &pb.ValidateReply{IsValid: false}, fmt.Errorf("authenticate/grpc: validate %v", err)
}
return &pb.ValidateReply{IsValid: isValid}, nil
}
// Refresh renews a user's session checks if the session has been revoked using an access token
// without reprompting the user.
func (p *Authenticate) Refresh(ctx context.Context, in *pb.Session) (*pb.Session, error) {
ctx, span := trace.StartSpan(ctx, "authenticate.grpc.Refresh")
defer span.End()
if in == nil {
return nil, fmt.Errorf("authenticate/grpc: session cannot be nil")
}
oldSession, err := pb.SessionFromProto(in)
if err != nil {
return nil, err
}
newSession, err := p.provider.Refresh(ctx, oldSession)
if err != nil {
return nil, fmt.Errorf("authenticate/grpc: refresh failed %v", err)
}
newSessionProto, err := pb.ProtoFromSession(newSession)
if err != nil {
return nil, err
}
return newSessionProto, nil
}

View file

@ -1,153 +0,0 @@
package authenticate
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/pomerium/pomerium/internal/identity"
"github.com/golang/protobuf/ptypes"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/sessions"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func TestAuthenticate_Validate(t *testing.T) {
tests := []struct {
name string
idToken string
mp *identity.MockProvider
want bool
wantErr bool
}{
{"good", "example", &identity.MockProvider{}, false, false},
{"error", "error", &identity.MockProvider{ValidateError: errors.New("err")}, false, true},
{"not error", "not error", &identity.MockProvider{ValidateError: nil}, false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Authenticate{provider: tt.mp}
got, err := p.Validate(context.Background(), &pb.ValidateRequest{IdToken: tt.idToken})
if (err != nil) != tt.wantErr {
t.Errorf("Authenticate.Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.IsValid, tt.want) {
t.Errorf("Authenticate.Validate() = %v, want %v", got.IsValid, tt.want)
}
})
}
}
func TestAuthenticate_Refresh(t *testing.T) {
fixedProtoTime, err := ptypes.TimestampProto(fixedDate)
if err != nil {
t.Fatal("failed to parse timestamp")
}
tests := []struct {
name string
mock *identity.MockProvider
originalSession *pb.Session
want *pb.Session
wantErr bool
}{
{"good",
&identity.MockProvider{
RefreshResponse: &sessions.SessionState{
AccessToken: "updated",
RefreshDeadline: fixedDate,
}},
&pb.Session{
AccessToken: "original",
RefreshDeadline: fixedProtoTime,
},
&pb.Session{
AccessToken: "updated",
RefreshDeadline: fixedProtoTime,
},
false},
{"test error", &identity.MockProvider{RefreshError: errors.New("hi")}, &pb.Session{RefreshToken: "refresh token", RefreshDeadline: fixedProtoTime}, nil, true},
{"test catch nil", nil, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Authenticate{provider: tt.mock}
got, err := p.Refresh(context.Background(), tt.originalSession)
if (err != nil) != tt.wantErr {
t.Errorf("Authenticate.Refresh() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authenticate.Refresh() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthenticate_Authenticate(t *testing.T) {
secret := cryptutil.GenerateKey()
c, err := cryptutil.NewCipher(secret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
}
newSecret := cryptutil.GenerateKey()
c2, err := cryptutil.NewCipher(newSecret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
}
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
vtProto, err := ptypes.TimestampProto(rt)
if err != nil {
t.Fatal("failed to parse timestamp")
}
want := &sessions.SessionState{
AccessToken: "token1234",
RefreshToken: "refresh4321",
RefreshDeadline: rt,
Email: "user@domain.com",
User: "user",
}
goodReply := &pb.Session{
AccessToken: "token1234",
RefreshToken: "refresh4321",
RefreshDeadline: vtProto,
Email: "user@domain.com",
User: "user"}
ciphertext, err := sessions.MarshalSession(want, c)
if err != nil {
t.Fatalf("expected to be encode session: %v", err)
}
tests := []struct {
name string
cipher cryptutil.Cipher
code string
want *pb.Session
wantErr bool
}{
{"good", c, ciphertext, goodReply, false},
{"bad cipher", c2, ciphertext, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Authenticate{cipher: tt.cipher}
got, err := p.Authenticate(context.Background(), &pb.AuthenticateRequest{Code: tt.code})
if (err != nil) != tt.wantErr {
t.Errorf("Authenticate.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authenticate.Authenticate() = got: \n%vwant:\n%v", got, tt.want)
}
})
}
}

View file

@ -2,12 +2,13 @@ package authenticate // import "github.com/pomerium/pomerium/authenticate"
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"golang.org/x/xerrors"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
@ -18,6 +19,7 @@ import (
) )
// CSPHeaders are the content security headers added to the service's handlers // CSPHeaders are the content security headers added to the service's handlers
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
var CSPHeaders = map[string]string{ var CSPHeaders = map[string]string{
"Content-Security-Policy": "default-src 'none'; style-src 'self'" + "Content-Security-Policy": "default-src 'none'; style-src 'self'" +
" 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " + " 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " +
@ -27,22 +29,24 @@ var CSPHeaders = map[string]string{
"Referrer-Policy": "Same-origin", "Referrer-Policy": "Same-origin",
} }
// Handler returns the authenticate service's HTTP request multiplexer, and routes. // Handler returns the authenticate service's HTTP multiplexer, and routes.
func (a *Authenticate) Handler() http.Handler { func (a *Authenticate) Handler() http.Handler {
// validation middleware chain // validation middleware chain
c := middleware.NewChain() c := middleware.NewChain()
c = c.Append(middleware.SetHeaders(CSPHeaders)) c = c.Append(middleware.SetHeaders(CSPHeaders))
validate := c.Append(middleware.ValidateSignature(a.SharedKey))
validate = validate.Append(middleware.ValidateRedirectURI(a.RedirectURL))
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt)) mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt))
// Identity Provider (IdP) callback endpoints and callbacks // Identity Provider (IdP) endpoints
mux.Handle("/start", c.ThenFunc(a.OAuthStart)) mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart))
mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback)) mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback))
// authenticate-server endpoints // Proxy service endpoints
mux.Handle("/sign_in", validate.ThenFunc(a.SignIn)) validationMiddlewares := c.Append(
mux.Handle("/sign_out", validate.ThenFunc(a.SignOut)) // POST middleware.ValidateSignature(a.SharedKey),
// programmatic authentication endpoints middleware.ValidateRedirectURI(a.RedirectURL),
)
mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn))
mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST
// Direct user access endpoints
mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken)) mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken))
return mux return mux
} }
@ -55,43 +59,46 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "User-agent: *\nDisallow: /") fmt.Fprintf(w, "User-agent: *\nDisallow: /")
} }
func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) error { func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) {
if session.RefreshPeriodExpired() { session, err := a.sessionStore.LoadSession(r)
session, err := a.provider.Refresh(r.Context(), session) if err != nil {
if err != nil { return nil, err
return xerrors.Errorf("session refresh failed : %w", err)
}
if err = a.sessionStore.SaveSession(w, r, session); err != nil {
return xerrors.Errorf("failed saving refreshed session : %w", err)
}
} else {
valid, err := a.provider.Validate(r.Context(), session.IDToken)
if err != nil || !valid {
return xerrors.Errorf("session valid: %v : %w", valid, err)
}
} }
return nil err = session.Valid()
if err == nil {
return session, nil
} else if !errors.Is(err, sessions.ErrExpired) {
return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err)
} else {
return a.refresh(w, r, session)
}
}
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) {
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return newSession, nil
} }
// SignIn handles to authenticating a user. // SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.sessionStore.LoadSession(r) session, err := a.loadExisting(w, r)
if err != nil { if err != nil {
log.FromRequest(r).Debug().Err(err).Msg("no session loaded, restart auth") log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session")
a.sessionStore.ClearSession(w, r) a.sessionStore.ClearSession(w, r)
a.OAuthStart(w, r) a.OAuthStart(w, r)
return return
} }
// if a session already exists, authenticate it
if err := a.authenticate(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
state := r.Form.Get("state") state := r.Form.Get("state")
if state == "" { if state == "" {
httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil)) httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil))
@ -100,21 +107,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri")) redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri parameter passed", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return return
} }
// encrypt session state as json blob // encrypt session state as json blob
encrypted, err := sessions.MarshalSession(session, a.cipher) encrypted, err := sessions.MarshalSession(session, a.cipher)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("couldn't marshall session", http.StatusInternalServerError, err)) httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err))
return return
} }
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound) http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound)
} }
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string { func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
// error handled by go's mux stack // ParseQuery err handled by go's mux stack
params, _ := url.ParseQuery(redirectURL.RawQuery) params, _ := url.ParseQuery(redirectURL.RawQuery)
params.Set("code", authCode) params.Set("code", authCode)
params.Set("state", state) params.Set("state", state)
@ -122,8 +128,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
return redirectURL.String() return redirectURL.String()
} }
// SignOut signs the user out by trying to revoke the user's remote identity session along with // SignOut signs the user out and attempts to revoke the user's identity session
// the associated local session state. Handles both GET and POST. // Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
@ -156,7 +162,6 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
// OIDC : 3.1.2.1. Authentication Request // OIDC : 3.1.2.1. Authentication Request
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey()) nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
a.csrfStore.SetCSRF(w, r, nonce) a.csrfStore.SetCSRF(w, r, nonce)
// Redirection URI to which the response will be sent. This URI MUST exactly // Redirection URI to which the response will be sent. This URI MUST exactly
// match one of the Redirection URI values for the Client pre-registered at // match one of the Redirection URI values for the Client pre-registered at
// at your identity provider // at your identity provider
@ -173,7 +178,6 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil)) httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
return return
} }
// State is the opaque value used to maintain state between the request and // State is the opaque value used to maintain state between the request and
// the callback; contains both the nonce and redirect URI // the callback; contains both the nonce and redirect URI
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String()))) state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
@ -183,74 +187,69 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, signInURL, http.StatusFound) http.Redirect(w, r, signInURL, http.StatusFound)
} }
// OAuthCallback handles the callback from the identity provider. Displays an error page if there // OAuthCallback handles the callback from the identity provider.
// was an error. If successful, the user is redirected back to the proxy-service.
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r) redirect, err := a.getOAuthCallback(w, r)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, xerrors.Errorf("oauth callback : %w", err)) httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
return return
} }
// redirect back to the proxy-service via sign_in // redirect back to the proxy-service via sign_in
http.Redirect(w, r, redirect, http.StatusFound) http.Redirect(w, r, redirect.String(), http.StatusFound)
} }
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return "", httputil.Error("invalid signature", http.StatusBadRequest, err) return nil, httputil.Error("invalid signature", http.StatusBadRequest, err)
} }
// OIDC : 3.1.2.6. Authentication Error Response // OIDC : 3.1.2.6. Authentication Error Response
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError
if errorString := r.Form.Get("error"); errorString != "" { if idpError := r.Form.Get("error"); idpError != "" {
return "", httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider returned error: %v", errorString)) return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError))
} }
// OIDC : 3.1.2.5. Successful Authentication Response
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
code := r.Form.Get("code") code := r.Form.Get("code")
if code == "" { if code == "" {
return "", httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil) return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil)
} }
// validate the returned code with the identity provider // validate the returned code with the identity provider
session, err := a.provider.Authenticate(r.Context(), code) session, err := a.provider.Authenticate(r.Context(), code)
if err != nil { if err != nil {
return "", xerrors.Errorf("error redeeming authenticate code: %w", err) return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
} }
// Opaque value used to maintain state between the request and the callback.
// OIDC : 3.1.2.5. Successful Authentication Response // OIDC : 3.1.2.5. Successful Authentication Response
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse // Opaque value used to maintain state between the request and the callback.
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state")) bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
if err != nil { if err != nil {
return "", xerrors.Errorf("failed decoding state: %w", err) return nil, fmt.Errorf("failed decoding state: %w", err)
} }
s := strings.SplitN(string(bytes), ":", 2) s := strings.SplitN(string(bytes), ":", 2)
if len(s) != 2 { if len(s) != 2 {
return "", xerrors.Errorf("invalid state size: %v", len(s)) return nil, fmt.Errorf("invalid state size: %d", len(s))
} }
// state contains both our csrf nonce and the redirect uri // state contains the csrf nonce and redirect uri
nonce := s[0] nonce := s[0]
redirect := s[1] redirect := s[1]
c, err := a.csrfStore.GetCSRF(r) c, err := a.csrfStore.GetCSRF(r)
defer a.csrfStore.ClearCSRF(w, r) defer a.csrfStore.ClearCSRF(w, r)
if err != nil || c.Value != nonce { if err != nil || c.Value != nonce {
return "", xerrors.Errorf("csrf failure: %w", err) return nil, fmt.Errorf("csrf failure: %w", err)
} }
redirectURL, err := urlutil.ParseAndValidateURL(redirect) redirectURL, err := urlutil.ParseAndValidateURL(redirect)
if err != nil { if err != nil {
return "", httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err) return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err)
} }
// sanity check, we are redirecting back to the same subdomain right? // sanity check, we are redirecting back to the same subdomain right?
if !middleware.SameDomain(redirectURL, a.RedirectURL) { if !middleware.SameDomain(redirectURL, a.RedirectURL) {
return "", httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil) return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil)
} }
if err := a.sessionStore.SaveSession(w, r, session); err != nil { if err := a.sessionStore.SaveSession(w, r, session); err != nil {
return "", xerrors.Errorf("failed saving new session: %w", err) return nil, fmt.Errorf("failed saving new session: %w", err)
} }
return redirect, nil return redirectURL, nil
} }
// ExchangeToken takes an identity provider issued JWT as input ('id_token) // ExchangeToken takes an identity provider issued JWT as input ('id_token)
@ -263,16 +262,32 @@ func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
} }
code := r.Form.Get("id_token") code := r.Form.Get("id_token")
if code == "" { if code == "" {
httputil.ErrorResponse(w, r, httputil.Error("provider missing id token", http.StatusBadRequest, nil)) httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
return return
} }
session, err := a.provider.IDTokenToSession(r.Context(), code) session, err := a.provider.IDTokenToSession(r.Context(), code)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("could not exchange identity for session", http.StatusInternalServerError, err)) httputil.ErrorResponse(w, r, err)
return return
} }
if err := a.restStore.SaveSession(w, r, session); err != nil { encToken, err := sessions.MarshalSession(session, a.cipher)
httputil.ErrorResponse(w, r, httputil.Error("failed returning new session", http.StatusInternalServerError, err)) if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return
} }
restSession := struct {
Token string
Expiry time.Time `json:",omitempty"`
}{
Token: encToken,
Expiry: session.RefreshDeadline,
}
jsonBytes, err := json.Marshal(restSession)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonBytes)
} }

View file

@ -68,22 +68,25 @@ func TestAuthenticate_SignIn(t *testing.T) {
state string state string
redirectURI string redirectURI string
session sessions.SessionStore session sessions.SessionStore
restStore sessions.SessionStore
provider identity.MockProvider provider identity.MockProvider
cipher cryptutil.Cipher cipher cryptutil.Cipher
wantCode int wantCode int
}{ }{
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, {"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
{"session refresh error", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, {"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
{"session save after refresh error", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, {"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh
{"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, // {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
// actually caught by go's handler, but we should keep the test. // actually caught by go's handler, but we should keep the test.
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError}, {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -178,10 +181,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int wantCode int
wantBody string wantBody string
}{ }{
{"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, {"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"}, {"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
{"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""}, {"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""},
{"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, {"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -288,19 +291,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string want string
wantCode int wantCode int
}{ }{
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound}, {"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound},
{"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, {"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
{"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, {"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -336,7 +339,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
} }
func TestAuthenticate_ExchangeToken(t *testing.T) { func TestAuthenticate_ExchangeToken(t *testing.T) {
cipher := &cryptutil.MockCipher{}
tests := []struct { tests := []struct {
name string name string
method string method string
@ -346,18 +348,18 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
provider identity.MockProvider provider identity.MockProvider
want string want string
}{ }{
{"good", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, ""}, {"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
{"could not exchange identity for session", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, "could not exchange identity for session"}, {"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
{"missing token", http.MethodPost, "", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, "missing id token"}, {"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
{"save error", http.MethodPost, "token", &sessions.MockSessionStore{SaveError: errors.New("error")}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, "failed returning new session"}, {"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
{"malformed form", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, ""}, {"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{ a := &Authenticate{
restStore: tt.restStore, cipher: tt.cipher,
cipher: tt.cipher, provider: tt.provider,
provider: tt.provider, sessionStore: tt.restStore,
} }
form := url.Values{} form := url.Values{}
if tt.idToken != "" { if tt.idToken != "" {
@ -370,6 +372,7 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
} }
r := httptest.NewRequest(tt.method, "/", strings.NewReader(rawForm)) r := httptest.NewRequest(tt.method, "/", strings.NewReader(rawForm))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()

View file

@ -21,7 +21,6 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
pbAuthenticate "github.com/pomerium/pomerium/proto/authenticate"
pbAuthorize "github.com/pomerium/pomerium/proto/authorize" pbAuthorize "github.com/pomerium/pomerium/proto/authorize"
"github.com/pomerium/pomerium/proxy" "github.com/pomerium/pomerium/proxy"
) )
@ -47,7 +46,7 @@ func main() {
mux := http.NewServeMux() mux := http.NewServeMux()
grpcServer := setupGRPCServer(opt) grpcServer := setupGRPCServer(opt)
_, err = newAuthenticateService(*opt, mux, grpcServer) _, err = newAuthenticateService(*opt, mux)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate") log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
} }
@ -62,7 +61,6 @@ func main() {
log.Fatal().Err(err).Msg("cmd/pomerium: proxy") log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
} }
if proxy != nil { if proxy != nil {
defer proxy.AuthenticateClient.Close()
defer proxy.AuthorizeClient.Close() defer proxy.AuthorizeClient.Close()
} }
@ -82,7 +80,7 @@ func main() {
os.Exit(0) os.Exit(0)
} }
func newAuthenticateService(opt config.Options, mux *http.ServeMux, rpc *grpc.Server) (*authenticate.Authenticate, error) { func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) {
if !config.IsAuthenticate(opt.Services) { if !config.IsAuthenticate(opt.Services) {
return nil, nil return nil, nil
} }
@ -90,7 +88,6 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux, rpc *grpc.Se
if err != nil { if err != nil {
return nil, err return nil, err
} }
pbAuthenticate.RegisterAuthenticatorServer(rpc, service)
mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler()) mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler())
return service, nil return service, nil
} }
@ -164,7 +161,7 @@ func configToServerOptions(opt *config.Options) *httputil.ServerOptions {
func setupMetrics(opt *config.Options) { func setupMetrics(opt *config.Options) {
if opt.MetricsAddr != "" { if opt.MetricsAddr != "" {
if handler, err := metrics.PrometheusHandler(); err != nil { if handler, err := metrics.PrometheusHandler(); err != nil {
log.Error().Err(err).Msg("cmd/pomerium: couldn't start metrics server") log.Error().Err(err).Msg("cmd/pomerium: metrics failed to start")
} else { } else {
metrics.SetBuildInfo(opt.Services) metrics.SetBuildInfo(opt.Services)
metrics.RegisterInfoMetrics() metrics.RegisterInfoMetrics()

View file

@ -21,9 +21,6 @@ import (
) )
func Test_newAuthenticateService(t *testing.T) { func Test_newAuthenticateService(t *testing.T) {
grpcAuth := middleware.NewSharedSecretCred("test")
grpcOpts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcAuth.ValidateRequest)}
grpcServer := grpc.NewServer(grpcOpts...)
mux := http.NewServeMux() mux := http.NewServeMux()
tests := []struct { tests := []struct {
@ -56,7 +53,7 @@ func Test_newAuthenticateService(t *testing.T) {
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value")) testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
} }
_, err = newAuthenticateService(*testOpts, mux, grpcServer) _, err = newAuthenticateService(*testOpts, mux)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -176,7 +176,6 @@ Go to **Environment** tab.
| SHARED_SECRET | output of `head -c32 /dev/urandom | base64` | | SHARED_SECRET | output of `head -c32 /dev/urandom | base64` |
| AUTHORIZE_SERVICE_URL | `https://localhost` | | AUTHORIZE_SERVICE_URL | `https://localhost` |
| AUTHENTICATE_SERVICE_URL | `https://authenticate.int.nas.example` | | AUTHENTICATE_SERVICE_URL | `https://authenticate.int.nas.example` |
| AUTHENTICATE_INTERNAL_URL | `https://localhost` |
For a detailed explanation, and additional options, please refer to the [configuration variable docs]. Also note, though not covered in this guide, settings can be made via a mounted configuration file. For a detailed explanation, and additional options, please refer to the [configuration variable docs]. Also note, though not covered in this guide, settings can be made via a mounted configuration file.

View file

@ -48,7 +48,6 @@ services:
- SERVICES=proxy - SERVICES=proxy
# IMPORTANT! If you are running pomerium behind another ingress (loadbalancer/firewall/etc) # IMPORTANT! If you are running pomerium behind another ingress (loadbalancer/firewall/etc)
# you must tell pomerium proxy how to communicate using an internal hostname for RPC # you must tell pomerium proxy how to communicate using an internal hostname for RPC
- AUTHENTICATE_INTERNAL_URL=https://pomerium-authenticate
- AUTHORIZE_SERVICE_URL=https://pomerium-authorize - AUTHORIZE_SERVICE_URL=https://pomerium-authorize
# When communicating internally, rPC is going to get a name conflict expecting an external # When communicating internally, rPC is going to get a name conflict expecting an external
# facing certificate name (i.e. authenticate-service.local vs *.corp.example.com). # facing certificate name (i.e. authenticate-service.local vs *.corp.example.com).

View file

@ -1,6 +1,5 @@
# Main configuration flags : https://www.pomerium.io/reference/ # Main configuration flags : https://www.pomerium.io/reference/
authenticate_service_url: https://authenticate.corp.beyondperimeter.com authenticate_service_url: https://authenticate.corp.beyondperimeter.com
authenticate_internal_url: https://pomerium-authenticate-service.default.svc.cluster.local
authorize_service_url: https://pomerium-authorize-service.default.svc.cluster.local authorize_service_url: https://pomerium-authorize-service.default.svc.cluster.local
override_certificate_name: "*.corp.beyondperimeter.com" override_certificate_name: "*.corp.beyondperimeter.com"

View file

@ -146,7 +146,7 @@ Timeouts set the global server timeouts. For route-specific timeouts, see [polic
## GRPC Options ## GRPC Options
These settings control upstream connections to the Authorize and Authenticate services. These settings control upstream connections to the Authorize service.
### GRPC Client Timeout ### GRPC Client Timeout
@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
| Config Key | Description | Required | | Config Key | Description | Required |
| :--------------- | :---------------------------------------------------------------- | -------- | | :--------------- | :---------------------------------------------------------------- | -------- |
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ | | tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ | | tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
### Jaeger ### Jaeger
@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
| Config Key | Description | Required | | Config Key | Description | Required |
| :-------------------------------- | :------------------------------------------ | -------- | | :-------------------------------- | :------------------------------------------ | -------- |
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ | | tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ | | tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
#### Example #### Example
@ -464,16 +464,6 @@ Signing key is the base64 encoded key used to sign outbound requests. For more i
Authenticate Service URL is the externally accessible URL for the authenticate service. Authenticate Service URL is the externally accessible URL for the authenticate service.
## Authenticate Internal Service URL
- Environmental Variable: `AUTHENTICATE_INTERNAL_URL`
- Config File Key: `authenticate_internal_url`
- Type: `URL`
- Optional
- Example: `https://pomerium-authenticate-service.default.svc.cluster.local`
Authenticate Internal Service URL is the internally routed dns name of the authenticate service. This setting is typically used with load balancers that do not gRPC, thus allowing you to specify an internally accessible name.
## Authorize Service URL ## Authorize Service URL
- Environmental Variable: `AUTHORIZE_SERVICE_URL` - Environmental Variable: `AUTHORIZE_SERVICE_URL`

1
go.mod
View file

@ -26,7 +26,6 @@ require (
golang.org/x/net v0.0.0-20190611141213-3f473d35a33a golang.org/x/net v0.0.0-20190611141213-3f473d35a33a
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae // indirect golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae // indirect
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
google.golang.org/api v0.6.0 google.golang.org/api v0.6.0
google.golang.org/appengine v1.6.1 // indirect google.golang.org/appengine v1.6.1 // indirect
google.golang.org/genproto v0.0.0-20190611190212-a7e196e89fd3 // indirect google.golang.org/genproto v0.0.0-20190611190212-a7e196e89fd3 // indirect

2
go.sum
View file

@ -257,8 +257,6 @@ golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBn
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/api v0.3.2/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/api v0.3.2/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=

View file

@ -97,13 +97,6 @@ type Options struct {
// (sudo) access including the ability to impersonate other users' access // (sudo) access including the ability to impersonate other users' access
Administrators []string `mapstructure:"administrators"` Administrators []string `mapstructure:"administrators"`
// AuthenticateInternalAddr is used override the routable destination of
// authenticate service's GRPC endpoint.
// NOTE: As many load balancers do not support externally routed gRPC so
// this may be an internal location.
AuthenticateInternalAddrString string `mapstructure:"authenticate_internal_url"`
AuthenticateInternalAddr *url.URL
// AuthorizeURL is the routable destination of the authorize service's // AuthorizeURL is the routable destination of the authorize service's
// gRPC endpoint. NOTE: As many load balancers do not support // gRPC endpoint. NOTE: As many load balancers do not support
// externally routed gRPC so this may be an internal location. // externally routed gRPC so this may be an internal location.
@ -246,13 +239,6 @@ func (o *Options) Validate() error {
o.AuthorizeURL = u o.AuthorizeURL = u
} }
if o.AuthenticateInternalAddrString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddrString)
if err != nil {
return fmt.Errorf("bad authenticate-internal-addr %s : %v", o.AuthenticateInternalAddrString, err)
}
o.AuthenticateInternalAddr = u
}
if o.PolicyFile != "" { if o.PolicyFile != "" {
return errors.New("policy file setting is deprecated") return errors.New("policy file setting is deprecated")
} }

View file

@ -337,7 +337,7 @@ func TestNewOptions(t *testing.T) {
func TestOptionsFromViper(t *testing.T) { func TestOptionsFromViper(t *testing.T) {
opts := []cmp.Option{ opts := []cmp.Option{
cmpopts.IgnoreFields(Options{}, "AuthenticateInternalAddr", "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"), cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
cmpopts.IgnoreFields(Policy{}, "Source", "Destination"), cmpopts.IgnoreFields(Policy{}, "Source", "Destination"),
} }
@ -361,21 +361,6 @@ func TestOptionsFromViper(t *testing.T) {
"X-XSS-Protection": "1; mode=block", "X-XSS-Protection": "1; mode=block",
}}, }},
false}, false},
{"good with authenticate internal url",
[]byte(`{"authenticate_internal_url": "https://internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
AuthenticateInternalAddrString: "https://internal.example",
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
CookieName: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
Headers: map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}},
false},
{"good disable header", {"good disable header",
[]byte(`{"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`), []byte(`{"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{ &Options{
@ -385,7 +370,6 @@ func TestOptionsFromViper(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
Headers: map[string]string{}}, Headers: map[string]string{}},
false}, false},
{"bad authenticate internal url", []byte(`{"authenticate_internal_url": "internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`), nil, true},
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true}, {"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true}, {"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},

View file

@ -67,6 +67,18 @@ func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
}, nil }, nil
} }
// NewCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
func NewCipherFromBase64(s string) (*XChaCha20Cipher, error) {
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("cryptutil: invalid base64: %v", err)
}
if len(decoded) != 32 {
return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(decoded))
}
return NewCipher(decoded)
}
// GenerateNonce generates a random nonce. // GenerateNonce generates a random nonce.
// Panics if source of randomness fails. // Panics if source of randomness fails.
func (c *XChaCha20Cipher) GenerateNonce() []byte { func (c *XChaCha20Cipher) GenerateNonce() []byte {

View file

@ -259,3 +259,26 @@ func TestNewCipher(t *testing.T) {
}) })
} }
} }
func TestNewCipherFromBase64(t *testing.T) {
tests := []struct {
name string
s string
wantErr bool
}{
{"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false},
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
{"key too long", GenerateRandomString(33), true},
{"bad base 64", string(GenerateKey()), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewCipherFromBase64(tt.s)
if (err != nil) != tt.wantErr {
t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -1,5 +1,6 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import ( import (
"encoding/base64"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -48,15 +49,20 @@ type ES256Signer struct {
NotBefore jwt.NumericDate `json:"nbf,omitempty"` NotBefore jwt.NumericDate `json:"nbf,omitempty"`
} }
// NewES256Signer creates an Elliptic Curve, NIST P-256 (aka secp256r1 aka prime256v1) JWT signer. // NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer
// from a base64 encoded private key.
// //
// RSA is not supported due to performance considerations of needing to sign each request. // RSA is not supported due to performance considerations of needing to sign each request.
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune // Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
// to length extension attacks. // to length extension attacks.
// See also: // See also:
// - https://cloud.google.com/iot/docs/how-tos/credentials/keys // - https://cloud.google.com/iot/docs/how-tos/credentials/keys
func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) { func NewES256Signer(privKey, audience string) (*ES256Signer, error) {
key, err := DecodePrivateKey(privKey) decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey)
if err != nil {
return nil, err
}
key, err := DecodePrivateKey(decodedSigningKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("cryptutil: parsing key failed %v", err) return nil, fmt.Errorf("cryptutil: parsing key failed %v", err)
} }

View file

@ -1,11 +1,12 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import ( import (
"encoding/base64"
"testing" "testing"
) )
func TestES256Signer(t *testing.T) { func TestES256Signer(t *testing.T) {
signer, err := NewES256Signer([]byte(pemECPrivateKeyP256), "destination-url") signer, err := NewES256Signer(base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "destination-url")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -25,12 +26,13 @@ func TestNewES256Signer(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
privKey []byte privKey string
audience string audience string
wantErr bool wantErr bool
}{ }{
{"working example", []byte(pemECPrivateKeyP256), "some-domain.com", false}, {"working example", base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "some-domain.com", false},
{"bad private key", []byte(garbagePEM), "some-domain.com", true}, {"bad private key", base64.StdEncoding.EncodeToString([]byte(garbagePEM)), "some-domain.com", true},
{"bad base64 key", garbagePEM, "some-domain.com", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -2,20 +2,18 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"golang.org/x/xerrors"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
) )
// Error formats creates a HTTP error with code, user friendly (and safe) error // Error formats creates a HTTP error with code, user friendly (and safe) error
// message. If nil or empty: // message. If nil or empty, HTTP status code defaults to 500 and message
// HTTP status code defaults to 500. // defaults to the text of the status code.
// Message defaults to the text of the status code.
func Error(message string, code int, err error) error { func Error(message string, code int, err error) error {
if code == 0 { if code == 0 {
code = http.StatusInternalServerError code = http.StatusInternalServerError
@ -45,7 +43,9 @@ func (e *httpError) Error() string {
func (e *httpError) Unwrap() error { return e.Err } func (e *httpError) Unwrap() error { return e.Err }
// Timeout reports whether this error represents a user debuggable error. // Timeout reports whether this error represents a user debuggable error.
func (e *httpError) Debugable() bool { return e.Code == http.StatusUnauthorized } func (e *httpError) Debugable() bool {
return e.Code == http.StatusUnauthorized || e.Code == http.StatusForbidden
}
// ErrorResponse renders an error page given an error. If the error is a // ErrorResponse renders an error page given an error. If the error is a
// http error from this package, a user friendly message is set, http status code, // http error from this package, a user friendly message is set, http status code,
@ -57,11 +57,12 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
var requestID string var requestID string
var httpError *httpError var httpError *httpError
// if this is an HTTPError, we can add some additional useful information // if this is an HTTPError, we can add some additional useful information
if xerrors.As(e, &httpError) { if errors.As(e, &httpError) {
canDebug = httpError.Debugable() canDebug = httpError.Debugable()
statusCode = httpError.Code statusCode = httpError.Code
errorString = httpError.Message errorString = httpError.Message
} }
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error") log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error")
if id, ok := log.IDFromRequest(r); ok { if id, ok := log.IDFromRequest(r); ok {
@ -71,7 +72,7 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
var response struct { var response struct {
Error string `json:"error"` Error string `json:"error"`
} }
response.Error = e.Error() response.Error = errorString
writeJSONResponse(rw, statusCode, response) writeJSONResponse(rw, statusCode, response)
} else { } else {
rw.WriteHeader(statusCode) rw.WriteHeader(statusCode)

View file

@ -129,8 +129,7 @@ func (p *GoogleProvider) GetSignInURL(state string) string {
// Authenticate creates an identity session with google from a authorization code, and follows up // Authenticate creates an identity session with google from a authorization code, and follows up
// call to the admin/group api to check what groups the user is in. // call to the admin/group api to check what groups the user is in.
func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
// convert authorization code into a token
oauth2Token, err := p.oauth.Exchange(ctx, code) oauth2Token, err := p.oauth.Exchange(ctx, code)
if err != nil { if err != nil {
return nil, fmt.Errorf("identity/google: token exchange failed %v", err) return nil, fmt.Errorf("identity/google: token exchange failed %v", err)
@ -153,7 +152,7 @@ func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessio
// Refresh renews a user's session using an oidc refresh token withoutreprompting the user. // Refresh renews a user's session using an oidc refresh token withoutreprompting the user.
// Group membership is also refreshed. // Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" { if s.RefreshToken == "" {
return nil, errors.New("identity: missing refresh token") return nil, errors.New("identity: missing refresh token")
} }
@ -180,7 +179,7 @@ func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.SessionState)
// IDTokenToSession takes an identity provider issued JWT as input ('id_token') // IDTokenToSession takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must // and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id. // match Pomerium's client_id.
func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) { func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
idToken, err := p.verifier.Verify(ctx, rawIDToken) idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("identity/google: could not verify id_token %v", err) return nil, fmt.Errorf("identity/google: could not verify id_token %v", err)
@ -200,7 +199,7 @@ func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string
return nil, fmt.Errorf("identity/google: could not retrieve groups %v", err) return nil, fmt.Errorf("identity/google: could not retrieve groups %v", err)
} }
return &sessions.SessionState{ return &sessions.State{
IDToken: rawIDToken, IDToken: rawIDToken,
RefreshDeadline: idToken.Expiry.Truncate(time.Second), RefreshDeadline: idToken.Expiry.Truncate(time.Second),
Email: claims.Email, Email: claims.Email,

View file

@ -74,7 +74,7 @@ func NewAzureProvider(p *Provider) (*AzureProvider, error) {
// Authenticate creates an identity session with azure from a authorization code, and follows up // Authenticate creates an identity session with azure from a authorization code, and follows up
// call to the groups api to check what groups the user is in. // call to the groups api to check what groups the user is in.
func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
// convert authorization code into a token // convert authorization code into a token
oauth2Token, err := p.oauth.Exchange(ctx, code) oauth2Token, err := p.oauth.Exchange(ctx, code)
if err != nil { if err != nil {
@ -104,7 +104,7 @@ func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*session
// IDTokenToSession takes an identity provider issued JWT as input ('id_token') // IDTokenToSession takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must // and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id. // match Pomerium's client_id.
func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) { func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
idToken, err := p.verifier.Verify(ctx, rawIDToken) idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err) return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err)
@ -118,7 +118,7 @@ func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string)
return nil, fmt.Errorf("identity/microsoft: failed to parse id_token claims %v", err) return nil, fmt.Errorf("identity/microsoft: failed to parse id_token claims %v", err)
} }
return &sessions.SessionState{ return &sessions.State{
IDToken: rawIDToken, IDToken: rawIDToken,
RefreshDeadline: idToken.Expiry.Truncate(time.Second), RefreshDeadline: idToken.Expiry.Truncate(time.Second),
Email: claims.Email, Email: claims.Email,
@ -146,7 +146,7 @@ func (p *AzureProvider) GetSignInURL(state string) string {
// Refresh renews a user's session using an oid refresh token without reprompting the user. // Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed. // Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" { if s.RefreshToken == "" {
return nil, errors.New("identity/microsoft: missing refresh token") return nil, errors.New("identity/microsoft: missing refresh token")
} }

View file

@ -8,25 +8,25 @@ import (
// MockProvider provides a mocked implementation of the providers interface. // MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct { type MockProvider struct {
AuthenticateResponse sessions.SessionState AuthenticateResponse sessions.State
AuthenticateError error AuthenticateError error
IDTokenToSessionResponse sessions.SessionState IDTokenToSessionResponse sessions.State
IDTokenToSessionError error IDTokenToSessionError error
ValidateResponse bool ValidateResponse bool
ValidateError error ValidateError error
RefreshResponse *sessions.SessionState RefreshResponse *sessions.State
RefreshError error RefreshError error
RevokeError error RevokeError error
GetSignInURLResponse string GetSignInURLResponse string
} }
// Authenticate is a mocked providers function. // Authenticate is a mocked providers function.
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
return &mp.AuthenticateResponse, mp.AuthenticateError return &mp.AuthenticateResponse, mp.AuthenticateError
} }
// IDTokenToSession is a mocked providers function. // IDTokenToSession is a mocked providers function.
func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.SessionState, error) { func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.State, error) {
return &mp.IDTokenToSessionResponse, mp.IDTokenToSessionError return &mp.IDTokenToSessionResponse, mp.IDTokenToSessionError
} }
@ -36,7 +36,7 @@ func (mp MockProvider) Validate(ctx context.Context, s string) (bool, error) {
} }
// Refresh is a mocked providers function. // Refresh is a mocked providers function.
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
return mp.RefreshResponse, mp.RefreshError return mp.RefreshResponse, mp.RefreshError
} }

View file

@ -91,7 +91,7 @@ type accessToken struct {
// Refresh renews a user's session using an oid refresh token without reprompting the user. // Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed. If configured properly, Okta is we can configure the access token // Group membership is also refreshed. If configured properly, Okta is we can configure the access token
// to include group membership claims which allows us to avoid a follow up oauth2 call. // to include group membership claims which allows us to avoid a follow up oauth2 call.
func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" { if s.RefreshToken == "" {
return nil, errors.New("identity/okta: missing refresh token") return nil, errors.New("identity/okta: missing refresh token")
} }

View file

@ -93,7 +93,7 @@ func (p *OneLoginProvider) GetSignInURL(state string) string {
// Refresh renews a user's session using an oid refresh token without reprompting the user. // Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed. // Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" { if s.RefreshToken == "" {
return nil, errors.New("identity/microsoft: missing refresh token") return nil, errors.New("identity/microsoft: missing refresh token")
} }

View file

@ -45,10 +45,10 @@ type UserGrouper interface {
// Authenticator is an interface representing the ability to authenticate with an identity provider. // Authenticator is an interface representing the ability to authenticate with an identity provider.
type Authenticator interface { type Authenticator interface {
Authenticate(context.Context, string) (*sessions.SessionState, error) Authenticate(context.Context, string) (*sessions.State, error)
IDTokenToSession(context.Context, string) (*sessions.SessionState, error) IDTokenToSession(context.Context, string) (*sessions.State, error)
Validate(context.Context, string) (bool, error) Validate(context.Context, string) (bool, error)
Refresh(context.Context, *sessions.SessionState) (*sessions.SessionState, error) Refresh(context.Context, *sessions.State) (*sessions.State, error)
Revoke(string) error Revoke(string) error
GetSignInURL(state string) string GetSignInURL(state string) string
} }
@ -131,7 +131,7 @@ func (p *Provider) Validate(ctx context.Context, idToken string) (bool, error) {
// IDTokenToSession takes an identity provider issued JWT as input ('id_token') // IDTokenToSession takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must // and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id. // match Pomerium's client_id.
func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) { func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
idToken, err := p.verifier.Verify(ctx, rawIDToken) idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("identity: could not verify id_token: %v", err) return nil, fmt.Errorf("identity: could not verify id_token: %v", err)
@ -146,7 +146,7 @@ func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*se
return nil, fmt.Errorf("identity: failed to parse id_token claims: %v", err) return nil, fmt.Errorf("identity: failed to parse id_token claims: %v", err)
} }
return &sessions.SessionState{ return &sessions.State{
IDToken: rawIDToken, IDToken: rawIDToken,
User: idToken.Subject, User: idToken.Subject,
RefreshDeadline: idToken.Expiry.Truncate(time.Second), RefreshDeadline: idToken.Expiry.Truncate(time.Second),
@ -157,7 +157,7 @@ func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*se
} }
// Authenticate creates a session with an identity provider from a authorization code // Authenticate creates a session with an identity provider from a authorization code
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) { func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
// exchange authorization for a oidc token // exchange authorization for a oidc token
oauth2Token, err := p.oauth.Exchange(ctx, code) oauth2Token, err := p.oauth.Exchange(ctx, code)
if err != nil { if err != nil {
@ -181,7 +181,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.Ses
// Refresh renews a user's session using therefresh_token without reprompting // Refresh renews a user's session using therefresh_token without reprompting
// the user. If supported, group membership is also refreshed. // the user. If supported, group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *Provider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) { func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" { if s.RefreshToken == "" {
return nil, errors.New("identity: missing refresh token") return nil, errors.New("identity: missing refresh token")
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
) )
@ -70,7 +71,7 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
return return
} }
redirectURI, err := url.Parse(r.Form.Get("redirect_uri")) redirectURI, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("bad redirect_uri", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("bad redirect_uri", http.StatusBadRequest, err))
return return
@ -131,7 +132,7 @@ func ValidateHost(validHost func(host string) bool) func(next http.Handler) http
defer span.End() defer span.End()
if !validHost(r.Host) { if !validHost(r.Host) {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not a known route.", r.Host), http.StatusNotFound, nil)) httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
return return
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
@ -168,7 +169,7 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" { if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false return false
} }
_, err := url.Parse(redirectURI) _, err := urlutil.ParseAndValidateURL(redirectURI)
if err != nil { if err != nil {
return false return false
} }

View file

@ -1,6 +1,7 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware" package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -40,7 +41,7 @@ func TestSignRequest(t *testing.T) {
}) })
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
signer, err := cryptutil.NewES256Signer([]byte(exampleKey), "audience") signer, err := cryptutil.NewES256Signer(base64.StdEncoding.EncodeToString([]byte(exampleKey)), "audience")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,7 +1,6 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -11,15 +10,17 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
// ErrInvalidSession is an error for invalid sessions.
var ErrInvalidSession = errors.New("internal/sessions: invalid session")
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if // ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
// the cookie is multi-part or not. This constant *should not* be valid // the cookie is multi-part or not. This constant *should not* be valid
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes. // base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
const ChunkedCanaryByte byte = '%' const ChunkedCanaryByte byte = '%'
// DefaultBearerTokenHeader is default header name for the authorization bearer
// token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
const DefaultBearerTokenHeader = "Authorization"
// MaxChunkSize sets the upper bound on a cookie chunks payload value. // MaxChunkSize sets the upper bound on a cookie chunks payload value.
// Note, this should be lower than the actual cookie's max size (4096 bytes) // Note, this should be lower than the actual cookie's max size (4096 bytes)
// which includes metadata. // which includes metadata.
@ -29,39 +30,27 @@ const MaxChunkSize = 3800
// set to prevent any abuse. // set to prevent any abuse.
const MaxNumChunks = 5 const MaxNumChunks = 5
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
LoadSession(*http.Request) (*SessionState, error)
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
}
// CookieStore represents all the cookie related configurations // CookieStore represents all the cookie related configurations
type CookieStore struct { type CookieStore struct {
Name string Name string
CookieCipher cryptutil.Cipher CookieCipher cryptutil.Cipher
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration CookieRefresh time.Duration
CookieSecure bool CookieSecure bool
CookieHTTPOnly bool CookieHTTPOnly bool
CookieDomain string CookieDomain string
BearerTokenHeader string
} }
// CookieStoreOptions holds options for CookieStore // CookieStoreOptions holds options for CookieStore
type CookieStoreOptions struct { type CookieStoreOptions struct {
Name string Name string
CookieSecure bool CookieSecure bool
CookieHTTPOnly bool CookieHTTPOnly bool
CookieDomain string CookieDomain string
CookieExpire time.Duration BearerTokenHeader string
CookieCipher cryptutil.Cipher CookieExpire time.Duration
CookieCipher cryptutil.Cipher
} }
// NewCookieStore returns a new session with ciphers for each of the cookie secrets // NewCookieStore returns a new session with ciphers for each of the cookie secrets
@ -72,23 +61,28 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
if opts.CookieCipher == nil { if opts.CookieCipher == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil") return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
} }
if opts.BearerTokenHeader == "" {
opts.BearerTokenHeader = DefaultBearerTokenHeader
}
return &CookieStore{ return &CookieStore{
Name: opts.Name, Name: opts.Name,
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain, CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieCipher: opts.CookieCipher, CookieCipher: opts.CookieCipher,
BearerTokenHeader: opts.BearerTokenHeader,
}, nil }, nil
} }
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host domain := req.Host
if name == s.csrfName() { if name == cs.csrfName() {
domain = req.Host domain = req.Host
} else if s.CookieDomain != "" { } else if cs.CookieDomain != "" {
domain = s.CookieDomain domain = cs.CookieDomain
} else { } else {
domain = splitDomain(domain) domain = splitDomain(domain)
} }
@ -101,8 +95,8 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
Value: value, Value: value,
Path: "/", Path: "/",
Domain: domain, Domain: domain,
HttpOnly: s.CookieHTTPOnly, HttpOnly: cs.CookieHTTPOnly,
Secure: s.CookieSecure, Secure: cs.CookieSecure,
} }
// only set an expiration if we want one, otherwise default to non perm session based // only set an expiration if we want one, otherwise default to non perm session based
if expiration != 0 { if expiration != 0 {
@ -111,22 +105,20 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
return c return c
} }
func (s *CookieStore) csrfName() string { func (cs *CookieStore) csrfName() string {
return fmt.Sprintf("%s_csrf", s.Name) return fmt.Sprintf("%s_csrf", cs.Name)
} }
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return s.makeCookie(req, s.Name, value, expiration, now) return cs.makeCookie(req, cs.Name, value, expiration, now)
} }
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time. func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
// CSRF cookies should be scoped to the actual domain return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return s.makeCookie(req, s.csrfName(), value, expiration, now)
} }
func (s *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize { if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
return return
@ -142,9 +134,9 @@ func (s *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i) nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
nc.Value = c nc.Value = c
} }
fmt.Println(i)
http.SetCookie(w, &nc) http.SetCookie(w, &nc)
} }
} }
func chunk(s string, size int) []string { func chunk(s string, size int) []string {
@ -159,43 +151,54 @@ func chunk(s string, size int) []string {
} }
// ClearCSRF clears the CSRF cookie from the request // ClearCSRF clears the CSRF cookie from the request
func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) { func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
} }
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request // SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) { func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now())) http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
} }
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request // GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) { func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
return req.Cookie(s.csrfName()) c, err := req.Cookie(cs.csrfName())
if err != nil {
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
}
return c, nil
} }
// ClearSession clears the session cookie from a request // ClearSession clears the session cookie from a request
func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) { func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now())) http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
} }
func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) { func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
s.setCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now())) cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
} }
// LoadSession returns a SessionState from the cookie in the request. func loadBearerToken(r *http.Request, headerKey string) string {
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) { authHeader := r.Header.Get(headerKey)
c, err := req.Cookie(s.Name) split := strings.Split(authHeader, "Bearer")
if authHeader == "" || len(split) != 2 {
return ""
}
return strings.TrimSpace(split[1])
}
func loadChunkedCookie(r *http.Request, cookieName string) string {
c, err := r.Cookie(cookieName)
if err != nil { if err != nil {
return nil, err // http.ErrNoCookie return ""
} }
cipherText := c.Value cipherText := c.Value
// if the first byte is our canary byte, we need to handle the multipart bit // if the first byte is our canary byte, we need to handle the multipart bit
if []byte(c.Value)[0] == ChunkedCanaryByte { if []byte(c.Value)[0] == ChunkedCanaryByte {
var b strings.Builder var b strings.Builder
fmt.Fprintf(&b, "%s", cipherText[1:]) fmt.Fprintf(&b, "%s", cipherText[1:])
for i := 1; i < MaxNumChunks; i++ { for i := 1; i <= MaxNumChunks; i++ {
next, err := req.Cookie(fmt.Sprintf("%s_%d", s.Name, i)) next, err := r.Cookie(fmt.Sprintf("%s_%d", cookieName, i))
if err != nil { if err != nil {
break // break if we can't find the next cookie break // break if we can't find the next cookie
} }
@ -203,20 +206,32 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
} }
cipherText = b.String() cipherText = b.String()
} }
session, err := UnmarshalSession(cipherText, s.CookieCipher) return cipherText
}
// LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
cipherText := loadChunkedCookie(req, cs.Name)
if cipherText == "" {
cipherText = loadBearerToken(req, cs.BearerTokenHeader)
}
if cipherText == "" {
return nil, ErrEmptySession
}
session, err := UnmarshalSession(cipherText, cs.CookieCipher)
if err != nil { if err != nil {
return nil, ErrInvalidSession return nil, err
} }
return session, nil return session, nil
} }
// SaveSession saves a session state to a request sessions. // SaveSession saves a session state to a request sessions.
func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error { func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(sessionState, s.CookieCipher) value, err := MarshalSession(s, cs.CookieCipher)
if err != nil { if err != nil {
return err return err
} }
s.setSessionCookie(w, req, value) cs.setSessionCookie(w, req, value)
return nil return nil
} }

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
@ -49,30 +50,33 @@ func TestNewCookieStore(t *testing.T) {
}{ }{
{"good", {"good",
&CookieStoreOptions{ &CookieStoreOptions{
Name: "_cookie", Name: "_cookie",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher, CookieCipher: cipher,
BearerTokenHeader: "Authorization",
}, },
&CookieStore{ &CookieStore{
Name: "_cookie", Name: "_cookie",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher, CookieCipher: cipher,
BearerTokenHeader: "Authorization",
}, },
false}, false},
{"missing name", {"missing name",
&CookieStoreOptions{ &CookieStoreOptions{
Name: "", Name: "",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher, CookieCipher: cipher,
BearerTokenHeader: "Authorization",
}, },
nil, nil,
true}, true},
@ -95,8 +99,12 @@ func TestNewCookieStore(t *testing.T) {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { cmpOpts := []cmp.Option{
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want) cmpopts.IgnoreUnexported(cryptutil.XChaCha20Cipher{}),
}
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("NewCookieStore() = %s", diff)
} }
}) })
} }
@ -211,15 +219,15 @@ func TestCookieStore_SaveSession(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tests := []struct { tests := []struct {
name string name string
sessionState *SessionState State *State
cipher cryptutil.Cipher cipher cryptutil.Cipher
wantErr bool wantErr bool
wantLoadErr bool wantLoadErr bool
}{ }{
{"good", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, {"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
{"bad cipher", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true}, {"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true},
{"huge cookie", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, {"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -234,12 +242,12 @@ func TestCookieStore_SaveSession(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr { if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
} }
r = httptest.NewRequest("GET", "/", nil) r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() { for _, cookie := range w.Result().Cookies() {
t.Log(cookie) // t.Log(cookie)
r.AddCookie(cookie) r.AddCookie(cookie)
} }
@ -248,8 +256,10 @@ func TestCookieStore_SaveSession(t *testing.T) {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr) t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return return
} }
if err == nil && !reflect.DeepEqual(state, tt.sessionState) { if err == nil {
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState) if diff := cmp.Diff(state, tt.State); diff != "" {
t.Errorf("CookieStore.LoadSession() got = %s", diff)
}
} }
}) })
} }
@ -291,18 +301,18 @@ func TestMockSessionStore(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mockCSRF *MockSessionStore mockCSRF *MockSessionStore
saveSession *SessionState saveSession *State
wantLoadErr bool wantLoadErr bool
wantSaveErr bool wantSaveErr bool
}{ }{
{"basic", {"basic",
&MockSessionStore{ &MockSessionStore{
ResponseSession: "test", ResponseSession: "test",
Session: &SessionState{AccessToken: "AccessToken"}, Session: &State{AccessToken: "AccessToken"},
SaveError: nil, SaveError: nil,
LoadError: nil, LoadError: nil,
}, },
&SessionState{AccessToken: "AccessToken"}, &State{AccessToken: "AccessToken"},
false, false,
false}, false},
} }

View file

@ -29,7 +29,7 @@ func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
// MockSessionStore is a mock implementation of the SessionStore interface // MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct { type MockSessionStore struct {
ResponseSession string ResponseSession string
Session *SessionState Session *State
SaveError error SaveError error
LoadError error LoadError error
} }
@ -40,11 +40,11 @@ func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
} }
// LoadSession returns the session and a error // LoadSession returns the session and a error
func (ms MockSessionStore) LoadSession(*http.Request) (*SessionState, error) { func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
return ms.Session, ms.LoadError return ms.Session, ms.LoadError
} }
// SaveSession returns a save error. // SaveSession returns a save error.
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error { func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *State) error {
return ms.SaveError return ms.SaveError
} }

View file

@ -1,106 +0,0 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
)
// DefaultBearerTokenHeader is default header name for the authorization bearer
// token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
const DefaultBearerTokenHeader = "Authorization"
// RestStore is a session store suitable for REST
type RestStore struct {
Name string
Cipher cryptutil.Cipher
// Expire time.Duration
}
// RestStoreOptions contains the options required to build a new RestStore.
type RestStoreOptions struct {
Name string
Cipher cryptutil.Cipher
// Expire time.Duration
}
// NewRestStore creates a new RestStore from a set of RestStoreOptions.
func NewRestStore(opts *RestStoreOptions) (*RestStore, error) {
if opts.Name == "" {
opts.Name = DefaultBearerTokenHeader
}
if opts.Cipher == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
}
return &RestStore{
Name: opts.Name,
// Expire: opts.Expire,
Cipher: opts.Cipher,
}, nil
}
// ClearSession functions differently because REST is stateless, we instead
// inform the client that this token is no longer valid.
// https://tools.ietf.org/html/rfc6750
func (s *RestStore) ClearSession(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
errMsg := `
{
"error": "invalid_token",
"token_type": "Bearer",
"error_description": "The token has expired."
}`
w.Write([]byte(errMsg))
}
// LoadSession attempts to load a pomerium session from a Bearer Token set
// in the authorization header.
func (s *RestStore) LoadSession(r *http.Request) (*SessionState, error) {
authHeader := r.Header.Get(s.Name)
split := strings.Split(authHeader, "Bearer")
if authHeader == "" || len(split) != 2 {
return nil, errors.New("internal/sessions: no bearer token header found")
}
token := strings.TrimSpace(split[1])
session, err := UnmarshalSession(token, s.Cipher)
if err != nil {
return nil, err
}
return session, nil
}
// RestStoreResponse is the JSON struct returned to the client.
type RestStoreResponse struct {
// Token is the encrypted pomerium session that can be used to
// programmatically authenticate with pomerium.
Token string
// In addition to the token, non-sensitive meta data is returned to help
// the client manage token renewals.
Expiry time.Time
}
// SaveSession returns an encrypted pomerium session as a JSON object with
// associated, non sensitive meta-data like
func (s *RestStore) SaveSession(w http.ResponseWriter, r *http.Request, sessionState *SessionState) error {
encToken, err := MarshalSession(sessionState, s.Cipher)
if err != nil {
return err
}
jsonBytes, err := json.Marshal(
&RestStoreResponse{
Token: encToken,
Expiry: sessionState.RefreshDeadline,
})
if err != nil {
return fmt.Errorf("internal/sessions: couldn't marshal token struct: %v", err)
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonBytes)
return nil
}

View file

@ -1,135 +0,0 @@
package sessions
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
)
func TestRestStore_SaveSession(t *testing.T) {
now := time.Date(2008, 1, 8, 17, 5, 5, 0, time.UTC)
tests := []struct {
name string
optionsName string
optionsCipher cryptutil.Cipher
sessionState *SessionState
wantErr bool
wantSaveResponse string
}{
{"good", "Authenticate", &cryptutil.MockCipher{MarshalResponse: "test"}, &SessionState{RefreshDeadline: now}, false, `{"Token":"test","Expiry":"2008-01-08T17:05:05Z"}`},
{"bad session marshal", "Authenticate", &cryptutil.MockCipher{MarshalError: errors.New("error")}, &SessionState{RefreshDeadline: now}, true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := NewRestStore(
&RestStoreOptions{
Name: tt.optionsName,
Cipher: tt.optionsCipher,
})
if err != nil {
t.Fatalf("NewRestStore err %v", err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
t.Errorf("RestStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
resp := w.Result()
body, _ := ioutil.ReadAll(resp.Body)
if diff := cmp.Diff(string(body), tt.wantSaveResponse); diff != "" {
t.Errorf("RestStore.SaveSession() got / want diff \n%s\n", diff)
}
})
}
}
func TestNewRestStore(t *testing.T) {
tests := []struct {
name string
optionsName string
optionsCipher cryptutil.Cipher
wantErr bool
}{
{"good", "Authenticate", &cryptutil.MockCipher{}, false},
{"good default to authenticate", "", &cryptutil.MockCipher{}, false},
{"empty cipher", "Authenticate", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewRestStore(
&RestStoreOptions{
Name: tt.optionsName,
Cipher: tt.optionsCipher,
})
if (err != nil) != tt.wantErr {
t.Errorf("NewRestStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestRestStore_ClearSession(t *testing.T) {
tests := []struct {
name string
expectedStatus int
}{
{"always returns reset!", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &RestStore{Name: "Authenticate", Cipher: &cryptutil.MockCipher{}}
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
s.ClearSession(w, r)
resp := w.Result()
if diff := cmp.Diff(resp.StatusCode, tt.expectedStatus); diff != "" {
t.Errorf("RestStore.ClearSession() got / want diff \n%s\n", diff)
}
})
}
}
func TestRestStore_LoadSession(t *testing.T) {
tests := []struct {
name string
optionsName string
optionsCipher cryptutil.Cipher
token string
wantErr bool
}{
{"good", "Authorization", &cryptutil.MockCipher{}, "test", false},
{"empty auth header", "", &cryptutil.MockCipher{}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &RestStore{
Name: tt.optionsName,
Cipher: tt.optionsCipher,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.optionsName != "" {
r.Header.Set(tt.optionsName, fmt.Sprintf(("Bearer %s"), tt.token))
}
_, err := s.LoadSession(r)
if (err != nil) != tt.wantErr {
t.Errorf("RestStore.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -3,7 +3,6 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -11,13 +10,11 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
var ( // ErrExpired is an error for a expired sessions.
// ErrLifetimeExpired is an error for the lifetime deadline expiring var ErrExpired = fmt.Errorf("internal/sessions: expired session")
ErrLifetimeExpired = errors.New("user lifetime expired")
)
// SessionState is our object that keeps track of a user's session state // State is our object that keeps track of a user's session state
type SessionState struct { type State struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
@ -31,18 +28,31 @@ type SessionState struct {
ImpersonateGroups []string ImpersonateGroups []string
} }
// RefreshPeriodExpired returns true if the refresh period has expired // Valid returns an error if the users's session state is not valid.
func (s *SessionState) RefreshPeriodExpired() bool { func (s *State) Valid() error {
return isExpired(s.RefreshDeadline) if s.Expired() {
return ErrExpired
}
return nil
}
// ForceRefresh sets the refresh deadline to now.
func (s *State) ForceRefresh() {
s.RefreshDeadline = time.Now().Truncate(time.Second)
}
// Expired returns true if the refresh period has expired
func (s *State) Expired() bool {
return s.RefreshDeadline.Before(time.Now())
} }
// Impersonating returns if the request is impersonating. // Impersonating returns if the request is impersonating.
func (s *SessionState) Impersonating() bool { func (s *State) Impersonating() bool {
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0 return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
} }
// RequestEmail is the email to make the request as. // RequestEmail is the email to make the request as.
func (s *SessionState) RequestEmail() string { func (s *State) RequestEmail() string {
if s.ImpersonateEmail != "" { if s.ImpersonateEmail != "" {
return s.ImpersonateEmail return s.ImpersonateEmail
} }
@ -51,7 +61,7 @@ func (s *SessionState) RequestEmail() string {
// RequestGroups returns the groups of the Groups making the request; uses // RequestGroups returns the groups of the Groups making the request; uses
// impersonating user if set. // impersonating user if set.
func (s *SessionState) RequestGroups() string { func (s *State) RequestGroups() string {
if len(s.ImpersonateGroups) != 0 { if len(s.ImpersonateGroups) != 0 {
return strings.Join(s.ImpersonateGroups, ",") return strings.Join(s.ImpersonateGroups, ",")
} }
@ -68,7 +78,7 @@ type idToken struct {
} }
// IssuedAt parses the IDToken's issue date and returns a valid go time.Time. // IssuedAt parses the IDToken's issue date and returns a valid go time.Time.
func (s *SessionState) IssuedAt() (time.Time, error) { func (s *State) IssuedAt() (time.Time, error) {
payload, err := parseJWT(s.IDToken) payload, err := parseJWT(s.IDToken)
if err != nil { if err != nil {
return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err) return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err)
@ -80,13 +90,9 @@ func (s *SessionState) IssuedAt() (time.Time, error) {
return time.Time(token.IssuedAt), nil return time.Time(token.IssuedAt), nil
} }
func isExpired(t time.Time) bool {
return t.Before(time.Now())
}
// MarshalSession marshals the session state as JSON, encrypts the JSON using the // MarshalSession marshals the session state as JSON, encrypts the JSON using the
// given cipher, and base64-encodes the result // given cipher, and base64-encodes the result
func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) { func MarshalSession(s *State, c cryptutil.Cipher) (string, error) {
v, err := c.Marshal(s) v, err := c.Marshal(s)
if err != nil { if err != nil {
return "", err return "", err
@ -96,8 +102,8 @@ func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) {
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) { func UnmarshalSession(value string, c cryptutil.Cipher) (*State, error) {
s := &SessionState{} s := &State{}
err := c.Unmarshal(value, s) err := c.Unmarshal(value, s)
if err != nil { if err != nil {
return nil, err return nil, err
@ -105,11 +111,6 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
return s, nil return s, nil
} }
// ExtendDeadline returns the time extended by a given duration, truncated by second
func ExtendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}
func parseJWT(p string) ([]byte, error) { func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".") parts := strings.Split(p, ".")
if len(parts) < 2 { if len(parts) < 2 {

View file

@ -11,14 +11,14 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
func TestSessionStateSerialization(t *testing.T) { func TestStateSerialization(t *testing.T) {
secret := cryptutil.GenerateKey() secret := cryptutil.GenerateKey()
c, err := cryptutil.NewCipher(secret) c, err := cryptutil.NewCipher(secret)
if err != nil { if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err) t.Fatalf("expected to be able to create cipher: %v", err)
} }
want := &SessionState{ want := &State{
AccessToken: "token1234", AccessToken: "token1234",
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
@ -43,41 +43,21 @@ func TestSessionStateSerialization(t *testing.T) {
} }
} }
func TestSessionStateExpirations(t *testing.T) { func TestStateExpirations(t *testing.T) {
session := &SessionState{ session := &State{
AccessToken: "token1234", AccessToken: "token1234",
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
RefreshDeadline: time.Now().Add(-1 * time.Hour), RefreshDeadline: time.Now().Add(-1 * time.Hour),
Email: "user@domain.com", Email: "user@domain.com",
User: "user", User: "user",
} }
if !session.RefreshPeriodExpired() { if !session.Expired() {
t.Errorf("expected lifetime period to be expired") t.Errorf("expected lifetime period to be expired")
} }
} }
func TestExtendDeadline(t *testing.T) { func TestState_IssuedAt(t *testing.T) {
// tons of wiggle room here
now := time.Now().Truncate(time.Second)
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)},
{"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want)
}
})
}
}
func TestSessionState_IssuedAt(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
@ -91,20 +71,20 @@ func TestSessionState_IssuedAt(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &SessionState{IDToken: tt.IDToken} s := &State{IDToken: tt.IDToken}
got, err := s.IssuedAt() got, err := s.IssuedAt()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("SessionState.IssuedAt() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("State.IssuedAt() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SessionState.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) t.Errorf("State.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339))
} }
}) })
} }
} }
func TestSessionState_Impersonating(t *testing.T) { func TestState_Impersonating(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
@ -123,20 +103,20 @@ func TestSessionState_Impersonating(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &SessionState{ s := &State{
Email: tt.Email, Email: tt.Email,
Groups: tt.Groups, Groups: tt.Groups,
ImpersonateEmail: tt.ImpersonateEmail, ImpersonateEmail: tt.ImpersonateEmail,
ImpersonateGroups: tt.ImpersonateGroups, ImpersonateGroups: tt.ImpersonateGroups,
} }
if got := s.Impersonating(); got != tt.want { if got := s.Impersonating(); got != tt.want {
t.Errorf("SessionState.Impersonating() = %v, want %v", got, tt.want) t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
} }
if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail { if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail {
t.Errorf("SessionState.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail) t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
} }
if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups { if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups {
t.Errorf("SessionState.v() = %v, want %v", gotGroups, tt.wantResponseGroups) t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
} }
}) })
} }
@ -154,11 +134,11 @@ func TestMarshalSession(t *testing.T) {
} }
tests := []struct { tests := []struct {
name string name string
s *SessionState s *State
wantErr bool wantErr bool
}{ }{
{"simple", &SessionState{}, false}, {"simple", &State{}, false},
{"too big", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString)}, false}, {"too big", &State{AccessToken: fmt.Sprintf("%x", hugeString)}, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -179,3 +159,45 @@ func TestMarshalSession(t *testing.T) {
}) })
} }
} }
func TestState_Valid(t *testing.T) {
tests := []struct {
name string
RefreshDeadline time.Time
wantErr bool
}{
{" good", time.Now().Add(10 * time.Second), false},
{" expired", time.Now().Add(-10 * time.Second), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &State{
RefreshDeadline: tt.RefreshDeadline,
}
if err := s.Valid(); (err != nil) != tt.wantErr {
t.Errorf("State.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestState_ForceRefresh(t *testing.T) {
tests := []struct {
name string
RefreshDeadline time.Time
}{
{"good", time.Now().Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &State{
RefreshDeadline: tt.RefreshDeadline,
}
s.ForceRefresh()
if s.RefreshDeadline != tt.RefreshDeadline {
t.Errorf("refresh deadline not updated")
}
})
}
}

View file

@ -0,0 +1,26 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"net/http"
)
// ErrEmptySession is an error for an empty sessions.
var ErrEmptySession = errors.New("internal/sessions: empty session")
// ErrEmptyCSRF is an error for an empty sessions.
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
LoadSession(*http.Request) (*State, error)
SaveSession(http.ResponseWriter, *http.Request, *State) error
}

View file

@ -8,12 +8,12 @@ import (
// The following tags are applied to stats recorded by this package. // The following tags are applied to stats recorded by this package.
var ( var (
TagKeyHTTPMethod tag.Key = tag.MustNewKey("http_method") TagKeyHTTPMethod = tag.MustNewKey("http_method")
TagKeyService tag.Key = tag.MustNewKey("service") TagKeyService = tag.MustNewKey("service")
TagKeyGRPCService tag.Key = tag.MustNewKey("grpc_service") TagKeyGRPCService = tag.MustNewKey("grpc_service")
TagKeyGRPCMethod tag.Key = tag.MustNewKey("grpc_method") TagKeyGRPCMethod = tag.MustNewKey("grpc_method")
TagKeyHost tag.Key = tag.MustNewKey("host") TagKeyHost = tag.MustNewKey("host")
TagKeyDestination tag.Key = tag.MustNewKey("destination") TagKeyDestination = tag.MustNewKey("destination")
) )
// Default distributions used by views in this package. // Default distributions used by views in this package.

View file

@ -1,399 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: authenticate.proto
package authenticate
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import timestamp "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type AuthenticateRequest struct {
Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *AuthenticateRequest) Reset() { *m = AuthenticateRequest{} }
func (m *AuthenticateRequest) String() string { return proto.CompactTextString(m) }
func (*AuthenticateRequest) ProtoMessage() {}
func (*AuthenticateRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{0}
}
func (m *AuthenticateRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_AuthenticateRequest.Unmarshal(m, b)
}
func (m *AuthenticateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_AuthenticateRequest.Marshal(b, m, deterministic)
}
func (dst *AuthenticateRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_AuthenticateRequest.Merge(dst, src)
}
func (m *AuthenticateRequest) XXX_Size() int {
return xxx_messageInfo_AuthenticateRequest.Size(m)
}
func (m *AuthenticateRequest) XXX_DiscardUnknown() {
xxx_messageInfo_AuthenticateRequest.DiscardUnknown(m)
}
var xxx_messageInfo_AuthenticateRequest proto.InternalMessageInfo
func (m *AuthenticateRequest) GetCode() string {
if m != nil {
return m.Code
}
return ""
}
type ValidateRequest struct {
IdToken string `protobuf:"bytes,1,opt,name=id_token,json=idToken,proto3" json:"id_token,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ValidateRequest) Reset() { *m = ValidateRequest{} }
func (m *ValidateRequest) String() string { return proto.CompactTextString(m) }
func (*ValidateRequest) ProtoMessage() {}
func (*ValidateRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{1}
}
func (m *ValidateRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ValidateRequest.Unmarshal(m, b)
}
func (m *ValidateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ValidateRequest.Marshal(b, m, deterministic)
}
func (dst *ValidateRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_ValidateRequest.Merge(dst, src)
}
func (m *ValidateRequest) XXX_Size() int {
return xxx_messageInfo_ValidateRequest.Size(m)
}
func (m *ValidateRequest) XXX_DiscardUnknown() {
xxx_messageInfo_ValidateRequest.DiscardUnknown(m)
}
var xxx_messageInfo_ValidateRequest proto.InternalMessageInfo
func (m *ValidateRequest) GetIdToken() string {
if m != nil {
return m.IdToken
}
return ""
}
type ValidateReply struct {
IsValid bool `protobuf:"varint,1,opt,name=is_valid,json=isValid,proto3" json:"is_valid,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ValidateReply) Reset() { *m = ValidateReply{} }
func (m *ValidateReply) String() string { return proto.CompactTextString(m) }
func (*ValidateReply) ProtoMessage() {}
func (*ValidateReply) Descriptor() ([]byte, []int) {
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{2}
}
func (m *ValidateReply) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ValidateReply.Unmarshal(m, b)
}
func (m *ValidateReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ValidateReply.Marshal(b, m, deterministic)
}
func (dst *ValidateReply) XXX_Merge(src proto.Message) {
xxx_messageInfo_ValidateReply.Merge(dst, src)
}
func (m *ValidateReply) XXX_Size() int {
return xxx_messageInfo_ValidateReply.Size(m)
}
func (m *ValidateReply) XXX_DiscardUnknown() {
xxx_messageInfo_ValidateReply.DiscardUnknown(m)
}
var xxx_messageInfo_ValidateReply proto.InternalMessageInfo
func (m *ValidateReply) GetIsValid() bool {
if m != nil {
return m.IsValid
}
return false
}
type Session struct {
AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
RefreshToken string `protobuf:"bytes,2,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
IdToken string `protobuf:"bytes,3,opt,name=id_token,json=idToken,proto3" json:"id_token,omitempty"`
User string `protobuf:"bytes,4,opt,name=user,proto3" json:"user,omitempty"`
Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"`
Groups []string `protobuf:"bytes,6,rep,name=groups,proto3" json:"groups,omitempty"`
RefreshDeadline *timestamp.Timestamp `protobuf:"bytes,7,opt,name=refresh_deadline,json=refreshDeadline,proto3" json:"refresh_deadline,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Session) Reset() { *m = Session{} }
func (m *Session) String() string { return proto.CompactTextString(m) }
func (*Session) ProtoMessage() {}
func (*Session) Descriptor() ([]byte, []int) {
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{3}
}
func (m *Session) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Session.Unmarshal(m, b)
}
func (m *Session) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Session.Marshal(b, m, deterministic)
}
func (dst *Session) XXX_Merge(src proto.Message) {
xxx_messageInfo_Session.Merge(dst, src)
}
func (m *Session) XXX_Size() int {
return xxx_messageInfo_Session.Size(m)
}
func (m *Session) XXX_DiscardUnknown() {
xxx_messageInfo_Session.DiscardUnknown(m)
}
var xxx_messageInfo_Session proto.InternalMessageInfo
func (m *Session) GetAccessToken() string {
if m != nil {
return m.AccessToken
}
return ""
}
func (m *Session) GetRefreshToken() string {
if m != nil {
return m.RefreshToken
}
return ""
}
func (m *Session) GetIdToken() string {
if m != nil {
return m.IdToken
}
return ""
}
func (m *Session) GetUser() string {
if m != nil {
return m.User
}
return ""
}
func (m *Session) GetEmail() string {
if m != nil {
return m.Email
}
return ""
}
func (m *Session) GetGroups() []string {
if m != nil {
return m.Groups
}
return nil
}
func (m *Session) GetRefreshDeadline() *timestamp.Timestamp {
if m != nil {
return m.RefreshDeadline
}
return nil
}
func init() {
proto.RegisterType((*AuthenticateRequest)(nil), "authenticate.AuthenticateRequest")
proto.RegisterType((*ValidateRequest)(nil), "authenticate.ValidateRequest")
proto.RegisterType((*ValidateReply)(nil), "authenticate.ValidateReply")
proto.RegisterType((*Session)(nil), "authenticate.Session")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// AuthenticatorClient is the client API for Authenticator service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type AuthenticatorClient interface {
Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*Session, error)
Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateReply, error)
Refresh(ctx context.Context, in *Session, opts ...grpc.CallOption) (*Session, error)
}
type authenticatorClient struct {
cc *grpc.ClientConn
}
func NewAuthenticatorClient(cc *grpc.ClientConn) AuthenticatorClient {
return &authenticatorClient{cc}
}
func (c *authenticatorClient) Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*Session, error) {
out := new(Session)
err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Authenticate", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *authenticatorClient) Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateReply, error) {
out := new(ValidateReply)
err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Validate", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *authenticatorClient) Refresh(ctx context.Context, in *Session, opts ...grpc.CallOption) (*Session, error) {
out := new(Session)
err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Refresh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// AuthenticatorServer is the server API for Authenticator service.
type AuthenticatorServer interface {
Authenticate(context.Context, *AuthenticateRequest) (*Session, error)
Validate(context.Context, *ValidateRequest) (*ValidateReply, error)
Refresh(context.Context, *Session) (*Session, error)
}
func RegisterAuthenticatorServer(s *grpc.Server, srv AuthenticatorServer) {
s.RegisterService(&_Authenticator_serviceDesc, srv)
}
func _Authenticator_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthenticateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticatorServer).Authenticate(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/authenticate.Authenticator/Authenticate",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticatorServer).Authenticate(ctx, req.(*AuthenticateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Authenticator_Validate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ValidateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticatorServer).Validate(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/authenticate.Authenticator/Validate",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticatorServer).Validate(ctx, req.(*ValidateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Authenticator_Refresh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Session)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticatorServer).Refresh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/authenticate.Authenticator/Refresh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticatorServer).Refresh(ctx, req.(*Session))
}
return interceptor(ctx, in, info, handler)
}
var _Authenticator_serviceDesc = grpc.ServiceDesc{
ServiceName: "authenticate.Authenticator",
HandlerType: (*AuthenticatorServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Authenticate",
Handler: _Authenticator_Authenticate_Handler,
},
{
MethodName: "Validate",
Handler: _Authenticator_Validate_Handler,
},
{
MethodName: "Refresh",
Handler: _Authenticator_Refresh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "authenticate.proto",
}
func init() { proto.RegisterFile("authenticate.proto", fileDescriptor_authenticate_d9796afa57ba1f78) }
var fileDescriptor_authenticate_d9796afa57ba1f78 = []byte{
// 354 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x91, 0x4d, 0x4f, 0xb3, 0x40,
0x14, 0x85, 0xcb, 0xdb, 0x0f, 0xda, 0x5b, 0x9a, 0xbe, 0xb9, 0x7e, 0x04, 0x31, 0xc6, 0x16, 0x37,
0xd5, 0x18, 0x9a, 0xd4, 0x95, 0x4b, 0x13, 0x4d, 0x8c, 0x4b, 0x6c, 0xdc, 0x36, 0x14, 0x6e, 0xdb,
0x89, 0x94, 0x41, 0x66, 0x30, 0xe9, 0xbf, 0xf5, 0x4f, 0xb8, 0x37, 0x0c, 0x10, 0xc1, 0xb4, 0x3b,
0xee, 0x99, 0xe7, 0x0e, 0x67, 0xce, 0x01, 0xf4, 0x52, 0xb9, 0xa1, 0x48, 0x32, 0xdf, 0x93, 0xe4,
0xc4, 0x09, 0x97, 0x1c, 0x8d, 0xaa, 0x66, 0x5d, 0xae, 0x39, 0x5f, 0x87, 0x34, 0x55, 0x67, 0xcb,
0x74, 0x35, 0x95, 0x6c, 0x4b, 0x42, 0x7a, 0xdb, 0x38, 0xc7, 0xed, 0x6b, 0x38, 0x7a, 0xa8, 0x2c,
0xb8, 0xf4, 0x91, 0x92, 0x90, 0x88, 0xd0, 0xf2, 0x79, 0x40, 0xa6, 0x36, 0xd2, 0x26, 0x3d, 0x57,
0x7d, 0xdb, 0xb7, 0x30, 0x7c, 0xf3, 0x42, 0x16, 0x54, 0xb0, 0x33, 0xe8, 0xb2, 0x60, 0x21, 0xf9,
0x3b, 0x45, 0x05, 0xaa, 0xb3, 0x60, 0x9e, 0x8d, 0xf6, 0x0d, 0x0c, 0x7e, 0xe9, 0x38, 0xdc, 0x29,
0x56, 0x2c, 0x3e, 0x33, 0x4d, 0xb1, 0x5d, 0x57, 0x67, 0x42, 0x21, 0xf6, 0xb7, 0x06, 0xfa, 0x2b,
0x09, 0xc1, 0x78, 0x84, 0x63, 0x30, 0x3c, 0xdf, 0x27, 0x21, 0x6a, 0xd7, 0xf6, 0x73, 0x4d, 0x5d,
0x8d, 0x57, 0x30, 0x48, 0x68, 0x95, 0x90, 0xd8, 0x14, 0xcc, 0x3f, 0xc5, 0x18, 0x85, 0x98, 0x43,
0x55, 0x6b, 0xcd, 0x9a, 0xb5, 0xec, 0x71, 0xa9, 0xa0, 0xc4, 0x6c, 0xe5, 0x8f, 0xcb, 0xbe, 0xf1,
0x18, 0xda, 0xb4, 0xf5, 0x58, 0x68, 0xb6, 0x95, 0x98, 0x0f, 0x78, 0x0a, 0x9d, 0x75, 0xc2, 0xd3,
0x58, 0x98, 0x9d, 0x51, 0x73, 0xd2, 0x73, 0x8b, 0x09, 0x9f, 0xe0, 0x7f, 0xe9, 0x20, 0x20, 0x2f,
0x08, 0x59, 0x44, 0xa6, 0x3e, 0xd2, 0x26, 0xfd, 0x99, 0xe5, 0xe4, 0x89, 0x3b, 0x65, 0xe2, 0xce,
0xbc, 0x4c, 0xdc, 0x1d, 0x16, 0x3b, 0x8f, 0xc5, 0xca, 0xec, 0x4b, 0x83, 0x41, 0x25, 0x7d, 0x9e,
0xe0, 0x0b, 0x18, 0xd5, 0x3a, 0x70, 0xec, 0xd4, 0x2a, 0xde, 0x53, 0x95, 0x75, 0x52, 0x47, 0x8a,
0x1c, 0xed, 0x06, 0x3e, 0x43, 0xb7, 0x6c, 0x00, 0x2f, 0xea, 0xd0, 0x9f, 0x1e, 0xad, 0xf3, 0x43,
0xc7, 0x71, 0xb8, 0xb3, 0x1b, 0x78, 0x0f, 0xba, 0x9b, 0x5b, 0xc7, 0xfd, 0x7f, 0x3b, 0x68, 0x62,
0xd9, 0x51, 0x39, 0xdc, 0xfd, 0x04, 0x00, 0x00, 0xff, 0xff, 0xdc, 0x47, 0xff, 0x7e, 0xab, 0x02,
0x00, 0x00,
}

View file

@ -1,26 +0,0 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
package authenticate;
service Authenticator {
rpc Authenticate(AuthenticateRequest) returns (Session) {}
rpc Validate(ValidateRequest) returns (ValidateReply) {}
rpc Refresh(Session) returns (Session) {}
}
message AuthenticateRequest { string code = 1; }
message ValidateRequest { string id_token = 1; }
message ValidateReply { bool is_valid = 1; }
message Session {
string access_token = 1;
string refresh_token = 2;
string id_token = 3;
string user = 4;
string email = 5;
repeated string groups = 6;
google.protobuf.Timestamp refresh_deadline = 7;
}

View file

@ -1,49 +0,0 @@
package authenticate
import (
fmt "fmt"
"github.com/golang/protobuf/ptypes"
"github.com/pomerium/pomerium/internal/sessions"
)
// SessionFromProto converts a converts a protocol buffer session into a pomerium session state.
func SessionFromProto(p *Session) (*sessions.SessionState, error) {
if p == nil {
return nil, fmt.Errorf("proto/authenticate: SessionFromProto session cannot be nil")
}
refreshDeadline, err := ptypes.Timestamp(p.RefreshDeadline)
if err != nil {
return nil, fmt.Errorf("proto/authenticate: couldn't parse refresh deadline %v", err)
}
return &sessions.SessionState{
AccessToken: p.AccessToken,
RefreshToken: p.RefreshToken,
IDToken: p.IdToken,
Email: p.Email,
User: p.User,
Groups: p.Groups,
RefreshDeadline: refreshDeadline,
}, nil
}
// ProtoFromSession converts a pomerium user session into a protocol buffer struct.
func ProtoFromSession(s *sessions.SessionState) (*Session, error) {
if s == nil {
return nil, fmt.Errorf("proto/authenticate: ProtoFromSession session cannot be nil")
}
refreshDeadline, err := ptypes.TimestampProto(s.RefreshDeadline)
if err != nil {
return nil, fmt.Errorf("proto/authenticate: couldn't parse refresh deadline %v", err)
}
return &Session{
AccessToken: s.AccessToken,
RefreshToken: s.RefreshToken,
IdToken: s.IDToken,
Email: s.Email,
User: s.User,
Groups: s.Groups,
RefreshDeadline: refreshDeadline,
}, nil
}

View file

@ -1,165 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: proto/authenticate/authenticate.pb.go
// Package mock_authenticate is a generated GoMock package.
package mock_authenticate
import (
"context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
"github.com/pomerium/pomerium/proto/authenticate"
grpc "google.golang.org/grpc"
)
// MockAuthenticatorClient is a mock of AuthenticatorClient interface
type MockAuthenticatorClient struct {
ctrl *gomock.Controller
recorder *MockAuthenticatorClientMockRecorder
}
// MockAuthenticatorClientMockRecorder is the mock recorder for MockAuthenticatorClient
type MockAuthenticatorClientMockRecorder struct {
mock *MockAuthenticatorClient
}
// NewMockAuthenticatorClient creates a new mock instance
func NewMockAuthenticatorClient(ctrl *gomock.Controller) *MockAuthenticatorClient {
mock := &MockAuthenticatorClient{ctrl: ctrl}
mock.recorder = &MockAuthenticatorClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockAuthenticatorClient) EXPECT() *MockAuthenticatorClientMockRecorder {
return m.recorder
}
// Authenticate mocks base method
func (m *MockAuthenticatorClient) Authenticate(ctx context.Context, in *authenticate.AuthenticateRequest, opts ...grpc.CallOption) (*authenticate.Session, error) {
m.ctrl.T.Helper()
varargs := []interface{}{ctx, in}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Authenticate", varargs...)
ret0, _ := ret[0].(*authenticate.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Authenticate indicates an expected call of Authenticate
func (mr *MockAuthenticatorClientMockRecorder) Authenticate(ctx, in interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx, in}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticatorClient)(nil).Authenticate), varargs...)
}
// Validate mocks base method
func (m *MockAuthenticatorClient) Validate(ctx context.Context, in *authenticate.ValidateRequest, opts ...grpc.CallOption) (*authenticate.ValidateReply, error) {
m.ctrl.T.Helper()
varargs := []interface{}{ctx, in}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Validate", varargs...)
ret0, _ := ret[0].(*authenticate.ValidateReply)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Validate indicates an expected call of Validate
func (mr *MockAuthenticatorClientMockRecorder) Validate(ctx, in interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx, in}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockAuthenticatorClient)(nil).Validate), varargs...)
}
// Refresh mocks base method
func (m *MockAuthenticatorClient) Refresh(ctx context.Context, in *authenticate.Session, opts ...grpc.CallOption) (*authenticate.Session, error) {
m.ctrl.T.Helper()
varargs := []interface{}{ctx, in}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Refresh", varargs...)
ret0, _ := ret[0].(*authenticate.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Refresh indicates an expected call of Refresh
func (mr *MockAuthenticatorClientMockRecorder) Refresh(ctx, in interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx, in}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockAuthenticatorClient)(nil).Refresh), varargs...)
}
// MockAuthenticatorServer is a mock of AuthenticatorServer interface
type MockAuthenticatorServer struct {
ctrl *gomock.Controller
recorder *MockAuthenticatorServerMockRecorder
}
// MockAuthenticatorServerMockRecorder is the mock recorder for MockAuthenticatorServer
type MockAuthenticatorServerMockRecorder struct {
mock *MockAuthenticatorServer
}
// NewMockAuthenticatorServer creates a new mock instance
func NewMockAuthenticatorServer(ctrl *gomock.Controller) *MockAuthenticatorServer {
mock := &MockAuthenticatorServer{ctrl: ctrl}
mock.recorder = &MockAuthenticatorServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockAuthenticatorServer) EXPECT() *MockAuthenticatorServerMockRecorder {
return m.recorder
}
// Authenticate mocks base method
func (m *MockAuthenticatorServer) Authenticate(arg0 context.Context, arg1 *authenticate.AuthenticateRequest) (*authenticate.Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Authenticate", arg0, arg1)
ret0, _ := ret[0].(*authenticate.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Authenticate indicates an expected call of Authenticate
func (mr *MockAuthenticatorServerMockRecorder) Authenticate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticatorServer)(nil).Authenticate), arg0, arg1)
}
// Validate mocks base method
func (m *MockAuthenticatorServer) Validate(arg0 context.Context, arg1 *authenticate.ValidateRequest) (*authenticate.ValidateReply, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Validate", arg0, arg1)
ret0, _ := ret[0].(*authenticate.ValidateReply)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Validate indicates an expected call of Validate
func (mr *MockAuthenticatorServerMockRecorder) Validate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockAuthenticatorServer)(nil).Validate), arg0, arg1)
}
// Refresh mocks base method
func (m *MockAuthenticatorServer) Refresh(arg0 context.Context, arg1 *authenticate.Session) (*authenticate.Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Refresh", arg0, arg1)
ret0, _ := ret[0].(*authenticate.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Refresh indicates an expected call of Refresh
func (mr *MockAuthenticatorServerMockRecorder) Refresh(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockAuthenticatorServer)(nil).Refresh), arg0, arg1)
}

View file

@ -1,117 +0,0 @@
package clients // import "github.com/pomerium/pomerium/proxy/clients"
import (
"context"
"errors"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
pb "github.com/pomerium/pomerium/proto/authenticate"
"google.golang.org/grpc"
)
// Authenticator provides the authenticate service interface
type Authenticator interface {
// Redeem takes a code and returns a validated session or an error
Redeem(context.Context, string) (*sessions.SessionState, error)
// Refresh attempts to refresh a valid session with a refresh token. Returns a refreshed session.
Refresh(context.Context, *sessions.SessionState) (*sessions.SessionState, error)
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
Validate(context.Context, string) (bool, error)
// Close closes the authenticator connection if any.
Close() error
}
// NewAuthenticateClient returns a new authenticate service client. Presently,
// only gRPC is supported and is always returned so name is ignored.
func NewAuthenticateClient(name string, opts *Options) (a Authenticator, err error) {
return NewGRPCAuthenticateClient(opts)
}
// NewGRPCAuthenticateClient returns a new authenticate service client.
func NewGRPCAuthenticateClient(opts *Options) (p *AuthenticateGRPC, err error) {
conn, err := NewGRPCClientConn(opts)
if err != nil {
return nil, err
}
authClient := pb.NewAuthenticatorClient(conn)
return &AuthenticateGRPC{Conn: conn, client: authClient}, nil
}
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
type AuthenticateGRPC struct {
Conn *grpc.ClientConn
client pb.AuthenticatorClient
}
// Redeem makes an RPC call to the authenticate service to creates a session state
// from an encrypted code provided as a result of an oauth2 callback process.
func (a *AuthenticateGRPC) Redeem(ctx context.Context, code string) (*sessions.SessionState, error) {
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Redeem")
defer span.End()
if code == "" {
return nil, errors.New("missing code")
}
protoSession, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
if err != nil {
return nil, err
}
session, err := pb.SessionFromProto(protoSession)
if err != nil {
return nil, err
}
return session, nil
}
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
// user's session. Requires a valid refresh token. Will return an error if the identity provider
// has revoked the session or if the refresh token is no longer valid in this context.
func (a *AuthenticateGRPC) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Refresh")
defer span.End()
if s.RefreshToken == "" {
return nil, errors.New("missing refresh token")
}
req, err := pb.ProtoFromSession(s)
if err != nil {
return nil, err
}
// todo(bdd): add grpc specific timeouts to main options
// todo(bdd): handle request id (metadata!?) in grpc receiver and add to ctx logger
reply, err := a.client.Refresh(ctx, req)
if err != nil {
return nil, err
}
newSession, err := pb.SessionFromProto(reply)
if err != nil {
return nil, err
}
return newSession, nil
}
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
// does NOT do nonce or revokation validation.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (a *AuthenticateGRPC) Validate(ctx context.Context, idToken string) (bool, error) {
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Validate")
defer span.End()
if idToken == "" {
return false, errors.New("missing id token")
}
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
if err != nil {
return false, err
}
return r.IsValid, nil
}
// Close tears down the ClientConn and all underlying connections.
func (a *AuthenticateGRPC) Close() error {
return a.Conn.Close()
}

View file

@ -1,242 +0,0 @@
package clients // import "github.com/pomerium/pomerium/proxy/clients"
import (
"context"
"fmt"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/pomerium/pomerium/internal/sessions"
pb "github.com/pomerium/pomerium/proto/authenticate"
mock "github.com/pomerium/pomerium/proto/authenticate/mock_authenticate"
)
func TestNew(t *testing.T) {
tests := []struct {
name string
serviceName string
opts *Options
wantErr bool
}{
{"grpc good", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "secret"}, false},
{"grpc missing shared secret", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: ""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewAuthenticateClient(tt.serviceName, tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
// rpcMsg implements the gomock.Matcher interface
type rpcMsg struct {
msg proto.Message
}
func (r *rpcMsg) Matches(msg interface{}) bool {
m, ok := msg.(proto.Message)
if !ok {
return false
}
return proto.Equal(m, r.msg)
}
func (r *rpcMsg) String() string {
return fmt.Sprintf("is %s", r.msg)
}
func TestProxy_Redeem(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
req := &pb.AuthenticateRequest{Code: "unit_test"}
mockExpire, err := ptypes.TimestampProto(fixedDate)
if err != nil {
t.Fatalf("%v failed converting timestamp", err)
}
mockAuthenticateClient.EXPECT().Authenticate(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&pb.Session{
AccessToken: "mocked access token",
RefreshToken: "mocked refresh token",
IdToken: "mocked id token",
User: "user1",
Email: "test@email.com",
RefreshDeadline: mockExpire,
}, nil)
tests := []struct {
name string
idToken string
want *sessions.SessionState
wantErr bool
}{
{"good", "unit_test", &sessions.SessionState{
AccessToken: "mocked access token",
RefreshToken: "mocked refresh token",
IDToken: "mocked id token",
User: "user1",
Email: "test@email.com",
RefreshDeadline: (fixedDate),
}, false},
{"empty code", "", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: mockAuthenticateClient}
got, err := a.Redeem(context.Background(), tt.idToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
return
}
if got != nil {
if got.AccessToken != "mocked access token" {
t.Errorf("authenticate: invalid access token")
}
if got.RefreshToken != "mocked refresh token" {
t.Errorf("authenticate: invalid refresh token")
}
if got.IDToken != "mocked id token" {
t.Errorf("authenticate: invalid id token")
}
if got.User != "user1" {
t.Errorf("authenticate: invalid user")
}
if got.Email != "test@email.com" {
t.Errorf("authenticate: invalid email")
}
}
})
}
}
func TestProxy_AuthenticateValidate(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
req := &pb.ValidateRequest{IdToken: "unit_test"}
mockAuthenticateClient.EXPECT().Validate(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&pb.ValidateReply{IsValid: false}, nil)
ac := mockAuthenticateClient
tests := []struct {
name string
idToken string
want bool
wantErr bool
}{
{"good", "unit_test", false, false},
{"empty id token", "", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: ac}
got, err := a.Validate(context.Background(), tt.idToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Proxy.AuthenticateValidate() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_AuthenticateRefresh(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockRefreshClient := mock.NewMockAuthenticatorClient(ctrl)
mockExpire, _ := ptypes.TimestampProto(fixedDate)
mockRefreshClient.EXPECT().Refresh(
gomock.Any(),
gomock.Not(sessions.SessionState{RefreshToken: "fail"}),
).Return(&pb.Session{
AccessToken: "new access token",
RefreshDeadline: mockExpire,
}, nil).AnyTimes()
tests := []struct {
name string
session *sessions.SessionState
want *sessions.SessionState
wantErr bool
}{
{"good",
&sessions.SessionState{RefreshToken: "unit_test"},
&sessions.SessionState{
AccessToken: "new access token",
RefreshDeadline: fixedDate,
}, false},
{"empty refresh token", &sessions.SessionState{RefreshToken: ""}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: mockRefreshClient}
got, err := a.Refresh(context.Background(), tt.session)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Proxy.AuthenticateRefresh() got = \n%#v\nwant \n%#v", got, tt.want)
}
})
}
}
func TestNewGRPC(t *testing.T) {
tests := []struct {
name string
opts *Options
wantErr bool
wantErrStr string
wantTarget string
}{
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""},
{"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
{"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
{"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
{"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
{"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"},
{"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"},
{"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"},
{"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewGRPCAuthenticateClient(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("NewGRPCAuthenticateClient() error = %v, wantErr %v", err, tt.wantErr)
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
t.Errorf("NewGRPCAuthenticateClient() error = %v did not contain wantErr %v", err, tt.wantErrStr)
}
}
if got != nil && got.Conn.Target() != tt.wantTarget {
t.Errorf("NewGRPCAuthenticateClient() target = %v expected %v", got.Conn.Target(), tt.wantTarget)
}
})
}
}

View file

@ -15,9 +15,9 @@ import (
type Authorizer interface { type Authorizer interface {
// Authorize takes a route and user session and returns whether the // Authorize takes a route and user session and returns whether the
// request is valid per access policy // request is valid per access policy
Authorize(context.Context, string, *sessions.SessionState) (bool, error) Authorize(context.Context, string, *sessions.State) (bool, error)
// IsAdmin takes a session and returns whether the user is an administrator // IsAdmin takes a session and returns whether the user is an administrator
IsAdmin(context.Context, *sessions.SessionState) (bool, error) IsAdmin(context.Context, *sessions.State) (bool, error)
// Close closes the auth connection if any. // Close closes the auth connection if any.
Close() error Close() error
} }
@ -46,7 +46,7 @@ type AuthorizeGRPC struct {
// Authorize takes a route and user session and returns whether the // Authorize takes a route and user session and returns whether the
// request is valid per access policy // request is valid per access policy
func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions.SessionState) (bool, error) { func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions.State) (bool, error) {
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Authorize") ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Authorize")
defer span.End() defer span.End()
@ -65,7 +65,7 @@ func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions
} }
// IsAdmin takes a session and returns whether the user is an administrator // IsAdmin takes a session and returns whether the user is an administrator
func (a *AuthorizeGRPC) IsAdmin(ctx context.Context, s *sessions.SessionState) (bool, error) { func (a *AuthorizeGRPC) IsAdmin(ctx context.Context, s *sessions.State) (bool, error) {
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.IsAdmin") ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.IsAdmin")
defer span.End() defer span.End()

View file

@ -2,6 +2,8 @@ package clients
import ( import (
"context" "context"
"net/url"
"strings"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -23,12 +25,12 @@ func TestAuthorizeGRPC_Authorize(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
route string route string
s *sessions.SessionState s *sessions.State
want bool want bool
wantErr bool wantErr bool
}{ }{
{"good", "hello.pomerium.io", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false}, {"good", "hello.pomerium.io", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false},
{"impersonate request", "hello.pomerium.io", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io", ImpersonateEmail: "other@other.example"}, true, false}, {"impersonate request", "hello.pomerium.io", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io", ImpersonateEmail: "other@other.example"}, true, false},
{"session cannot be nil", "hello.pomerium.io", nil, false, true}, {"session cannot be nil", "hello.pomerium.io", nil, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -56,11 +58,11 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
s *sessions.SessionState s *sessions.State
want bool want bool
wantErr bool wantErr bool
}{ }{
{"good", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false}, {"good", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false},
{"session cannot be nil", nil, false, true}, {"session cannot be nil", nil, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -77,3 +79,41 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) {
}) })
} }
} }
func TestNewGRPC(t *testing.T) {
tests := []struct {
name string
opts *Options
wantErr bool
wantErrStr string
wantTarget string
}{
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""},
{"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
{"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
{"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
{"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
{"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"},
{"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"},
{"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"},
{"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewGRPCAuthorizeClient(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("NewGRPCAuthorizeClient() error = %v, wantErr %v", err, tt.wantErr)
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
t.Errorf("NewGRPCAuthorizeClient() error = %v did not contain wantErr %v", err, tt.wantErrStr)
}
}
if got != nil && got.Conn.Target() != tt.wantTarget {
t.Errorf("NewGRPCAuthorizeClient() target = %v expected %v", got.Conn.Target(), tt.wantTarget)
}
})
}
}

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"go.opencensus.io/plugin/ocgrpc" "go.opencensus.io/plugin/ocgrpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/balancer/roundrobin"
@ -25,7 +26,7 @@ const defaultGRPCPort = 443
// Options contains options for connecting to a pomerium rpc service. // Options contains options for connecting to a pomerium rpc service.
type Options struct { type Options struct {
// Addr is the location of the authenticate service. e.g. "service.corp.example:8443" // Addr is the location of the service. e.g. "service.corp.example:8443"
Addr *url.URL Addr *url.URL
// InternalAddr is the internal (behind the ingress) address to use when // InternalAddr is the internal (behind the ingress) address to use when
// making a connection. If empty, Addr is used. // making a connection. If empty, Addr is used.
@ -34,7 +35,7 @@ type Options struct {
// returned certificates from the server. gRPC internals also use it to override the virtual // returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set. // hosting name if it is set.
OverrideCertificateName string OverrideCertificateName string
// Shared secret is used to authenticate a authenticate-client with a authenticate-server. // Shared secret is used to mutually authenticate a client and server.
SharedSecret string SharedSecret string
// CA specifies the base64 encoded TLS certificate authority to use. // CA specifies the base64 encoded TLS certificate authority to use.
CA string CA string

View file

@ -6,35 +6,6 @@ import (
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
) )
// MockAuthenticate provides a mocked implementation of the authenticator interface.
type MockAuthenticate struct {
RedeemError error
RedeemResponse *sessions.SessionState
RefreshResponse *sessions.SessionState
RefreshError error
ValidateResponse bool
ValidateError error
CloseError error
}
// Redeem is a mocked authenticator client function.
func (a MockAuthenticate) Redeem(ctx context.Context, code string) (*sessions.SessionState, error) {
return a.RedeemResponse, a.RedeemError
}
// Refresh is a mocked authenticator client function.
func (a MockAuthenticate) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
return a.RefreshResponse, a.RefreshError
}
// Validate is a mocked authenticator client function.
func (a MockAuthenticate) Validate(ctx context.Context, idToken string) (bool, error) {
return a.ValidateResponse, a.ValidateError
}
// Close is a mocked authenticator client function.
func (a MockAuthenticate) Close() error { return a.CloseError }
// MockAuthorize provides a mocked implementation of the authorizer interface. // MockAuthorize provides a mocked implementation of the authorizer interface.
type MockAuthorize struct { type MockAuthorize struct {
AuthorizeResponse bool AuthorizeResponse bool
@ -48,11 +19,11 @@ type MockAuthorize struct {
func (a MockAuthorize) Close() error { return a.CloseError } func (a MockAuthorize) Close() error { return a.CloseError }
// Authorize is a mocked authorizer client function. // Authorize is a mocked authorizer client function.
func (a MockAuthorize) Authorize(ctx context.Context, route string, s *sessions.SessionState) (bool, error) { func (a MockAuthorize) Authorize(ctx context.Context, route string, s *sessions.State) (bool, error) {
return a.AuthorizeResponse, a.AuthorizeError return a.AuthorizeResponse, a.AuthorizeError
} }
// IsAdmin is a mocked IsAdmin function. // IsAdmin is a mocked IsAdmin function.
func (a MockAuthorize) IsAdmin(ctx context.Context, s *sessions.SessionState) (bool, error) { func (a MockAuthorize) IsAdmin(ctx context.Context, s *sessions.State) (bool, error) {
return a.IsAdminResponse, a.IsAdminError return a.IsAdminResponse, a.IsAdminError
} }

View file

@ -1,57 +0,0 @@
package clients
import (
"context"
"errors"
"reflect"
"testing"
"github.com/pomerium/pomerium/internal/sessions"
)
func TestMockAuthenticate(t *testing.T) {
// Absurd, but I caught a typo this way.
redeemResponse := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
}
ma := &MockAuthenticate{
RedeemError: errors.New("redeem error"),
RedeemResponse: redeemResponse,
RefreshResponse: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
},
RefreshError: errors.New("refresh error"),
ValidateResponse: true,
ValidateError: errors.New("validate error"),
CloseError: errors.New("close error"),
}
got, gotErr := ma.Redeem(context.Background(), "a")
if gotErr.Error() != "redeem error" {
t.Errorf("unexpected value for gotErr %s", gotErr)
}
if !reflect.DeepEqual(redeemResponse, got) {
t.Errorf("unexpected value for redeemResponse %s", got)
}
newSession, gotErr := ma.Refresh(context.Background(), nil)
if gotErr.Error() != "refresh error" {
t.Errorf("unexpected value for gotErr %s", gotErr)
}
if !reflect.DeepEqual(newSession, redeemResponse) {
t.Errorf("unexpected value for newSession %s", newSession)
}
ok, gotErr := ma.Validate(context.Background(), "a")
if !ok {
t.Errorf("unexpected value for ok : %t", ok)
}
if gotErr.Error() != "validate error" {
t.Errorf("unexpected value for gotErr %s", gotErr)
}
gotErr = ma.Close()
if gotErr.Error() != "close error" {
t.Errorf("unexpected value for ma.CloseError %s", gotErr)
}
}

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/urlutil"
) )
// StateParameter holds the redirect id along with the session id. // StateParameter holds the redirect id along with the session id.
@ -36,9 +37,9 @@ func (p *Proxy) Handler() http.Handler {
mux.HandleFunc("/.pomerium", p.UserDashboard) mux.HandleFunc("/.pomerium", p.UserDashboard)
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
mux.HandleFunc("/.pomerium/sign_out", p.SignOut) mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
// handlers handlers with validation // handlers with validation
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback)) mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback))
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh)) mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh))
mux.Handle("/", validate.ThenFunc(p.Proxy)) mux.Handle("/", validate.ThenFunc(p.Proxy))
return mux return mux
} }
@ -60,12 +61,12 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
uri, err := url.Parse(r.Form.Get("redirect_uri")) uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err == nil && uri.String() != "" { if err == nil && uri.String() != "" {
redirectURL = uri redirectURL = uri
} }
default: default:
uri, err := url.Parse(r.URL.Query().Get("redirect_uri")) uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri"))
if err == nil && uri.String() != "" { if err == nil && uri.String() != "" {
redirectURL = uri redirectURL = uri
} }
@ -76,24 +77,20 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
// OAuthStart begins the authenticate flow, encrypting the redirect url // OAuthStart begins the authenticate flow, encrypting the redirect url
// in a request to the provider's sign in endpoint. // in a request to the provider's sign in endpoint.
func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
// create a CSRF value used to mitigate replay attacks.
state := &StateParameter{ state := &StateParameter{
SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()), SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()),
RedirectURI: r.URL.String(), RedirectURI: r.URL.String(),
} }
// Encrypt, and save CSRF state. Will be checked on callback. // Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback.
localState, err := p.cipher.Marshal(state) csrfState, err := p.cipher.Marshal(state)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
p.csrfStore.SetCSRF(w, r, localState) p.csrfStore.SetCSRF(w, r, csrfState)
// Though the plaintext payload is identical, we re-encrypt which will paramState, err := p.cipher.Marshal(state)
// create a different cipher text using another nonce
remoteState, err := p.cipher.Marshal(state)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
@ -101,68 +98,55 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
// Sanity check. The encrypted payload of local and remote state should // Sanity check. The encrypted payload of local and remote state should
// never match as each encryption round uses a cryptographic nonce. // never match as each encryption round uses a cryptographic nonce.
// // if paramState == csrfState {
// todo(bdd): since this should nearly (1/(2^32*2^32)) never happen should // httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
// we panic as a failure most likely means the rands entropy source is failing? // return
if remoteState == localState { // }
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
return
}
signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), remoteState) signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState)
log.FromRequest(r).Debug().Str("SigninURL", signinURL.String()).Msg("proxy: oauth start")
// Redirect the user to the authenticate service along with the encrypted // Redirect the user to the authenticate service along with the encrypted
// state which contains a redirect uri back to the proxy and a nonce // state which contains a redirect uri back to the proxy and a nonce
http.Redirect(w, r, signinURL.String(), http.StatusFound) http.Redirect(w, r, signinURL.String(), http.StatusFound)
} }
// OAuthCallback validates the cookie sent back from the authenticate service. This function will // AuthenticateCallback checks the state parameter to make sure it matches the
// contain an error, or it will contain a `code`; the code can be used to fetch an access token, and // local csrf state then redirects the user back to the original intended route.
// other metadata, from the authenticator. func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) {
// finish the oauth cycle
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
if callbackError := r.Form.Get("error"); callbackError != "" {
httputil.ErrorResponse(w, r, httputil.Error(callbackError, http.StatusBadRequest, nil))
return
}
// Encrypted CSRF passed from authenticate service // Encrypted CSRF passed from authenticate service
remoteStateEncrypted := r.Form.Get("state") remoteStateEncrypted := r.Form.Get("state")
remoteStatePlain := new(StateParameter) var remoteStatePlain StateParameter
if err := p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain); err != nil { if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
// Encrypted CSRF from session storage
c, err := p.csrfStore.GetCSRF(r) c, err := p.csrfStore.GetCSRF(r)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
p.csrfStore.ClearCSRF(w, r) p.csrfStore.ClearCSRF(w, r)
localStateEncrypted := c.Value localStateEncrypted := c.Value
localStatePlain := new(StateParameter) var localStatePlain StateParameter
err = p.cipher.Unmarshal(localStateEncrypted, localStatePlain) err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
// If the encrypted value of local and remote state match, reject. // assert no nonce reuse
// Likely a replay attack or nonce-reuse.
if remoteStateEncrypted == localStateEncrypted { if remoteStateEncrypted == localStateEncrypted {
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r,
httputil.ErrorResponse(w, r, httputil.Error("local and remote state should not match!", http.StatusBadRequest, nil)) httputil.Error("local and remote state", http.StatusBadRequest,
fmt.Errorf("possible nonce-reuse / replay attack")))
return return
} }
@ -205,13 +189,23 @@ func isCORSPreflight(r *http.Request) bool {
r.Header.Get("Origin") != "" r.Header.Get("Origin") != ""
} }
func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) {
s, err := p.sessionStore.LoadSession(r)
if err != nil {
return nil, fmt.Errorf("proxy: invalid session: %w", err)
}
if err := s.Valid(); err != nil {
return nil, fmt.Errorf("proxy: invalid state: %w", err)
}
return s, nil
}
// Proxy authenticates a request, either proxying the request if it is authenticated, // Proxy authenticates a request, either proxying the request if it is authenticated,
// or starting the authenticate service for validation if not. // or starting the authenticate service for validation if not.
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// does a route exist for this request?
route, ok := p.router(r) route, ok := p.router(r)
if !ok { if !ok {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not a managed route.", r.Host), http.StatusNotFound, nil)) httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
return return
} }
@ -221,30 +215,17 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
return return
} }
s, err := p.restStore.LoadSession(r) s, err := p.loadExistingSession(r)
// if authorization bearer token does not exist or fails, use cookie store if err != nil {
if err != nil || s == nil { log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
s, err = p.sessionStore.LoadSession(r) p.OAuthStart(w, r)
if err != nil {
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("proxy: invalid session, re-authenticating")
p.sessionStore.ClearSession(w, r)
p.OAuthStart(w, r)
return
}
}
if err = p.authenticate(w, r, s); err != nil {
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r, httputil.Error("User unauthenticated", http.StatusUnauthorized, err))
return return
} }
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s) authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} } else if !authorized {
if !authorized {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil)) httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil))
return return
} }
@ -259,20 +240,13 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// It also contains certain administrative actions like user impersonation. // It also contains certain administrative actions like user impersonation.
// Nota bene: This endpoint does authentication, not authorization. // Nota bene: This endpoint does authentication, not authorization.
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
session, err := p.sessionStore.LoadSession(r) session, err := p.loadExistingSession(r)
if err != nil { if err != nil {
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("proxy: no session, redirecting to auth") log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.sessionStore.ClearSession(w, r)
p.OAuthStart(w, r) p.OAuthStart(w, r)
return return
} }
if err := p.authenticate(w, r, session); err != nil {
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r, httputil.Error("User unauthenticated", http.StatusUnauthorized, err))
return
}
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"} redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil { if err != nil {
@ -314,13 +288,14 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
templates.New().ExecuteTemplate(w, "dashboard.html", t) templates.New().ExecuteTemplate(w, "dashboard.html", t)
} }
// Refresh redeems and extends an existing authenticated oidc session with // ForceRefresh redeems and extends an existing authenticated oidc session with
// the underlying identity provider. All session details including groups, // the underlying identity provider. All session details including groups,
// timeouts, will be renewed. // timeouts, will be renewed.
func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
session, err := p.sessionStore.LoadSession(r) session, err := p.loadExistingSession(r)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
return return
} }
iss, err := session.IssuedAt() iss, err := session.IssuedAt()
@ -332,16 +307,13 @@ func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) {
// reject a refresh if it's been less than the refresh cooldown to prevent abuse // reject a refresh if it's been less than the refresh cooldown to prevent abuse
if time.Since(iss) < p.refreshCooldown { if time.Since(iss) < p.refreshCooldown {
httputil.ErrorResponse(w, r, httputil.ErrorResponse(w, r,
httputil.Error(fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown), http.StatusBadRequest, nil)) httputil.Error(
fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown),
http.StatusBadRequest, nil))
return return
} }
session.ForceRefresh()
newSession, err := p.AuthenticateClient.Refresh(r.Context(), session) if err = p.sessionStore.SaveSession(w, r, session); err != nil {
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
if err = p.sessionStore.SaveSession(w, r, newSession); err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
@ -357,12 +329,12 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
session, err := p.sessionStore.LoadSession(r) session, err := p.loadExistingSession(r)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
return return
} }
// authorization check -- is this user an admin?
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin { if err != nil || !isAdmin {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err)) httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err))
@ -376,7 +348,7 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
} }
p.csrfStore.ClearCSRF(w, r) p.csrfStore.ClearCSRF(w, r)
encryptedCSRF := c.Value encryptedCSRF := c.Value
decryptedCSRF := new(StateParameter) var decryptedCSRF StateParameter
if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil { if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
@ -398,26 +370,6 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/.pomerium", http.StatusFound) http.Redirect(w, r, "/.pomerium", http.StatusFound)
} }
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
// clearing the session cookie if it's invalid and returning an error if necessary..
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request, s *sessions.SessionState) error {
if s.RefreshPeriodExpired() {
s, err := p.AuthenticateClient.Refresh(r.Context(), s)
if err != nil {
return fmt.Errorf("proxy: session refresh failed : %v", err)
}
if err := p.sessionStore.SaveSession(w, r, s); err != nil {
return fmt.Errorf("proxy: refresh failed : %v", err)
}
} else {
valid, err := p.AuthenticateClient.Validate(r.Context(), s.IDToken)
if err != nil || !valid {
return fmt.Errorf("proxy: session validate failed: %v : %v", valid, err)
}
}
return nil
}
// router attempts to find a route for a request. If a route is successfully matched, // router attempts to find a route for a request. If a route is successfully matched,
// it returns the route information and a bool value of `true`. If a route can // it returns the route information and a bool value of `true`. If a route can
// not be matched, a nil value for the route and false bool value is returned. // not be matched, a nil value for the route and false bool value is returned.
@ -461,7 +413,7 @@ func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"}) a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"})
now := time.Now() now := time.Now()
rawRedirect := redirectURL.String() rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery) params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
params.Set("redirect_uri", rawRedirect) params.Set("redirect_uri", rawRedirect)
params.Set("shared_secret", p.SharedKey) params.Set("shared_secret", p.SharedKey)
params.Set("response_type", "code") params.Set("response_type", "code")
@ -477,7 +429,7 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"}) a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"})
now := time.Now() now := time.Now()
rawRedirect := redirectURL.String() rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery) params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
params.Add("redirect_uri", rawRedirect) params.Add("redirect_uri", rawRedirect)
params.Set("ts", fmt.Sprint(now.Unix())) params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now)) params.Set("sig", p.signRedirectURL(rawRedirect, now))

View file

@ -72,7 +72,6 @@ func TestProxy_GetRedirectURL(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}} p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}}
if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) { if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want) t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want)
} }
@ -240,8 +239,7 @@ func TestProxy_router(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.AuthenticateClient = clients.MockAuthenticate{} p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
p.cipher = mockCipher{}
req := httptest.NewRequest(http.MethodGet, tt.host, nil) req := httptest.NewRequest(http.MethodGet, tt.host, nil)
_, ok := p.router(req) _, ok := p.router(req)
@ -253,7 +251,7 @@ func TestProxy_router(t *testing.T) {
} }
func TestProxy_Proxy(t *testing.T) { func TestProxy_Proxy(t *testing.T) {
goodSession := &sessions.SessionState{ goodSession := &sessions.State{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
@ -278,39 +276,34 @@ func TestProxy_Proxy(t *testing.T) {
headersWs.Set("Upgrade", "websocket") headersWs.Set("Upgrade", "websocket")
tests := []struct { tests := []struct {
name string name string
options config.Options options config.Options
method string method string
header http.Header header http.Header
host string host string
session sessions.SessionStore session sessions.SessionStore
authenticator clients.Authenticator authorizer clients.Authorizer
authorizer clients.Authorizer wantStatus int
wantStatus int
}{ }{
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// same request as above, but with cors_allow_preflight=false in the policy // same request as above, but with cors_allow_preflight=false in the policy
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// cors allowed, but the request is missing proper headers // cors allowed, but the request is missing proper headers
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"unexpected error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
// redirect to start auth process // redirect to start auth process
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError}, {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
// authenticate errors // authenticate errors
{"weird load session error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest}, {"session error, redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"failed refreshed session", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RefreshError: errors.New("refresh error")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized}, {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"cannot resave refreshed session", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{SaveError: errors.New("weird"), Session: &sessions.SessionState{RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized}, {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"authenticate validation error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: false}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized}, {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound}, {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
// no session, redirect to login
{"no http found (no session)", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
} }
for _, tt := range tests { for _, tt := range tests {
@ -323,13 +316,13 @@ func TestProxy_Proxy(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.cipher = mockCipher{} p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, tt.host, nil) r := httptest.NewRequest(tt.method, tt.host, nil)
r.Header = tt.header r.Header = tt.header
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Proxy(w, r) p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
@ -348,23 +341,21 @@ func TestProxy_Proxy(t *testing.T) {
func TestProxy_UserDashboard(t *testing.T) { func TestProxy_UserDashboard(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
tests := []struct { tests := []struct {
name string name string
options config.Options options config.Options
method string method string
cipher cryptutil.Cipher cipher cryptutil.Cipher
session sessions.SessionStore session sessions.SessionStore
authenticator clients.Authenticator authorizer clients.Authorizer
authorizer clients.Authorizer
wantAdminForm bool wantAdminForm bool
wantStatus int wantStatus int
}{ }{
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusOK}, {"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusBadRequest}, {"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, false, http.StatusFound},
{"auth failure, validation error", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthenticate{ValidateError: errors.New("not valid anymore"), ValidateResponse: false}, clients.MockAuthorize{}, false, http.StatusUnauthorized}, {"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, {"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, {"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
{"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
@ -375,15 +366,18 @@ func TestProxy_UserDashboard(t *testing.T) {
} }
p.cipher = tt.cipher p.cipher = tt.cipher
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil) r := httptest.NewRequest(tt.method, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.UserDashboard(w, r) p.UserDashboard(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", opts) t.Errorf("\n%+v", opts)
t.Errorf("\n%+v", w.Body.String())
} }
if adminForm := strings.Contains(w.Body.String(), "impersonate"); adminForm != tt.wantAdminForm { if adminForm := strings.Contains(w.Body.String(), "impersonate"); adminForm != tt.wantAdminForm {
t.Errorf("wanted admin form got %v want %v", adminForm, tt.wantAdminForm) t.Errorf("wanted admin form got %v want %v", adminForm, tt.wantAdminForm)
@ -393,28 +387,27 @@ func TestProxy_UserDashboard(t *testing.T) {
} }
} }
func TestProxy_Refresh(t *testing.T) { func TestProxy_ForceRefresh(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
opts.RefreshCooldown = 0 opts.RefreshCooldown = 0
timeSinceError := testOptions(t) timeSinceError := testOptions(t)
timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1)) timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1))
tests := []struct { tests := []struct {
name string name string
options config.Options options config.Options
method string method string
cipher cryptutil.Cipher cipher cryptutil.Cipher
session sessions.SessionStore session sessions.SessionStore
authenticator clients.Authenticator authorizer clients.Authorizer
authorizer clients.Authorizer wantStatus int
wantStatus int
}{ }{
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusFound}, {"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, http.StatusFound},
{"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusBadRequest}, {"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
{"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{RefreshError: errors.New("err")}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -424,12 +417,11 @@ func TestProxy_Refresh(t *testing.T) {
} }
p.cipher = tt.cipher p.cipher = tt.cipher
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil) r := httptest.NewRequest(tt.method, "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Refresh(w, r) p.ForceRefresh(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", opts) t.Errorf("\n%+v", opts)
@ -442,30 +434,29 @@ func TestProxy_Impersonate(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
tests := []struct { tests := []struct {
name string name string
malformed bool malformed bool
options config.Options options config.Options
method string method string
email string email string
groups string groups string
csrf string csrf string
cipher cryptutil.Cipher cipher cryptutil.Cipher
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
csrfStore sessions.CSRFStore csrfStore sessions.CSRFStore
authenticator clients.Authenticator authorizer clients.Authorizer
authorizer clients.Authorizer wantStatus int
wantStatus int
}{ }{
{"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, // {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, {"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, {"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
{"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -476,7 +467,6 @@ func TestProxy_Impersonate(t *testing.T) {
p.cipher = tt.cipher p.cipher = tt.cipher
p.sessionStore = tt.sessionStore p.sessionStore = tt.sessionStore
p.csrfStore = tt.csrfStore p.csrfStore = tt.csrfStore
p.AuthenticateClient = tt.authenticator
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
postForm := url.Values{} postForm := url.Values{}
postForm.Add("email", tt.email) postForm.Add("email", tt.email)
@ -501,19 +491,17 @@ func TestProxy_Impersonate(t *testing.T) {
func TestProxy_OAuthCallback(t *testing.T) { func TestProxy_OAuthCallback(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
csrf sessions.MockCSRFStore csrf sessions.MockCSRFStore
session sessions.MockSessionStore session sessions.MockSessionStore
authenticator clients.MockAuthenticate params map[string]string
params map[string]string wantCode int
wantCode int
}{ }{
{"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound}, {"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
{"error", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"error": "some error"}, http.StatusBadRequest}, {"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
{"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError}, {"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, {"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError}, {"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
@ -524,7 +512,6 @@ func TestProxy_OAuthCallback(t *testing.T) {
} }
proxy.sessionStore = &tt.session proxy.sessionStore = &tt.session
proxy.csrfStore = tt.csrf proxy.csrfStore = tt.csrf
proxy.AuthenticateClient = tt.authenticator
proxy.cipher = mockCipher{} proxy.cipher = mockCipher{}
// proxy.Csrf // proxy.Csrf
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil) req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
@ -537,7 +524,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
req.URL.RawQuery = "email=%zzzzz" req.URL.RawQuery = "email=%zzzzz"
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.OAuthCallback(w, req) proxy.AuthenticateCallback(w, req)
if status := w.Code; status != tt.wantCode { if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
} }

View file

@ -2,11 +2,9 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
import ( import (
"crypto/tls" "crypto/tls"
"encoding/base64"
"fmt" "fmt"
"html/template" "html/template"
stdlog "log" stdlog "log"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@ -39,51 +37,27 @@ const (
// ValidateOptions checks that proper configuration settings are set to create // ValidateOptions checks that proper configuration settings are set to create
// a proper Proxy instance // a proper Proxy instance
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
decoded, err := base64.StdEncoding.DecodeString(o.SharedKey) if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
if err != nil { return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
return fmt.Errorf("`SHARED_SECRET` setting is invalid base64: %v", err)
} }
if len(decoded) != 32 { if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("`SHARED_SECRET` want 32 but got %d bytes", len(decoded)) return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
} }
if o.AuthenticateURL == nil { if o.AuthenticateURL == nil {
return fmt.Errorf("proxy: missing setting: authenticate-service-url") return fmt.Errorf("proxy: missing 'AUTHENTICATE_SERVICE_URL'")
} }
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil { if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
return fmt.Errorf("proxy: error parsing authenticate url: %v", err) return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
} }
if o.AuthorizeURL == nil { if o.AuthorizeURL == nil {
return fmt.Errorf("proxy: missing setting: authenticate-service-url") return fmt.Errorf("proxy: missing 'AUTHORIZE_SERVICE_URL'")
} }
if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil { if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil {
return fmt.Errorf("proxy: error parsing authorize url: %v", err) return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err)
}
if o.AuthenticateInternalAddr != nil {
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddr.String()); err != nil {
return fmt.Errorf("proxy: error parsing authorize url: %v", err)
}
}
if o.CookieSecret == "" {
return fmt.Errorf("proxy: missing setting: cookie-secret")
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
if err != nil {
return fmt.Errorf("proxy: cookie secret is invalid base64: %v", err)
}
if len(decodedCookieSecret) != 32 {
return fmt.Errorf("proxy: cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
} }
if len(o.SigningKey) != 0 { if len(o.SigningKey) != 0 {
decodedSigningKey, err := base64.StdEncoding.DecodeString(o.SigningKey) if _, err := cryptutil.NewES256Signer(o.SigningKey, "localhost"); err != nil {
if err != nil { return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
return fmt.Errorf("proxy: signing key is invalid base64: %v", err)
}
_, err = cryptutil.NewES256Signer(decodedSigningKey, "localhost")
if err != nil {
return fmt.Errorf("proxy: invalid signing key is : %v", err)
} }
} }
return nil return nil
@ -92,12 +66,11 @@ func ValidateOptions(o config.Options) error {
// Proxy stores all the information associated with proxying a request. // Proxy stores all the information associated with proxying a request.
type Proxy struct { type Proxy struct {
// SharedKey used to mutually authenticate service communication // SharedKey used to mutually authenticate service communication
SharedKey string SharedKey string
authenticateURL *url.URL authenticateURL *url.URL
authenticateInternalAddr *url.URL authorizeURL *url.URL
authorizeURL *url.URL
AuthenticateClient clients.Authenticator AuthorizeClient clients.Authorizer
AuthorizeClient clients.Authorizer
cipher cryptutil.Cipher cipher cryptutil.Cipher
cookieName string cookieName string
@ -105,7 +78,6 @@ type Proxy struct {
defaultUpstreamTimeout time.Duration defaultUpstreamTimeout time.Duration
redirectURL *url.URL redirectURL *url.URL
refreshCooldown time.Duration refreshCooldown time.Duration
restStore sessions.SessionStore
routeConfigs map[string]*routeConfig routeConfigs map[string]*routeConfig
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
signingKey string signingKey string
@ -123,11 +95,9 @@ func New(opts config.Options) (*Proxy, error) {
if err := ValidateOptions(opts); err != nil { if err := ValidateOptions(opts); err != nil {
return nil, err return nil, err
} }
// error explicitly handled by validate cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, err := cryptutil.NewCipher(decodedSecret)
if err != nil { if err != nil {
return nil, fmt.Errorf("cookie-secret error: %s", err.Error()) return nil, err
} }
cookieStore, err := sessions.NewCookieStore( cookieStore, err := sessions.NewCookieStore(
@ -140,10 +110,6 @@ func New(opts config.Options) (*Proxy, error) {
CookieCipher: cipher, CookieCipher: cipher,
}) })
if err != nil {
return nil, err
}
restStore, err := sessions.NewRestStore(&sessions.RestStoreOptions{Cipher: cipher})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,7 +124,6 @@ func New(opts config.Options) (*Proxy, error) {
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout, defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
redirectURL: &url.URL{Path: "/.pomerium/callback"}, redirectURL: &url.URL{Path: "/.pomerium/callback"},
refreshCooldown: opts.RefreshCooldown, refreshCooldown: opts.RefreshCooldown,
restStore: restStore,
sessionStore: cookieStore, sessionStore: cookieStore,
signingKey: opts.SigningKey, signingKey: opts.SigningKey,
templates: templates.New(), templates: templates.New(),
@ -166,7 +131,6 @@ func New(opts config.Options) (*Proxy, error) {
// DeepCopy urls to avoid accidental mutation, err checked in validate func // DeepCopy urls to avoid accidental mutation, err checked in validate func
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
p.authenticateInternalAddr, _ = urlutil.DeepCopy(opts.AuthenticateInternalAddr)
if err := p.UpdatePolicies(&opts); err != nil { if err := p.UpdatePolicies(&opts); err != nil {
return nil, err return nil, err
@ -174,20 +138,6 @@ func New(opts config.Options) (*Proxy, error) {
metrics.AddPolicyCountCallback("proxy", func() int64 { metrics.AddPolicyCountCallback("proxy", func() int64 {
return int64(len(p.routeConfigs)) return int64(len(p.routeConfigs))
}) })
p.AuthenticateClient, err = clients.NewAuthenticateClient("grpc",
&clients.Options{
Addr: p.authenticateURL,
InternalAddr: p.authenticateInternalAddr,
OverrideCertificateName: opts.OverrideCertificateName,
SharedSecret: opts.SharedKey,
CA: opts.CA,
CAFile: opts.CAFile,
RequestTimeout: opts.GRPCClientTimeout,
ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin,
})
if err != nil {
return nil, err
}
p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc", p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc",
&clients.Options{ &clients.Options{
Addr: p.authorizeURL, Addr: p.authorizeURL,
@ -213,19 +163,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
} }
proxy := NewReverseProxy(policy.Destination) proxy := NewReverseProxy(policy.Destination)
// build http transport (roundtripper) middleware chain // build http transport (roundtripper) middleware chain
// todo(bdd): replace with transport.Clone() in go 1.13 transport := http.DefaultTransport.(*http.Transport).Clone()
transport := http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
c := tripper.NewChain() c := tripper.NewChain()
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host)) c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host))
@ -253,7 +191,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
if isCustomClientConfig { if isCustomClientConfig {
transport.TLSClientConfig = &tlsClientConfig transport.TLSClientConfig = &tlsClientConfig
} }
proxy.Transport = c.Then(&transport) proxy.Transport = c.Then(transport)
handler, err := p.newReverseProxyHandler(proxy, &policy) handler, err := p.newReverseProxyHandler(proxy, &policy)
if err != nil { if err != nil {
@ -298,15 +236,6 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
return proxy return proxy
} }
// newRouteSigner creates a route specific signer.
func (p *Proxy) newRouteSigner(audience string) (cryptutil.JWTSigner, error) {
decodedSigningKey, err := base64.StdEncoding.DecodeString(p.signingKey)
if err != nil {
return nil, err
}
return cryptutil.NewES256Signer(decodedSigningKey, audience)
}
// newReverseProxyHandler applies handler specific options to a given route. // newReverseProxyHandler applies handler specific options to a given route.
func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) { func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) {
handler = &UpstreamProxy{ handler = &UpstreamProxy{
@ -318,7 +247,7 @@ func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.
// if signing key is set, add signer to middleware // if signing key is set, add signer to middleware
if len(p.signingKey) != 0 { if len(p.signingKey) != 0 {
signer, err := p.newRouteSigner(route.Source.Host) signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -169,9 +169,6 @@ func TestOptions_Validate(t *testing.T) {
authurl, _ := url.Parse("authenticate.corp.beyondperimeter.com") authurl, _ := url.Parse("authenticate.corp.beyondperimeter.com")
authenticateBadScheme := testOptions(t) authenticateBadScheme := testOptions(t)
authenticateBadScheme.AuthenticateURL = authurl authenticateBadScheme.AuthenticateURL = authurl
authenticateInternalBadScheme := testOptions(t)
authenticateInternalBadScheme.AuthenticateInternalAddr = authurl
authorizeBadSCheme := testOptions(t) authorizeBadSCheme := testOptions(t)
authorizeBadSCheme.AuthorizeURL = authurl authorizeBadSCheme.AuthorizeURL = authurl
authorizeNil := testOptions(t) authorizeNil := testOptions(t)
@ -200,7 +197,6 @@ func TestOptions_Validate(t *testing.T) {
{"nil options", config.Options{}, true}, {"nil options", config.Options{}, true},
{"authenticate service url", badAuthURL, true}, {"authenticate service url", badAuthURL, true},
{"authenticate service url no scheme", authenticateBadScheme, true}, {"authenticate service url no scheme", authenticateBadScheme, true},
{"internal authenticate service url no scheme", authenticateInternalBadScheme, true},
{"authorize service url no scheme", authorizeBadSCheme, true}, {"authorize service url no scheme", authorizeBadSCheme, true},
{"authorize service cannot be nil", authorizeNil, true}, {"authorize service cannot be nil", authorizeNil, true},
{"no cookie secret", emptyCookieSecret, true}, {"no cookie secret", emptyCookieSecret, true},
@ -221,7 +217,6 @@ func TestOptions_Validate(t *testing.T) {
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
good := testOptions(t) good := testOptions(t)
shortCookieLength := testOptions(t) shortCookieLength := testOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="