mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 07:37:33 +02:00
authenticate: make service http only
- Rename SessionState to State to avoid stutter. - Simplified option validation to use a wrapper function for base64 secrets. - Removed authenticates grpc code. - Abstracted logic to load and validate a user's authenticate session. - Removed instances of url.Parse in favor of urlutil's version. - proxy: replaces grpc refresh logic with forced deadline advancement. - internal/sessions: remove rest store; parse authorize header as part of session store. - proxy: refactor request signer - sessions: remove extend deadline (fixes #294) - remove AuthenticateInternalAddr - remove AuthenticateInternalAddrString - omit type tag.Key from declaration of vars TagKey* it will be inferred from the right-hand side - remove compatibility package xerrors - use cloned http.DefaultTransport as base transport
This commit is contained in:
parent
bc72d08ad4
commit
380d314404
53 changed files with 718 additions and 2280 deletions
1
Makefile
1
Makefile
|
@ -45,7 +45,6 @@ tag: ## Create a new git tag to prepare to build a release
|
|||
.PHONY: build
|
||||
build: ## Builds dynamic executables and/or packages.
|
||||
@echo "==> $@"
|
||||
@echo Untracked changes? dirty? $(BUILDMETA) files? $(GITUNTRACKEDCHANGES)
|
||||
@CGO_ENABLED=0 GO111MODULE=on go build -tags "$(BUILDTAGS)" ${GO_LDFLAGS} -o $(BINDIR)/$(NAME) ./cmd/"$(NAME)"
|
||||
|
||||
.PHONY: lint
|
||||
|
|
|
@ -15,36 +15,31 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
// ValidateOptions checks to see if configuration values are valid for the authenticate service.
|
||||
// The checks do not modify the internal state of the Option structure. Returns
|
||||
// on first error found.
|
||||
// ValidateOptions checks that configuration are complete and valid.
|
||||
// Returns on first error found.
|
||||
func ValidateOptions(o config.Options) error {
|
||||
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
|
||||
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err)
|
||||
}
|
||||
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
||||
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
|
||||
}
|
||||
if o.AuthenticateURL == nil {
|
||||
return errors.New("authenticate: missing setting: authenticate-service-url")
|
||||
return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required")
|
||||
}
|
||||
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
|
||||
return fmt.Errorf("authenticate: error parsing authenticate url: %v", err)
|
||||
return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err)
|
||||
}
|
||||
if o.ClientID == "" {
|
||||
return errors.New("authenticate: 'IDP_CLIENT_ID' missing")
|
||||
return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
|
||||
}
|
||||
if o.ClientSecret == "" {
|
||||
return errors.New("authenticate: 'IDP_CLIENT_SECRET' missing")
|
||||
}
|
||||
if o.SharedKey == "" {
|
||||
return errors.New("authenticate: 'SHARED_SECRET' missing")
|
||||
}
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: 'COOKIE_SECRET' must be base64 encoded: %v", err)
|
||||
}
|
||||
if len(decodedCookieSecret) != 32 {
|
||||
return fmt.Errorf("authenticate: 'COOKIE_SECRET' %s be 32; got %d", o.CookieSecret, len(decodedCookieSecret))
|
||||
return errors.New("authenticate: 'IDP_CLIENT_SECRET' is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate validates a user's identity
|
||||
// Authenticate contains data required to run the authenticate service.
|
||||
type Authenticate struct {
|
||||
SharedKey string
|
||||
RedirectURL *url.URL
|
||||
|
@ -52,12 +47,11 @@ type Authenticate struct {
|
|||
templates *template.Template
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
restStore sessions.SessionStore
|
||||
cipher cryptutil.Cipher
|
||||
provider identity.Authenticator
|
||||
}
|
||||
|
||||
// New validates and creates a new authenticate service from a set of Options
|
||||
// New validates and creates a new authenticate service from a set of Options.
|
||||
func New(opts config.Options) (*Authenticate, error) {
|
||||
if err := ValidateOptions(opts); err != nil {
|
||||
return nil, err
|
||||
|
@ -95,17 +89,13 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
restStore, err := sessions.NewRestStore(&sessions.RestStoreOptions{Cipher: cipher})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Authenticate{
|
||||
SharedKey: opts.SharedKey,
|
||||
RedirectURL: redirectURL,
|
||||
templates: templates.New(),
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
restStore: restStore,
|
||||
cipher: cipher,
|
||||
provider: provider,
|
||||
}, nil
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
//go:generate protoc -I ../proto/authenticate --go_out=plugins=grpc:../proto/authenticate ../proto/authenticate/authenticate.proto
|
||||
|
||||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// Authenticate takes an encrypted code, and returns the authentication result.
|
||||
func (p *Authenticate) Authenticate(ctx context.Context, in *pb.AuthenticateRequest) (*pb.Session, error) {
|
||||
_, span := trace.StartSpan(ctx, "authenticate.grpc.Validate")
|
||||
defer span.End()
|
||||
session, err := sessions.UnmarshalSession(in.Code, p.cipher)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: authenticate %v", err)
|
||||
}
|
||||
newSessionProto, err := pb.ProtoFromSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSessionProto, nil
|
||||
}
|
||||
|
||||
// Validate locally validates a JWT id_token; does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *Authenticate) Validate(ctx context.Context, in *pb.ValidateRequest) (*pb.ValidateReply, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "authenticate.grpc.Validate")
|
||||
defer span.End()
|
||||
|
||||
isValid, err := p.provider.Validate(ctx, in.IdToken)
|
||||
if err != nil {
|
||||
return &pb.ValidateReply{IsValid: false}, fmt.Errorf("authenticate/grpc: validate %v", err)
|
||||
}
|
||||
return &pb.ValidateReply{IsValid: isValid}, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session checks if the session has been revoked using an access token
|
||||
// without reprompting the user.
|
||||
func (p *Authenticate) Refresh(ctx context.Context, in *pb.Session) (*pb.Session, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "authenticate.grpc.Refresh")
|
||||
defer span.End()
|
||||
if in == nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: session cannot be nil")
|
||||
}
|
||||
oldSession, err := pb.SessionFromProto(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSession, err := p.provider.Refresh(ctx, oldSession)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate/grpc: refresh failed %v", err)
|
||||
}
|
||||
newSessionProto, err := pb.ProtoFromSession(newSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSessionProto, nil
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
func TestAuthenticate_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
mp *identity.MockProvider
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "example", &identity.MockProvider{}, false, false},
|
||||
{"error", "error", &identity.MockProvider{ValidateError: errors.New("err")}, false, true},
|
||||
{"not error", "not error", &identity.MockProvider{ValidateError: nil}, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Authenticate{provider: tt.mp}
|
||||
got, err := p.Validate(context.Background(), &pb.ValidateRequest{IdToken: tt.idToken})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got.IsValid, tt.want) {
|
||||
t.Errorf("Authenticate.Validate() = %v, want %v", got.IsValid, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_Refresh(t *testing.T) {
|
||||
fixedProtoTime, err := ptypes.TimestampProto(fixedDate)
|
||||
if err != nil {
|
||||
t.Fatal("failed to parse timestamp")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mock *identity.MockProvider
|
||||
originalSession *pb.Session
|
||||
want *pb.Session
|
||||
wantErr bool
|
||||
}{
|
||||
{"good",
|
||||
&identity.MockProvider{
|
||||
RefreshResponse: &sessions.SessionState{
|
||||
AccessToken: "updated",
|
||||
RefreshDeadline: fixedDate,
|
||||
}},
|
||||
&pb.Session{
|
||||
AccessToken: "original",
|
||||
RefreshDeadline: fixedProtoTime,
|
||||
},
|
||||
&pb.Session{
|
||||
AccessToken: "updated",
|
||||
RefreshDeadline: fixedProtoTime,
|
||||
},
|
||||
false},
|
||||
{"test error", &identity.MockProvider{RefreshError: errors.New("hi")}, &pb.Session{RefreshToken: "refresh token", RefreshDeadline: fixedProtoTime}, nil, true},
|
||||
{"test catch nil", nil, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Authenticate{provider: tt.mock}
|
||||
|
||||
got, err := p.Refresh(context.Background(), tt.originalSession)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Refresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authenticate.Refresh() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_Authenticate(t *testing.T) {
|
||||
secret := cryptutil.GenerateKey()
|
||||
c, err := cryptutil.NewCipher(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
newSecret := cryptutil.GenerateKey()
|
||||
c2, err := cryptutil.NewCipher(newSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
rt := time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC()
|
||||
vtProto, err := ptypes.TimestampProto(rt)
|
||||
if err != nil {
|
||||
t.Fatal("failed to parse timestamp")
|
||||
}
|
||||
|
||||
want := &sessions.SessionState{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
RefreshDeadline: rt,
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
goodReply := &pb.Session{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
RefreshDeadline: vtProto,
|
||||
Email: "user@domain.com",
|
||||
User: "user"}
|
||||
ciphertext, err := sessions.MarshalSession(want, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be encode session: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cipher cryptutil.Cipher
|
||||
code string
|
||||
want *pb.Session
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", c, ciphertext, goodReply, false},
|
||||
{"bad cipher", c2, ciphertext, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Authenticate{cipher: tt.cipher}
|
||||
got, err := p.Authenticate(context.Background(), &pb.AuthenticateRequest{Code: tt.code})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authenticate.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authenticate.Authenticate() = got: \n%vwant:\n%v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,12 +2,13 @@ package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
|||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -18,6 +19,7 @@ import (
|
|||
)
|
||||
|
||||
// CSPHeaders are the content security headers added to the service's handlers
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
|
||||
var CSPHeaders = map[string]string{
|
||||
"Content-Security-Policy": "default-src 'none'; style-src 'self'" +
|
||||
" 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " +
|
||||
|
@ -27,22 +29,24 @@ var CSPHeaders = map[string]string{
|
|||
"Referrer-Policy": "Same-origin",
|
||||
}
|
||||
|
||||
// Handler returns the authenticate service's HTTP request multiplexer, and routes.
|
||||
// Handler returns the authenticate service's HTTP multiplexer, and routes.
|
||||
func (a *Authenticate) Handler() http.Handler {
|
||||
// validation middleware chain
|
||||
c := middleware.NewChain()
|
||||
c = c.Append(middleware.SetHeaders(CSPHeaders))
|
||||
validate := c.Append(middleware.ValidateSignature(a.SharedKey))
|
||||
validate = validate.Append(middleware.ValidateRedirectURI(a.RedirectURL))
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt))
|
||||
// Identity Provider (IdP) callback endpoints and callbacks
|
||||
mux.Handle("/start", c.ThenFunc(a.OAuthStart))
|
||||
// Identity Provider (IdP) endpoints
|
||||
mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart))
|
||||
mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback))
|
||||
// authenticate-server endpoints
|
||||
mux.Handle("/sign_in", validate.ThenFunc(a.SignIn))
|
||||
mux.Handle("/sign_out", validate.ThenFunc(a.SignOut)) // POST
|
||||
// programmatic authentication endpoints
|
||||
// Proxy service endpoints
|
||||
validationMiddlewares := c.Append(
|
||||
middleware.ValidateSignature(a.SharedKey),
|
||||
middleware.ValidateRedirectURI(a.RedirectURL),
|
||||
)
|
||||
mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn))
|
||||
mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST
|
||||
// Direct user access endpoints
|
||||
mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken))
|
||||
return mux
|
||||
}
|
||||
|
@ -55,43 +59,46 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
|||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) error {
|
||||
if session.RefreshPeriodExpired() {
|
||||
session, err := a.provider.Refresh(r.Context(), session)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("session refresh failed : %w", err)
|
||||
}
|
||||
if err = a.sessionStore.SaveSession(w, r, session); err != nil {
|
||||
return xerrors.Errorf("failed saving refreshed session : %w", err)
|
||||
}
|
||||
} else {
|
||||
valid, err := a.provider.Validate(r.Context(), session.IDToken)
|
||||
if err != nil || !valid {
|
||||
return xerrors.Errorf("session valid: %v : %w", valid, err)
|
||||
}
|
||||
func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) {
|
||||
session, err := a.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
err = session.Valid()
|
||||
if err == nil {
|
||||
return session, nil
|
||||
} else if !errors.Is(err, sessions.ErrExpired) {
|
||||
return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err)
|
||||
} else {
|
||||
return a.refresh(w, r, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) {
|
||||
newSession, err := a.provider.Refresh(r.Context(), s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||
}
|
||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||
}
|
||||
return newSession, nil
|
||||
|
||||
}
|
||||
|
||||
// SignIn handles to authenticating a user.
|
||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.sessionStore.LoadSession(r)
|
||||
session, err := a.loadExisting(w, r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("no session loaded, restart auth")
|
||||
log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
a.OAuthStart(w, r)
|
||||
return
|
||||
}
|
||||
// if a session already exists, authenticate it
|
||||
if err := a.authenticate(w, r, session); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
state := r.Form.Get("state")
|
||||
if state == "" {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil))
|
||||
|
@ -100,21 +107,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri parameter passed", http.StatusBadRequest, err))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
// encrypt session state as json blob
|
||||
encrypted, err := sessions.MarshalSession(session, a.cipher)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't marshall session", http.StatusInternalServerError, err))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err))
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound)
|
||||
}
|
||||
|
||||
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
|
||||
// error handled by go's mux stack
|
||||
// ParseQuery err handled by go's mux stack
|
||||
params, _ := url.ParseQuery(redirectURL.RawQuery)
|
||||
params.Set("code", authCode)
|
||||
params.Set("state", state)
|
||||
|
@ -122,8 +128,8 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
|
|||
return redirectURL.String()
|
||||
}
|
||||
|
||||
// SignOut signs the user out by trying to revoke the user's remote identity session along with
|
||||
// the associated local session state. Handles both GET and POST.
|
||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||
// Handles both GET and POST.
|
||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
|
@ -156,7 +162,6 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
// OIDC : 3.1.2.1. Authentication Request
|
||||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||
a.csrfStore.SetCSRF(w, r, nonce)
|
||||
|
||||
// Redirection URI to which the response will be sent. This URI MUST exactly
|
||||
// match one of the Redirection URI values for the Client pre-registered at
|
||||
// at your identity provider
|
||||
|
@ -173,7 +178,6 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
|
||||
// State is the opaque value used to maintain state between the request and
|
||||
// the callback; contains both the nonce and redirect URI
|
||||
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
||||
|
@ -183,74 +187,69 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, signInURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
|
||||
// was an error. If successful, the user is redirected back to the proxy-service.
|
||||
// OAuthCallback handles the callback from the identity provider.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
|
||||
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
redirect, err := a.getOAuthCallback(w, r)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, xerrors.Errorf("oauth callback : %w", err))
|
||||
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
|
||||
return
|
||||
}
|
||||
// redirect back to the proxy-service via sign_in
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
http.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return "", httputil.Error("invalid signature", http.StatusBadRequest, err)
|
||||
return nil, httputil.Error("invalid signature", http.StatusBadRequest, err)
|
||||
}
|
||||
// OIDC : 3.1.2.6. Authentication Error Response
|
||||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError
|
||||
if errorString := r.Form.Get("error"); errorString != "" {
|
||||
return "", httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider returned error: %v", errorString))
|
||||
if idpError := r.Form.Get("error"); idpError != "" {
|
||||
return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError))
|
||||
}
|
||||
// OIDC : 3.1.2.5. Successful Authentication Response
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
|
||||
code := r.Form.Get("code")
|
||||
if code == "" {
|
||||
return "", httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil)
|
||||
return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil)
|
||||
}
|
||||
|
||||
// validate the returned code with the identity provider
|
||||
session, err := a.provider.Authenticate(r.Context(), code)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("error redeeming authenticate code: %w", err)
|
||||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||
}
|
||||
|
||||
// Opaque value used to maintain state between the request and the callback.
|
||||
// OIDC : 3.1.2.5. Successful Authentication Response
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
|
||||
// Opaque value used to maintain state between the request and the callback.
|
||||
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("failed decoding state: %w", err)
|
||||
return nil, fmt.Errorf("failed decoding state: %w", err)
|
||||
}
|
||||
s := strings.SplitN(string(bytes), ":", 2)
|
||||
if len(s) != 2 {
|
||||
return "", xerrors.Errorf("invalid state size: %v", len(s))
|
||||
return nil, fmt.Errorf("invalid state size: %d", len(s))
|
||||
}
|
||||
// state contains both our csrf nonce and the redirect uri
|
||||
// state contains the csrf nonce and redirect uri
|
||||
nonce := s[0]
|
||||
redirect := s[1]
|
||||
c, err := a.csrfStore.GetCSRF(r)
|
||||
defer a.csrfStore.ClearCSRF(w, r)
|
||||
if err != nil || c.Value != nonce {
|
||||
return "", xerrors.Errorf("csrf failure: %w", err)
|
||||
|
||||
return nil, fmt.Errorf("csrf failure: %w", err)
|
||||
}
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(redirect)
|
||||
if err != nil {
|
||||
return "", httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err)
|
||||
return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err)
|
||||
}
|
||||
// sanity check, we are redirecting back to the same subdomain right?
|
||||
if !middleware.SameDomain(redirectURL, a.RedirectURL) {
|
||||
return "", httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil)
|
||||
return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil)
|
||||
}
|
||||
|
||||
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
||||
return "", xerrors.Errorf("failed saving new session: %w", err)
|
||||
return nil, fmt.Errorf("failed saving new session: %w", err)
|
||||
}
|
||||
return redirect, nil
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
// ExchangeToken takes an identity provider issued JWT as input ('id_token)
|
||||
|
@ -263,16 +262,32 @@ func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
code := r.Form.Get("id_token")
|
||||
if code == "" {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("provider missing id token", http.StatusBadRequest, nil))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
session, err := a.provider.IDTokenToSession(r.Context(), code)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("could not exchange identity for session", http.StatusInternalServerError, err))
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
if err := a.restStore.SaveSession(w, r, session); err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("failed returning new session", http.StatusInternalServerError, err))
|
||||
encToken, err := sessions.MarshalSession(session, a.cipher)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
restSession := struct {
|
||||
Token string
|
||||
Expiry time.Time `json:",omitempty"`
|
||||
}{
|
||||
Token: encToken,
|
||||
Expiry: session.RefreshDeadline,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(restSession)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
}
|
||||
|
|
|
@ -68,22 +68,25 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
state string
|
||||
redirectURI string
|
||||
session sessions.SessionStore
|
||||
restStore sessions.SessionStore
|
||||
provider identity.MockProvider
|
||||
cipher cryptutil.Cipher
|
||||
wantCode int
|
||||
}{
|
||||
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"session refresh error", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"session save after refresh error", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
{"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh
|
||||
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
|
||||
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
{"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||
// actually caught by go's handler, but we should keep the test.
|
||||
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -178,10 +181,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
wantCode int
|
||||
wantBody string
|
||||
}{
|
||||
{"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
|
||||
{"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""},
|
||||
{"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
||||
{"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
|
||||
{"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""},
|
||||
{"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -288,19 +291,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
want string
|
||||
wantCode int
|
||||
}{
|
||||
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
|
||||
{"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
|
||||
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
|
||||
{"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
|
||||
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
||||
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -336,7 +339,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthenticate_ExchangeToken(t *testing.T) {
|
||||
cipher := &cryptutil.MockCipher{}
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
@ -346,18 +348,18 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
|
|||
provider identity.MockProvider
|
||||
want string
|
||||
}{
|
||||
{"good", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, ""},
|
||||
{"could not exchange identity for session", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, "could not exchange identity for session"},
|
||||
{"missing token", http.MethodPost, "", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, "missing id token"},
|
||||
{"save error", http.MethodPost, "token", &sessions.MockSessionStore{SaveError: errors.New("error")}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, "failed returning new session"},
|
||||
{"malformed form", http.MethodPost, "token", &sessions.RestStore{Cipher: cipher}, cipher, identity.MockProvider{IDTokenToSessionResponse: sessions.SessionState{IDToken: "ok"}}, ""},
|
||||
{"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
||||
{"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
|
||||
{"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
|
||||
{"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
||||
{"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Authenticate{
|
||||
restStore: tt.restStore,
|
||||
cipher: tt.cipher,
|
||||
provider: tt.provider,
|
||||
cipher: tt.cipher,
|
||||
provider: tt.provider,
|
||||
sessionStore: tt.restStore,
|
||||
}
|
||||
form := url.Values{}
|
||||
if tt.idToken != "" {
|
||||
|
@ -370,6 +372,7 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
|
|||
}
|
||||
r := httptest.NewRequest(tt.method, "/", strings.NewReader(rawForm))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
pbAuthenticate "github.com/pomerium/pomerium/proto/authenticate"
|
||||
pbAuthorize "github.com/pomerium/pomerium/proto/authorize"
|
||||
"github.com/pomerium/pomerium/proxy"
|
||||
)
|
||||
|
@ -47,7 +46,7 @@ func main() {
|
|||
|
||||
mux := http.NewServeMux()
|
||||
grpcServer := setupGRPCServer(opt)
|
||||
_, err = newAuthenticateService(*opt, mux, grpcServer)
|
||||
_, err = newAuthenticateService(*opt, mux)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
|
||||
}
|
||||
|
@ -62,7 +61,6 @@ func main() {
|
|||
log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
|
||||
}
|
||||
if proxy != nil {
|
||||
defer proxy.AuthenticateClient.Close()
|
||||
defer proxy.AuthorizeClient.Close()
|
||||
}
|
||||
|
||||
|
@ -82,7 +80,7 @@ func main() {
|
|||
os.Exit(0)
|
||||
}
|
||||
|
||||
func newAuthenticateService(opt config.Options, mux *http.ServeMux, rpc *grpc.Server) (*authenticate.Authenticate, error) {
|
||||
func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) {
|
||||
if !config.IsAuthenticate(opt.Services) {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -90,7 +88,6 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux, rpc *grpc.Se
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pbAuthenticate.RegisterAuthenticatorServer(rpc, service)
|
||||
mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler())
|
||||
return service, nil
|
||||
}
|
||||
|
@ -164,7 +161,7 @@ func configToServerOptions(opt *config.Options) *httputil.ServerOptions {
|
|||
func setupMetrics(opt *config.Options) {
|
||||
if opt.MetricsAddr != "" {
|
||||
if handler, err := metrics.PrometheusHandler(); err != nil {
|
||||
log.Error().Err(err).Msg("cmd/pomerium: couldn't start metrics server")
|
||||
log.Error().Err(err).Msg("cmd/pomerium: metrics failed to start")
|
||||
} else {
|
||||
metrics.SetBuildInfo(opt.Services)
|
||||
metrics.RegisterInfoMetrics()
|
||||
|
|
|
@ -21,9 +21,6 @@ import (
|
|||
)
|
||||
|
||||
func Test_newAuthenticateService(t *testing.T) {
|
||||
grpcAuth := middleware.NewSharedSecretCred("test")
|
||||
grpcOpts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcAuth.ValidateRequest)}
|
||||
grpcServer := grpc.NewServer(grpcOpts...)
|
||||
mux := http.NewServeMux()
|
||||
|
||||
tests := []struct {
|
||||
|
@ -56,7 +53,7 @@ func Test_newAuthenticateService(t *testing.T) {
|
|||
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
|
||||
}
|
||||
|
||||
_, err = newAuthenticateService(*testOpts, mux, grpcServer)
|
||||
_, err = newAuthenticateService(*testOpts, mux)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
@ -176,7 +176,6 @@ Go to **Environment** tab.
|
|||
| SHARED_SECRET | output of `head -c32 /dev/urandom | base64` |
|
||||
| AUTHORIZE_SERVICE_URL | `https://localhost` |
|
||||
| AUTHENTICATE_SERVICE_URL | `https://authenticate.int.nas.example` |
|
||||
| AUTHENTICATE_INTERNAL_URL | `https://localhost` |
|
||||
|
||||
For a detailed explanation, and additional options, please refer to the [configuration variable docs]. Also note, though not covered in this guide, settings can be made via a mounted configuration file.
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ services:
|
|||
- SERVICES=proxy
|
||||
# IMPORTANT! If you are running pomerium behind another ingress (loadbalancer/firewall/etc)
|
||||
# you must tell pomerium proxy how to communicate using an internal hostname for RPC
|
||||
- AUTHENTICATE_INTERNAL_URL=https://pomerium-authenticate
|
||||
- AUTHORIZE_SERVICE_URL=https://pomerium-authorize
|
||||
# When communicating internally, rPC is going to get a name conflict expecting an external
|
||||
# facing certificate name (i.e. authenticate-service.local vs *.corp.example.com).
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# Main configuration flags : https://www.pomerium.io/reference/
|
||||
authenticate_service_url: https://authenticate.corp.beyondperimeter.com
|
||||
authenticate_internal_url: https://pomerium-authenticate-service.default.svc.cluster.local
|
||||
authorize_service_url: https://pomerium-authorize-service.default.svc.cluster.local
|
||||
|
||||
override_certificate_name: "*.corp.beyondperimeter.com"
|
||||
|
|
|
@ -146,7 +146,7 @@ Timeouts set the global server timeouts. For route-specific timeouts, see [polic
|
|||
|
||||
## GRPC Options
|
||||
|
||||
These settings control upstream connections to the Authorize and Authenticate services.
|
||||
These settings control upstream connections to the Authorize service.
|
||||
|
||||
### GRPC Client Timeout
|
||||
|
||||
|
@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
|
|||
|
||||
| Config Key | Description | Required |
|
||||
| :--------------- | :---------------------------------------------------------------- | -------- |
|
||||
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
|
||||
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
|
||||
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
|
||||
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
|
||||
|
||||
### Jaeger
|
||||
|
||||
|
@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
|
|||
|
||||
| Config Key | Description | Required |
|
||||
| :-------------------------------- | :------------------------------------------ | -------- |
|
||||
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
|
||||
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
|
||||
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
|
||||
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
|
||||
|
||||
#### Example
|
||||
|
||||
|
@ -464,16 +464,6 @@ Signing key is the base64 encoded key used to sign outbound requests. For more i
|
|||
|
||||
Authenticate Service URL is the externally accessible URL for the authenticate service.
|
||||
|
||||
## Authenticate Internal Service URL
|
||||
|
||||
- Environmental Variable: `AUTHENTICATE_INTERNAL_URL`
|
||||
- Config File Key: `authenticate_internal_url`
|
||||
- Type: `URL`
|
||||
- Optional
|
||||
- Example: `https://pomerium-authenticate-service.default.svc.cluster.local`
|
||||
|
||||
Authenticate Internal Service URL is the internally routed dns name of the authenticate service. This setting is typically used with load balancers that do not gRPC, thus allowing you to specify an internally accessible name.
|
||||
|
||||
## Authorize Service URL
|
||||
|
||||
- Environmental Variable: `AUTHORIZE_SERVICE_URL`
|
||||
|
|
1
go.mod
1
go.mod
|
@ -26,7 +26,6 @@ require (
|
|||
golang.org/x/net v0.0.0-20190611141213-3f473d35a33a
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae // indirect
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
|
||||
google.golang.org/api v0.6.0
|
||||
google.golang.org/appengine v1.6.1 // indirect
|
||||
google.golang.org/genproto v0.0.0-20190611190212-a7e196e89fd3 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -257,8 +257,6 @@ golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBn
|
|||
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
|
||||
google.golang.org/api v0.3.2/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
|
||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||
|
|
|
@ -97,13 +97,6 @@ type Options struct {
|
|||
// (sudo) access including the ability to impersonate other users' access
|
||||
Administrators []string `mapstructure:"administrators"`
|
||||
|
||||
// AuthenticateInternalAddr is used override the routable destination of
|
||||
// authenticate service's GRPC endpoint.
|
||||
// NOTE: As many load balancers do not support externally routed gRPC so
|
||||
// this may be an internal location.
|
||||
AuthenticateInternalAddrString string `mapstructure:"authenticate_internal_url"`
|
||||
AuthenticateInternalAddr *url.URL
|
||||
|
||||
// AuthorizeURL is the routable destination of the authorize service's
|
||||
// gRPC endpoint. NOTE: As many load balancers do not support
|
||||
// externally routed gRPC so this may be an internal location.
|
||||
|
@ -246,13 +239,6 @@ func (o *Options) Validate() error {
|
|||
o.AuthorizeURL = u
|
||||
}
|
||||
|
||||
if o.AuthenticateInternalAddrString != "" {
|
||||
u, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddrString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bad authenticate-internal-addr %s : %v", o.AuthenticateInternalAddrString, err)
|
||||
}
|
||||
o.AuthenticateInternalAddr = u
|
||||
}
|
||||
if o.PolicyFile != "" {
|
||||
return errors.New("policy file setting is deprecated")
|
||||
}
|
||||
|
|
|
@ -337,7 +337,7 @@ func TestNewOptions(t *testing.T) {
|
|||
|
||||
func TestOptionsFromViper(t *testing.T) {
|
||||
opts := []cmp.Option{
|
||||
cmpopts.IgnoreFields(Options{}, "AuthenticateInternalAddr", "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
|
||||
cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
|
||||
cmpopts.IgnoreFields(Policy{}, "Source", "Destination"),
|
||||
}
|
||||
|
||||
|
@ -361,21 +361,6 @@ func TestOptionsFromViper(t *testing.T) {
|
|||
"X-XSS-Protection": "1; mode=block",
|
||||
}},
|
||||
false},
|
||||
{"good with authenticate internal url",
|
||||
[]byte(`{"authenticate_internal_url": "https://internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
|
||||
&Options{
|
||||
AuthenticateInternalAddrString: "https://internal.example",
|
||||
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
|
||||
CookieName: "_pomerium",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
Headers: map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
}},
|
||||
false},
|
||||
{"good disable header",
|
||||
[]byte(`{"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
|
||||
&Options{
|
||||
|
@ -385,7 +370,6 @@ func TestOptionsFromViper(t *testing.T) {
|
|||
CookieHTTPOnly: true,
|
||||
Headers: map[string]string{}},
|
||||
false},
|
||||
{"bad authenticate internal url", []byte(`{"authenticate_internal_url": "internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`), nil, true},
|
||||
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
|
||||
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},
|
||||
|
||||
|
|
|
@ -67,6 +67,18 @@ func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// NewCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
|
||||
func NewCipherFromBase64(s string) (*XChaCha20Cipher, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: invalid base64: %v", err)
|
||||
}
|
||||
if len(decoded) != 32 {
|
||||
return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(decoded))
|
||||
}
|
||||
return NewCipher(decoded)
|
||||
}
|
||||
|
||||
// GenerateNonce generates a random nonce.
|
||||
// Panics if source of randomness fails.
|
||||
func (c *XChaCha20Cipher) GenerateNonce() []byte {
|
||||
|
|
|
@ -259,3 +259,26 @@ func TestNewCipher(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCipherFromBase64(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false},
|
||||
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
|
||||
{"key too long", GenerateRandomString(33), true},
|
||||
{"bad base 64", string(GenerateKey()), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewCipherFromBase64(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -48,15 +49,20 @@ type ES256Signer struct {
|
|||
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
|
||||
}
|
||||
|
||||
// NewES256Signer creates an Elliptic Curve, NIST P-256 (aka secp256r1 aka prime256v1) JWT signer.
|
||||
// NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer
|
||||
// from a base64 encoded private key.
|
||||
//
|
||||
// RSA is not supported due to performance considerations of needing to sign each request.
|
||||
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
|
||||
// to length extension attacks.
|
||||
// See also:
|
||||
// - https://cloud.google.com/iot/docs/how-tos/credentials/keys
|
||||
func NewES256Signer(privKey []byte, audience string) (*ES256Signer, error) {
|
||||
key, err := DecodePrivateKey(privKey)
|
||||
func NewES256Signer(privKey, audience string) (*ES256Signer, error) {
|
||||
decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := DecodePrivateKey(decodedSigningKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: parsing key failed %v", err)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestES256Signer(t *testing.T) {
|
||||
signer, err := NewES256Signer([]byte(pemECPrivateKeyP256), "destination-url")
|
||||
signer, err := NewES256Signer(base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "destination-url")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -25,12 +26,13 @@ func TestNewES256Signer(t *testing.T) {
|
|||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
privKey []byte
|
||||
privKey string
|
||||
audience string
|
||||
wantErr bool
|
||||
}{
|
||||
{"working example", []byte(pemECPrivateKeyP256), "some-domain.com", false},
|
||||
{"bad private key", []byte(garbagePEM), "some-domain.com", true},
|
||||
{"working example", base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "some-domain.com", false},
|
||||
{"bad private key", base64.StdEncoding.EncodeToString([]byte(garbagePEM)), "some-domain.com", true},
|
||||
{"bad base64 key", garbagePEM, "some-domain.com", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -2,20 +2,18 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
)
|
||||
|
||||
// Error formats creates a HTTP error with code, user friendly (and safe) error
|
||||
// message. If nil or empty:
|
||||
// HTTP status code defaults to 500.
|
||||
// Message defaults to the text of the status code.
|
||||
// message. If nil or empty, HTTP status code defaults to 500 and message
|
||||
// defaults to the text of the status code.
|
||||
func Error(message string, code int, err error) error {
|
||||
if code == 0 {
|
||||
code = http.StatusInternalServerError
|
||||
|
@ -45,7 +43,9 @@ func (e *httpError) Error() string {
|
|||
func (e *httpError) Unwrap() error { return e.Err }
|
||||
|
||||
// Timeout reports whether this error represents a user debuggable error.
|
||||
func (e *httpError) Debugable() bool { return e.Code == http.StatusUnauthorized }
|
||||
func (e *httpError) Debugable() bool {
|
||||
return e.Code == http.StatusUnauthorized || e.Code == http.StatusForbidden
|
||||
}
|
||||
|
||||
// ErrorResponse renders an error page given an error. If the error is a
|
||||
// http error from this package, a user friendly message is set, http status code,
|
||||
|
@ -57,11 +57,12 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
|||
var requestID string
|
||||
var httpError *httpError
|
||||
// if this is an HTTPError, we can add some additional useful information
|
||||
if xerrors.As(e, &httpError) {
|
||||
if errors.As(e, &httpError) {
|
||||
canDebug = httpError.Debugable()
|
||||
statusCode = httpError.Code
|
||||
errorString = httpError.Message
|
||||
}
|
||||
|
||||
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error")
|
||||
|
||||
if id, ok := log.IDFromRequest(r); ok {
|
||||
|
@ -71,7 +72,7 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
|||
var response struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
response.Error = e.Error()
|
||||
response.Error = errorString
|
||||
writeJSONResponse(rw, statusCode, response)
|
||||
} else {
|
||||
rw.WriteHeader(statusCode)
|
||||
|
|
|
@ -129,8 +129,7 @@ func (p *GoogleProvider) GetSignInURL(state string) string {
|
|||
|
||||
// Authenticate creates an identity session with google from a authorization code, and follows up
|
||||
// call to the admin/group api to check what groups the user is in.
|
||||
func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
// convert authorization code into a token
|
||||
func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: token exchange failed %v", err)
|
||||
|
@ -153,7 +152,7 @@ func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessio
|
|||
// Refresh renews a user's session using an oidc refresh token withoutreprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity: missing refresh token")
|
||||
}
|
||||
|
@ -180,7 +179,7 @@ func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.SessionState)
|
|||
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) {
|
||||
func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: could not verify id_token %v", err)
|
||||
|
@ -200,7 +199,7 @@ func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string
|
|||
return nil, fmt.Errorf("identity/google: could not retrieve groups %v", err)
|
||||
}
|
||||
|
||||
return &sessions.SessionState{
|
||||
return &sessions.State{
|
||||
IDToken: rawIDToken,
|
||||
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
|
||||
Email: claims.Email,
|
||||
|
|
|
@ -74,7 +74,7 @@ func NewAzureProvider(p *Provider) (*AzureProvider, error) {
|
|||
|
||||
// Authenticate creates an identity session with azure from a authorization code, and follows up
|
||||
// call to the groups api to check what groups the user is in.
|
||||
func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
// convert authorization code into a token
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
|
@ -104,7 +104,7 @@ func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*session
|
|||
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) {
|
||||
func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err)
|
||||
|
@ -118,7 +118,7 @@ func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string)
|
|||
return nil, fmt.Errorf("identity/microsoft: failed to parse id_token claims %v", err)
|
||||
}
|
||||
|
||||
return &sessions.SessionState{
|
||||
return &sessions.State{
|
||||
IDToken: rawIDToken,
|
||||
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
|
||||
Email: claims.Email,
|
||||
|
@ -146,7 +146,7 @@ func (p *AzureProvider) GetSignInURL(state string) string {
|
|||
// Refresh renews a user's session using an oid refresh token without reprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity/microsoft: missing refresh token")
|
||||
}
|
||||
|
|
|
@ -8,25 +8,25 @@ import (
|
|||
|
||||
// MockProvider provides a mocked implementation of the providers interface.
|
||||
type MockProvider struct {
|
||||
AuthenticateResponse sessions.SessionState
|
||||
AuthenticateResponse sessions.State
|
||||
AuthenticateError error
|
||||
IDTokenToSessionResponse sessions.SessionState
|
||||
IDTokenToSessionResponse sessions.State
|
||||
IDTokenToSessionError error
|
||||
ValidateResponse bool
|
||||
ValidateError error
|
||||
RefreshResponse *sessions.SessionState
|
||||
RefreshResponse *sessions.State
|
||||
RefreshError error
|
||||
RevokeError error
|
||||
GetSignInURLResponse string
|
||||
}
|
||||
|
||||
// Authenticate is a mocked providers function.
|
||||
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
return &mp.AuthenticateResponse, mp.AuthenticateError
|
||||
}
|
||||
|
||||
// IDTokenToSession is a mocked providers function.
|
||||
func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.State, error) {
|
||||
return &mp.IDTokenToSessionResponse, mp.IDTokenToSessionError
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ func (mp MockProvider) Validate(ctx context.Context, s string) (bool, error) {
|
|||
}
|
||||
|
||||
// Refresh is a mocked providers function.
|
||||
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
return mp.RefreshResponse, mp.RefreshError
|
||||
}
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ type accessToken struct {
|
|||
// Refresh renews a user's session using an oid refresh token without reprompting the user.
|
||||
// Group membership is also refreshed. If configured properly, Okta is we can configure the access token
|
||||
// to include group membership claims which allows us to avoid a follow up oauth2 call.
|
||||
func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity/okta: missing refresh token")
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ func (p *OneLoginProvider) GetSignInURL(state string) string {
|
|||
// Refresh renews a user's session using an oid refresh token without reprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity/microsoft: missing refresh token")
|
||||
}
|
||||
|
|
|
@ -45,10 +45,10 @@ type UserGrouper interface {
|
|||
|
||||
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
||||
type Authenticator interface {
|
||||
Authenticate(context.Context, string) (*sessions.SessionState, error)
|
||||
IDTokenToSession(context.Context, string) (*sessions.SessionState, error)
|
||||
Authenticate(context.Context, string) (*sessions.State, error)
|
||||
IDTokenToSession(context.Context, string) (*sessions.State, error)
|
||||
Validate(context.Context, string) (bool, error)
|
||||
Refresh(context.Context, *sessions.SessionState) (*sessions.SessionState, error)
|
||||
Refresh(context.Context, *sessions.State) (*sessions.State, error)
|
||||
Revoke(string) error
|
||||
GetSignInURL(state string) string
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ func (p *Provider) Validate(ctx context.Context, idToken string) (bool, error) {
|
|||
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.SessionState, error) {
|
||||
func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity: could not verify id_token: %v", err)
|
||||
|
@ -146,7 +146,7 @@ func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*se
|
|||
return nil, fmt.Errorf("identity: failed to parse id_token claims: %v", err)
|
||||
}
|
||||
|
||||
return &sessions.SessionState{
|
||||
return &sessions.State{
|
||||
IDToken: rawIDToken,
|
||||
User: idToken.Subject,
|
||||
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
|
||||
|
@ -157,7 +157,7 @@ func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*se
|
|||
}
|
||||
|
||||
// Authenticate creates a session with an identity provider from a authorization code
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
// exchange authorization for a oidc token
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
|
@ -181,7 +181,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.Ses
|
|||
// Refresh renews a user's session using therefresh_token without reprompting
|
||||
// the user. If supported, group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity: missing refresh token")
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
@ -70,7 +71,7 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl
|
|||
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
redirectURI, err := url.Parse(r.Form.Get("redirect_uri"))
|
||||
redirectURI, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("bad redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
|
@ -131,7 +132,7 @@ func ValidateHost(validHost func(host string) bool) func(next http.Handler) http
|
|||
defer span.End()
|
||||
|
||||
if !validHost(r.Host) {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not a known route.", r.Host), http.StatusNotFound, nil))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
|
@ -168,7 +169,7 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
|||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
}
|
||||
_, err := url.Parse(redirectURI)
|
||||
_, err := urlutil.ParseAndValidateURL(redirectURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -40,7 +41,7 @@ func TestSignRequest(t *testing.T) {
|
|||
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
signer, err := cryptutil.NewES256Signer([]byte(exampleKey), "audience")
|
||||
signer, err := cryptutil.NewES256Signer(base64.StdEncoding.EncodeToString([]byte(exampleKey)), "audience")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -11,15 +10,17 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
// ErrInvalidSession is an error for invalid sessions.
|
||||
var ErrInvalidSession = errors.New("internal/sessions: invalid session")
|
||||
|
||||
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
||||
// the cookie is multi-part or not. This constant *should not* be valid
|
||||
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
|
||||
const ChunkedCanaryByte byte = '%'
|
||||
|
||||
// DefaultBearerTokenHeader is default header name for the authorization bearer
|
||||
// token header as defined in rfc2617
|
||||
// https://tools.ietf.org/html/rfc6750#section-2.1
|
||||
const DefaultBearerTokenHeader = "Authorization"
|
||||
|
||||
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
|
||||
// Note, this should be lower than the actual cookie's max size (4096 bytes)
|
||||
// which includes metadata.
|
||||
|
@ -29,39 +30,27 @@ const MaxChunkSize = 3800
|
|||
// set to prevent any abuse.
|
||||
const MaxNumChunks = 5
|
||||
|
||||
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
||||
type CSRFStore interface {
|
||||
SetCSRF(http.ResponseWriter, *http.Request, string)
|
||||
GetCSRF(*http.Request) (*http.Cookie, error)
|
||||
ClearCSRF(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
LoadSession(*http.Request) (*SessionState, error)
|
||||
SaveSession(http.ResponseWriter, *http.Request, *SessionState) error
|
||||
}
|
||||
|
||||
// CookieStore represents all the cookie related configurations
|
||||
type CookieStore struct {
|
||||
Name string
|
||||
CookieCipher cryptutil.Cipher
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
Name string
|
||||
CookieCipher cryptutil.Cipher
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
BearerTokenHeader string
|
||||
}
|
||||
|
||||
// CookieStoreOptions holds options for CookieStore
|
||||
type CookieStoreOptions struct {
|
||||
Name string
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieCipher cryptutil.Cipher
|
||||
Name string
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
BearerTokenHeader string
|
||||
CookieExpire time.Duration
|
||||
CookieCipher cryptutil.Cipher
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
|
@ -72,23 +61,28 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
|||
if opts.CookieCipher == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||
}
|
||||
if opts.BearerTokenHeader == "" {
|
||||
opts.BearerTokenHeader = DefaultBearerTokenHeader
|
||||
}
|
||||
|
||||
return &CookieStore{
|
||||
Name: opts.Name,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieCipher: opts.CookieCipher,
|
||||
Name: opts.Name,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieCipher: opts.CookieCipher,
|
||||
BearerTokenHeader: opts.BearerTokenHeader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
domain := req.Host
|
||||
|
||||
if name == s.csrfName() {
|
||||
if name == cs.csrfName() {
|
||||
domain = req.Host
|
||||
} else if s.CookieDomain != "" {
|
||||
domain = s.CookieDomain
|
||||
} else if cs.CookieDomain != "" {
|
||||
domain = cs.CookieDomain
|
||||
} else {
|
||||
domain = splitDomain(domain)
|
||||
}
|
||||
|
@ -101,8 +95,8 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
|
|||
Value: value,
|
||||
Path: "/",
|
||||
Domain: domain,
|
||||
HttpOnly: s.CookieHTTPOnly,
|
||||
Secure: s.CookieSecure,
|
||||
HttpOnly: cs.CookieHTTPOnly,
|
||||
Secure: cs.CookieSecure,
|
||||
}
|
||||
// only set an expiration if we want one, otherwise default to non perm session based
|
||||
if expiration != 0 {
|
||||
|
@ -111,22 +105,20 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
|
|||
return c
|
||||
}
|
||||
|
||||
func (s *CookieStore) csrfName() string {
|
||||
return fmt.Sprintf("%s_csrf", s.Name)
|
||||
func (cs *CookieStore) csrfName() string {
|
||||
return fmt.Sprintf("%s_csrf", cs.Name)
|
||||
}
|
||||
|
||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||
func (s *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return s.makeCookie(req, s.Name, value, expiration, now)
|
||||
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
||||
}
|
||||
|
||||
// makeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time.
|
||||
// CSRF cookies should be scoped to the actual domain
|
||||
func (s *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return s.makeCookie(req, s.csrfName(), value, expiration, now)
|
||||
func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
|
||||
}
|
||||
|
||||
func (s *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
if len(cookie.String()) <= MaxChunkSize {
|
||||
http.SetCookie(w, cookie)
|
||||
return
|
||||
|
@ -142,9 +134,9 @@ func (s *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
|||
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
|
||||
nc.Value = c
|
||||
}
|
||||
fmt.Println(i)
|
||||
http.SetCookie(w, &nc)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func chunk(s string, size int) []string {
|
||||
|
@ -159,43 +151,54 @@ func chunk(s string, size int) []string {
|
|||
}
|
||||
|
||||
// ClearCSRF clears the CSRF cookie from the request
|
||||
func (s *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, s.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(w, s.makeCSRFCookie(req, val, s.CookieExpire, time.Now()))
|
||||
func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (s *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
||||
return req.Cookie(s.csrfName())
|
||||
func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
||||
c, err := req.Cookie(cs.csrfName())
|
||||
if err != nil {
|
||||
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (s *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, s.makeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (s *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||
s.setCookie(w, s.makeSessionCookie(req, val, s.CookieExpire, time.Now()))
|
||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// LoadSession returns a SessionState from the cookie in the request.
|
||||
func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
||||
c, err := req.Cookie(s.Name)
|
||||
func loadBearerToken(r *http.Request, headerKey string) string {
|
||||
authHeader := r.Header.Get(headerKey)
|
||||
split := strings.Split(authHeader, "Bearer")
|
||||
if authHeader == "" || len(split) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(split[1])
|
||||
}
|
||||
|
||||
func loadChunkedCookie(r *http.Request, cookieName string) string {
|
||||
c, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return nil, err // http.ErrNoCookie
|
||||
return ""
|
||||
}
|
||||
cipherText := c.Value
|
||||
|
||||
// if the first byte is our canary byte, we need to handle the multipart bit
|
||||
if []byte(c.Value)[0] == ChunkedCanaryByte {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%s", cipherText[1:])
|
||||
for i := 1; i < MaxNumChunks; i++ {
|
||||
next, err := req.Cookie(fmt.Sprintf("%s_%d", s.Name, i))
|
||||
for i := 1; i <= MaxNumChunks; i++ {
|
||||
next, err := r.Cookie(fmt.Sprintf("%s_%d", cookieName, i))
|
||||
if err != nil {
|
||||
break // break if we can't find the next cookie
|
||||
}
|
||||
|
@ -203,20 +206,32 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
|||
}
|
||||
cipherText = b.String()
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, s.CookieCipher)
|
||||
return cipherText
|
||||
}
|
||||
|
||||
// LoadSession returns a State from the cookie in the request.
|
||||
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||
cipherText := loadChunkedCookie(req, cs.Name)
|
||||
if cipherText == "" {
|
||||
cipherText = loadBearerToken(req, cs.BearerTokenHeader)
|
||||
}
|
||||
if cipherText == "" {
|
||||
return nil, ErrEmptySession
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, cs.CookieCipher)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidSession
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (s *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, sessionState *SessionState) error {
|
||||
value, err := MarshalSession(sessionState, s.CookieCipher)
|
||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
||||
value, err := MarshalSession(s, cs.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.setSessionCookie(w, req, value)
|
||||
cs.setSessionCookie(w, req, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
|
@ -49,30 +50,33 @@ func TestNewCookieStore(t *testing.T) {
|
|||
}{
|
||||
{"good",
|
||||
&CookieStoreOptions{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
BearerTokenHeader: "Authorization",
|
||||
},
|
||||
&CookieStore{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
BearerTokenHeader: "Authorization",
|
||||
},
|
||||
false},
|
||||
{"missing name",
|
||||
&CookieStoreOptions{
|
||||
Name: "",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
Name: "",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
CookieCipher: cipher,
|
||||
BearerTokenHeader: "Authorization",
|
||||
},
|
||||
nil,
|
||||
true},
|
||||
|
@ -95,8 +99,12 @@ func TestNewCookieStore(t *testing.T) {
|
|||
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(cryptutil.XChaCha20Cipher{}),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
t.Errorf("NewCookieStore() = %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -211,15 +219,15 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionState *SessionState
|
||||
cipher cryptutil.Cipher
|
||||
wantErr bool
|
||||
wantLoadErr bool
|
||||
name string
|
||||
State *State
|
||||
cipher cryptutil.Cipher
|
||||
wantErr bool
|
||||
wantLoadErr bool
|
||||
}{
|
||||
{"good", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||
{"bad cipher", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true},
|
||||
{"huge cookie", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||
{"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||
{"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true},
|
||||
{"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -234,12 +242,12 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
|
||||
if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
t.Log(cookie)
|
||||
// t.Log(cookie)
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
|
||||
|
@ -248,8 +256,10 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
}
|
||||
if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
|
||||
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
|
||||
if err == nil {
|
||||
if diff := cmp.Diff(state, tt.State); diff != "" {
|
||||
t.Errorf("CookieStore.LoadSession() got = %s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -291,18 +301,18 @@ func TestMockSessionStore(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockSessionStore
|
||||
saveSession *SessionState
|
||||
saveSession *State
|
||||
wantLoadErr bool
|
||||
wantSaveErr bool
|
||||
}{
|
||||
{"basic",
|
||||
&MockSessionStore{
|
||||
ResponseSession: "test",
|
||||
Session: &SessionState{AccessToken: "AccessToken"},
|
||||
Session: &State{AccessToken: "AccessToken"},
|
||||
SaveError: nil,
|
||||
LoadError: nil,
|
||||
},
|
||||
&SessionState{AccessToken: "AccessToken"},
|
||||
&State{AccessToken: "AccessToken"},
|
||||
false,
|
||||
false},
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
|||
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||
type MockSessionStore struct {
|
||||
ResponseSession string
|
||||
Session *SessionState
|
||||
Session *State
|
||||
SaveError error
|
||||
LoadError error
|
||||
}
|
||||
|
@ -40,11 +40,11 @@ func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
|||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||
func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *State) error {
|
||||
return ms.SaveError
|
||||
}
|
||||
|
|
|
@ -1,106 +0,0 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
// DefaultBearerTokenHeader is default header name for the authorization bearer
|
||||
// token header as defined in rfc2617
|
||||
// https://tools.ietf.org/html/rfc6750#section-2.1
|
||||
const DefaultBearerTokenHeader = "Authorization"
|
||||
|
||||
// RestStore is a session store suitable for REST
|
||||
type RestStore struct {
|
||||
Name string
|
||||
Cipher cryptutil.Cipher
|
||||
// Expire time.Duration
|
||||
}
|
||||
|
||||
// RestStoreOptions contains the options required to build a new RestStore.
|
||||
type RestStoreOptions struct {
|
||||
Name string
|
||||
Cipher cryptutil.Cipher
|
||||
// Expire time.Duration
|
||||
}
|
||||
|
||||
// NewRestStore creates a new RestStore from a set of RestStoreOptions.
|
||||
func NewRestStore(opts *RestStoreOptions) (*RestStore, error) {
|
||||
if opts.Name == "" {
|
||||
opts.Name = DefaultBearerTokenHeader
|
||||
}
|
||||
if opts.Cipher == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||
}
|
||||
return &RestStore{
|
||||
Name: opts.Name,
|
||||
// Expire: opts.Expire,
|
||||
Cipher: opts.Cipher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ClearSession functions differently because REST is stateless, we instead
|
||||
// inform the client that this token is no longer valid.
|
||||
// https://tools.ietf.org/html/rfc6750
|
||||
func (s *RestStore) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
errMsg := `
|
||||
{
|
||||
"error": "invalid_token",
|
||||
"token_type": "Bearer",
|
||||
"error_description": "The token has expired."
|
||||
}`
|
||||
w.Write([]byte(errMsg))
|
||||
}
|
||||
|
||||
// LoadSession attempts to load a pomerium session from a Bearer Token set
|
||||
// in the authorization header.
|
||||
func (s *RestStore) LoadSession(r *http.Request) (*SessionState, error) {
|
||||
authHeader := r.Header.Get(s.Name)
|
||||
split := strings.Split(authHeader, "Bearer")
|
||||
if authHeader == "" || len(split) != 2 {
|
||||
return nil, errors.New("internal/sessions: no bearer token header found")
|
||||
}
|
||||
token := strings.TrimSpace(split[1])
|
||||
session, err := UnmarshalSession(token, s.Cipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// RestStoreResponse is the JSON struct returned to the client.
|
||||
type RestStoreResponse struct {
|
||||
// Token is the encrypted pomerium session that can be used to
|
||||
// programmatically authenticate with pomerium.
|
||||
Token string
|
||||
// In addition to the token, non-sensitive meta data is returned to help
|
||||
// the client manage token renewals.
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// SaveSession returns an encrypted pomerium session as a JSON object with
|
||||
// associated, non sensitive meta-data like
|
||||
func (s *RestStore) SaveSession(w http.ResponseWriter, r *http.Request, sessionState *SessionState) error {
|
||||
encToken, err := MarshalSession(sessionState, s.Cipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
jsonBytes, err := json.Marshal(
|
||||
&RestStoreResponse{
|
||||
Token: encToken,
|
||||
Expiry: sessionState.RefreshDeadline,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("internal/sessions: couldn't marshal token struct: %v", err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
return nil
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
func TestRestStore_SaveSession(t *testing.T) {
|
||||
now := time.Date(2008, 1, 8, 17, 5, 5, 0, time.UTC)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
optionsName string
|
||||
optionsCipher cryptutil.Cipher
|
||||
sessionState *SessionState
|
||||
wantErr bool
|
||||
wantSaveResponse string
|
||||
}{
|
||||
{"good", "Authenticate", &cryptutil.MockCipher{MarshalResponse: "test"}, &SessionState{RefreshDeadline: now}, false, `{"Token":"test","Expiry":"2008-01-08T17:05:05Z"}`},
|
||||
{"bad session marshal", "Authenticate", &cryptutil.MockCipher{MarshalError: errors.New("error")}, &SessionState{RefreshDeadline: now}, true, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := NewRestStore(
|
||||
&RestStoreOptions{
|
||||
Name: tt.optionsName,
|
||||
Cipher: tt.optionsCipher,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRestStore err %v", err)
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
|
||||
t.Errorf("RestStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
resp := w.Result()
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
if diff := cmp.Diff(string(body), tt.wantSaveResponse); diff != "" {
|
||||
t.Errorf("RestStore.SaveSession() got / want diff \n%s\n", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRestStore(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
optionsName string
|
||||
optionsCipher cryptutil.Cipher
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "Authenticate", &cryptutil.MockCipher{}, false},
|
||||
{"good default to authenticate", "", &cryptutil.MockCipher{}, false},
|
||||
{"empty cipher", "Authenticate", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewRestStore(
|
||||
&RestStoreOptions{
|
||||
Name: tt.optionsName,
|
||||
Cipher: tt.optionsCipher,
|
||||
})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewRestStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestStore_ClearSession(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"always returns reset!", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &RestStore{Name: "Authenticate", Cipher: &cryptutil.MockCipher{}}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.ClearSession(w, r)
|
||||
resp := w.Result()
|
||||
if diff := cmp.Diff(resp.StatusCode, tt.expectedStatus); diff != "" {
|
||||
t.Errorf("RestStore.ClearSession() got / want diff \n%s\n", diff)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestStore_LoadSession(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
optionsName string
|
||||
optionsCipher cryptutil.Cipher
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "Authorization", &cryptutil.MockCipher{}, "test", false},
|
||||
{"empty auth header", "", &cryptutil.MockCipher{}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &RestStore{
|
||||
Name: tt.optionsName,
|
||||
Cipher: tt.optionsCipher,
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
if tt.optionsName != "" {
|
||||
r.Header.Set(tt.optionsName, fmt.Sprintf(("Bearer %s"), tt.token))
|
||||
|
||||
}
|
||||
_, err := s.LoadSession(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("RestStore.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
|||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -11,13 +10,11 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrLifetimeExpired is an error for the lifetime deadline expiring
|
||||
ErrLifetimeExpired = errors.New("user lifetime expired")
|
||||
)
|
||||
// ErrExpired is an error for a expired sessions.
|
||||
var ErrExpired = fmt.Errorf("internal/sessions: expired session")
|
||||
|
||||
// SessionState is our object that keeps track of a user's session state
|
||||
type SessionState struct {
|
||||
// State is our object that keeps track of a user's session state
|
||||
type State struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
|
@ -31,18 +28,31 @@ type SessionState struct {
|
|||
ImpersonateGroups []string
|
||||
}
|
||||
|
||||
// RefreshPeriodExpired returns true if the refresh period has expired
|
||||
func (s *SessionState) RefreshPeriodExpired() bool {
|
||||
return isExpired(s.RefreshDeadline)
|
||||
// Valid returns an error if the users's session state is not valid.
|
||||
func (s *State) Valid() error {
|
||||
if s.Expired() {
|
||||
return ErrExpired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceRefresh sets the refresh deadline to now.
|
||||
func (s *State) ForceRefresh() {
|
||||
s.RefreshDeadline = time.Now().Truncate(time.Second)
|
||||
}
|
||||
|
||||
// Expired returns true if the refresh period has expired
|
||||
func (s *State) Expired() bool {
|
||||
return s.RefreshDeadline.Before(time.Now())
|
||||
}
|
||||
|
||||
// Impersonating returns if the request is impersonating.
|
||||
func (s *SessionState) Impersonating() bool {
|
||||
func (s *State) Impersonating() bool {
|
||||
return s.ImpersonateEmail != "" || len(s.ImpersonateGroups) != 0
|
||||
}
|
||||
|
||||
// RequestEmail is the email to make the request as.
|
||||
func (s *SessionState) RequestEmail() string {
|
||||
func (s *State) RequestEmail() string {
|
||||
if s.ImpersonateEmail != "" {
|
||||
return s.ImpersonateEmail
|
||||
}
|
||||
|
@ -51,7 +61,7 @@ func (s *SessionState) RequestEmail() string {
|
|||
|
||||
// RequestGroups returns the groups of the Groups making the request; uses
|
||||
// impersonating user if set.
|
||||
func (s *SessionState) RequestGroups() string {
|
||||
func (s *State) RequestGroups() string {
|
||||
if len(s.ImpersonateGroups) != 0 {
|
||||
return strings.Join(s.ImpersonateGroups, ",")
|
||||
}
|
||||
|
@ -68,7 +78,7 @@ type idToken struct {
|
|||
}
|
||||
|
||||
// IssuedAt parses the IDToken's issue date and returns a valid go time.Time.
|
||||
func (s *SessionState) IssuedAt() (time.Time, error) {
|
||||
func (s *State) IssuedAt() (time.Time, error) {
|
||||
payload, err := parseJWT(s.IDToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err)
|
||||
|
@ -80,13 +90,9 @@ func (s *SessionState) IssuedAt() (time.Time, error) {
|
|||
return time.Time(token.IssuedAt), nil
|
||||
}
|
||||
|
||||
func isExpired(t time.Time) bool {
|
||||
return t.Before(time.Now())
|
||||
}
|
||||
|
||||
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||
// given cipher, and base64-encodes the result
|
||||
func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) {
|
||||
func MarshalSession(s *State, c cryptutil.Cipher) (string, error) {
|
||||
v, err := c.Marshal(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -96,8 +102,8 @@ func MarshalSession(s *SessionState, c cryptutil.Cipher) (string, error) {
|
|||
|
||||
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
||||
func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
|
||||
s := &SessionState{}
|
||||
func UnmarshalSession(value string, c cryptutil.Cipher) (*State, error) {
|
||||
s := &State{}
|
||||
err := c.Unmarshal(value, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -105,11 +111,6 @@ func UnmarshalSession(value string, c cryptutil.Cipher) (*SessionState, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// ExtendDeadline returns the time extended by a given duration, truncated by second
|
||||
func ExtendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
||||
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
|
@ -11,14 +11,14 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
func TestSessionStateSerialization(t *testing.T) {
|
||||
func TestStateSerialization(t *testing.T) {
|
||||
secret := cryptutil.GenerateKey()
|
||||
c, err := cryptutil.NewCipher(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
|
||||
want := &SessionState{
|
||||
want := &State{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
|
@ -43,41 +43,21 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSessionStateExpirations(t *testing.T) {
|
||||
session := &SessionState{
|
||||
func TestStateExpirations(t *testing.T) {
|
||||
session := &State{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
RefreshDeadline: time.Now().Add(-1 * time.Hour),
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
if !session.RefreshPeriodExpired() {
|
||||
if !session.Expired() {
|
||||
t.Errorf("expected lifetime period to be expired")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestExtendDeadline(t *testing.T) {
|
||||
// tons of wiggle room here
|
||||
now := time.Now().Truncate(time.Second)
|
||||
tests := []struct {
|
||||
name string
|
||||
ttl time.Duration
|
||||
want time.Time
|
||||
}{
|
||||
{"Add a few ms", time.Millisecond * 10, now.Truncate(time.Second)},
|
||||
{"Add a few microsecs", time.Microsecond * 10, now.Truncate(time.Second)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ExtendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ExtendDeadline() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionState_IssuedAt(t *testing.T) {
|
||||
func TestState_IssuedAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -91,20 +71,20 @@ func TestSessionState_IssuedAt(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &SessionState{IDToken: tt.IDToken}
|
||||
s := &State{IDToken: tt.IDToken}
|
||||
got, err := s.IssuedAt()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SessionState.IssuedAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("State.IssuedAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SessionState.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339))
|
||||
t.Errorf("State.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionState_Impersonating(t *testing.T) {
|
||||
func TestState_Impersonating(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -123,20 +103,20 @@ func TestSessionState_Impersonating(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &SessionState{
|
||||
s := &State{
|
||||
Email: tt.Email,
|
||||
Groups: tt.Groups,
|
||||
ImpersonateEmail: tt.ImpersonateEmail,
|
||||
ImpersonateGroups: tt.ImpersonateGroups,
|
||||
}
|
||||
if got := s.Impersonating(); got != tt.want {
|
||||
t.Errorf("SessionState.Impersonating() = %v, want %v", got, tt.want)
|
||||
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
|
||||
}
|
||||
if gotEmail := s.RequestEmail(); gotEmail != tt.wantResponseEmail {
|
||||
t.Errorf("SessionState.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
|
||||
t.Errorf("State.RequestEmail() = %v, want %v", gotEmail, tt.wantResponseEmail)
|
||||
}
|
||||
if gotGroups := s.RequestGroups(); gotGroups != tt.wantResponseGroups {
|
||||
t.Errorf("SessionState.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
|
||||
t.Errorf("State.v() = %v, want %v", gotGroups, tt.wantResponseGroups)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -154,11 +134,11 @@ func TestMarshalSession(t *testing.T) {
|
|||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
s *SessionState
|
||||
s *State
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple", &SessionState{}, false},
|
||||
{"too big", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString)}, false},
|
||||
{"simple", &State{}, false},
|
||||
{"too big", &State{AccessToken: fmt.Sprintf("%x", hugeString)}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -179,3 +159,45 @@ func TestMarshalSession(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_Valid(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
RefreshDeadline time.Time
|
||||
wantErr bool
|
||||
}{
|
||||
{" good", time.Now().Add(10 * time.Second), false},
|
||||
{" expired", time.Now().Add(-10 * time.Second), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &State{
|
||||
RefreshDeadline: tt.RefreshDeadline,
|
||||
}
|
||||
if err := s.Valid(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("State.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_ForceRefresh(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
RefreshDeadline time.Time
|
||||
}{
|
||||
{"good", time.Now().Truncate(time.Second)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &State{
|
||||
RefreshDeadline: tt.RefreshDeadline,
|
||||
}
|
||||
s.ForceRefresh()
|
||||
if s.RefreshDeadline != tt.RefreshDeadline {
|
||||
t.Errorf("refresh deadline not updated")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
26
internal/sessions/store.go
Normal file
26
internal/sessions/store.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrEmptySession is an error for an empty sessions.
|
||||
var ErrEmptySession = errors.New("internal/sessions: empty session")
|
||||
|
||||
// ErrEmptyCSRF is an error for an empty sessions.
|
||||
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
|
||||
|
||||
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
||||
type CSRFStore interface {
|
||||
SetCSRF(http.ResponseWriter, *http.Request, string)
|
||||
GetCSRF(*http.Request) (*http.Cookie, error)
|
||||
ClearCSRF(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
LoadSession(*http.Request) (*State, error)
|
||||
SaveSession(http.ResponseWriter, *http.Request, *State) error
|
||||
}
|
|
@ -8,12 +8,12 @@ import (
|
|||
|
||||
// The following tags are applied to stats recorded by this package.
|
||||
var (
|
||||
TagKeyHTTPMethod tag.Key = tag.MustNewKey("http_method")
|
||||
TagKeyService tag.Key = tag.MustNewKey("service")
|
||||
TagKeyGRPCService tag.Key = tag.MustNewKey("grpc_service")
|
||||
TagKeyGRPCMethod tag.Key = tag.MustNewKey("grpc_method")
|
||||
TagKeyHost tag.Key = tag.MustNewKey("host")
|
||||
TagKeyDestination tag.Key = tag.MustNewKey("destination")
|
||||
TagKeyHTTPMethod = tag.MustNewKey("http_method")
|
||||
TagKeyService = tag.MustNewKey("service")
|
||||
TagKeyGRPCService = tag.MustNewKey("grpc_service")
|
||||
TagKeyGRPCMethod = tag.MustNewKey("grpc_method")
|
||||
TagKeyHost = tag.MustNewKey("host")
|
||||
TagKeyDestination = tag.MustNewKey("destination")
|
||||
)
|
||||
|
||||
// Default distributions used by views in this package.
|
||||
|
|
|
@ -1,399 +0,0 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: authenticate.proto
|
||||
|
||||
package authenticate
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import timestamp "github.com/golang/protobuf/ptypes/timestamp"
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type AuthenticateRequest struct {
|
||||
Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *AuthenticateRequest) Reset() { *m = AuthenticateRequest{} }
|
||||
func (m *AuthenticateRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*AuthenticateRequest) ProtoMessage() {}
|
||||
func (*AuthenticateRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{0}
|
||||
}
|
||||
func (m *AuthenticateRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_AuthenticateRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *AuthenticateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_AuthenticateRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (dst *AuthenticateRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_AuthenticateRequest.Merge(dst, src)
|
||||
}
|
||||
func (m *AuthenticateRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_AuthenticateRequest.Size(m)
|
||||
}
|
||||
func (m *AuthenticateRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_AuthenticateRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_AuthenticateRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *AuthenticateRequest) GetCode() string {
|
||||
if m != nil {
|
||||
return m.Code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ValidateRequest struct {
|
||||
IdToken string `protobuf:"bytes,1,opt,name=id_token,json=idToken,proto3" json:"id_token,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *ValidateRequest) Reset() { *m = ValidateRequest{} }
|
||||
func (m *ValidateRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*ValidateRequest) ProtoMessage() {}
|
||||
func (*ValidateRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{1}
|
||||
}
|
||||
func (m *ValidateRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ValidateRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *ValidateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_ValidateRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (dst *ValidateRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_ValidateRequest.Merge(dst, src)
|
||||
}
|
||||
func (m *ValidateRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_ValidateRequest.Size(m)
|
||||
}
|
||||
func (m *ValidateRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_ValidateRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_ValidateRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *ValidateRequest) GetIdToken() string {
|
||||
if m != nil {
|
||||
return m.IdToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ValidateReply struct {
|
||||
IsValid bool `protobuf:"varint,1,opt,name=is_valid,json=isValid,proto3" json:"is_valid,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *ValidateReply) Reset() { *m = ValidateReply{} }
|
||||
func (m *ValidateReply) String() string { return proto.CompactTextString(m) }
|
||||
func (*ValidateReply) ProtoMessage() {}
|
||||
func (*ValidateReply) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{2}
|
||||
}
|
||||
func (m *ValidateReply) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ValidateReply.Unmarshal(m, b)
|
||||
}
|
||||
func (m *ValidateReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_ValidateReply.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (dst *ValidateReply) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_ValidateReply.Merge(dst, src)
|
||||
}
|
||||
func (m *ValidateReply) XXX_Size() int {
|
||||
return xxx_messageInfo_ValidateReply.Size(m)
|
||||
}
|
||||
func (m *ValidateReply) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_ValidateReply.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_ValidateReply proto.InternalMessageInfo
|
||||
|
||||
func (m *ValidateReply) GetIsValid() bool {
|
||||
if m != nil {
|
||||
return m.IsValid
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
|
||||
RefreshToken string `protobuf:"bytes,2,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
|
||||
IdToken string `protobuf:"bytes,3,opt,name=id_token,json=idToken,proto3" json:"id_token,omitempty"`
|
||||
User string `protobuf:"bytes,4,opt,name=user,proto3" json:"user,omitempty"`
|
||||
Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"`
|
||||
Groups []string `protobuf:"bytes,6,rep,name=groups,proto3" json:"groups,omitempty"`
|
||||
RefreshDeadline *timestamp.Timestamp `protobuf:"bytes,7,opt,name=refresh_deadline,json=refreshDeadline,proto3" json:"refresh_deadline,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Session) Reset() { *m = Session{} }
|
||||
func (m *Session) String() string { return proto.CompactTextString(m) }
|
||||
func (*Session) ProtoMessage() {}
|
||||
func (*Session) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_authenticate_d9796afa57ba1f78, []int{3}
|
||||
}
|
||||
func (m *Session) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Session.Unmarshal(m, b)
|
||||
}
|
||||
func (m *Session) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_Session.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (dst *Session) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Session.Merge(dst, src)
|
||||
}
|
||||
func (m *Session) XXX_Size() int {
|
||||
return xxx_messageInfo_Session.Size(m)
|
||||
}
|
||||
func (m *Session) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Session.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Session proto.InternalMessageInfo
|
||||
|
||||
func (m *Session) GetAccessToken() string {
|
||||
if m != nil {
|
||||
return m.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Session) GetRefreshToken() string {
|
||||
if m != nil {
|
||||
return m.RefreshToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Session) GetIdToken() string {
|
||||
if m != nil {
|
||||
return m.IdToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Session) GetUser() string {
|
||||
if m != nil {
|
||||
return m.User
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Session) GetEmail() string {
|
||||
if m != nil {
|
||||
return m.Email
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Session) GetGroups() []string {
|
||||
if m != nil {
|
||||
return m.Groups
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Session) GetRefreshDeadline() *timestamp.Timestamp {
|
||||
if m != nil {
|
||||
return m.RefreshDeadline
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*AuthenticateRequest)(nil), "authenticate.AuthenticateRequest")
|
||||
proto.RegisterType((*ValidateRequest)(nil), "authenticate.ValidateRequest")
|
||||
proto.RegisterType((*ValidateReply)(nil), "authenticate.ValidateReply")
|
||||
proto.RegisterType((*Session)(nil), "authenticate.Session")
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// AuthenticatorClient is the client API for Authenticator service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type AuthenticatorClient interface {
|
||||
Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*Session, error)
|
||||
Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateReply, error)
|
||||
Refresh(ctx context.Context, in *Session, opts ...grpc.CallOption) (*Session, error)
|
||||
}
|
||||
|
||||
type authenticatorClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewAuthenticatorClient(cc *grpc.ClientConn) AuthenticatorClient {
|
||||
return &authenticatorClient{cc}
|
||||
}
|
||||
|
||||
func (c *authenticatorClient) Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*Session, error) {
|
||||
out := new(Session)
|
||||
err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Authenticate", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *authenticatorClient) Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateReply, error) {
|
||||
out := new(ValidateReply)
|
||||
err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Validate", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *authenticatorClient) Refresh(ctx context.Context, in *Session, opts ...grpc.CallOption) (*Session, error) {
|
||||
out := new(Session)
|
||||
err := c.cc.Invoke(ctx, "/authenticate.Authenticator/Refresh", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// AuthenticatorServer is the server API for Authenticator service.
|
||||
type AuthenticatorServer interface {
|
||||
Authenticate(context.Context, *AuthenticateRequest) (*Session, error)
|
||||
Validate(context.Context, *ValidateRequest) (*ValidateReply, error)
|
||||
Refresh(context.Context, *Session) (*Session, error)
|
||||
}
|
||||
|
||||
func RegisterAuthenticatorServer(s *grpc.Server, srv AuthenticatorServer) {
|
||||
s.RegisterService(&_Authenticator_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Authenticator_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(AuthenticateRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(AuthenticatorServer).Authenticate(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/authenticate.Authenticator/Authenticate",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(AuthenticatorServer).Authenticate(ctx, req.(*AuthenticateRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Authenticator_Validate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ValidateRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(AuthenticatorServer).Validate(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/authenticate.Authenticator/Validate",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(AuthenticatorServer).Validate(ctx, req.(*ValidateRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Authenticator_Refresh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Session)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(AuthenticatorServer).Refresh(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/authenticate.Authenticator/Refresh",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(AuthenticatorServer).Refresh(ctx, req.(*Session))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Authenticator_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "authenticate.Authenticator",
|
||||
HandlerType: (*AuthenticatorServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Authenticate",
|
||||
Handler: _Authenticator_Authenticate_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Validate",
|
||||
Handler: _Authenticator_Validate_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Refresh",
|
||||
Handler: _Authenticator_Refresh_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "authenticate.proto",
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("authenticate.proto", fileDescriptor_authenticate_d9796afa57ba1f78) }
|
||||
|
||||
var fileDescriptor_authenticate_d9796afa57ba1f78 = []byte{
|
||||
// 354 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x91, 0x4d, 0x4f, 0xb3, 0x40,
|
||||
0x14, 0x85, 0xcb, 0xdb, 0x0f, 0xda, 0x5b, 0x9a, 0xbe, 0xb9, 0x7e, 0x04, 0x31, 0xc6, 0x16, 0x37,
|
||||
0xd5, 0x18, 0x9a, 0xd4, 0x95, 0x4b, 0x13, 0x4d, 0x8c, 0x4b, 0x6c, 0xdc, 0x36, 0x14, 0x6e, 0xdb,
|
||||
0x89, 0x94, 0x41, 0x66, 0x30, 0xe9, 0xbf, 0xf5, 0x4f, 0xb8, 0x37, 0x0c, 0x10, 0xc1, 0xb4, 0x3b,
|
||||
0xee, 0x99, 0xe7, 0x0e, 0x67, 0xce, 0x01, 0xf4, 0x52, 0xb9, 0xa1, 0x48, 0x32, 0xdf, 0x93, 0xe4,
|
||||
0xc4, 0x09, 0x97, 0x1c, 0x8d, 0xaa, 0x66, 0x5d, 0xae, 0x39, 0x5f, 0x87, 0x34, 0x55, 0x67, 0xcb,
|
||||
0x74, 0x35, 0x95, 0x6c, 0x4b, 0x42, 0x7a, 0xdb, 0x38, 0xc7, 0xed, 0x6b, 0x38, 0x7a, 0xa8, 0x2c,
|
||||
0xb8, 0xf4, 0x91, 0x92, 0x90, 0x88, 0xd0, 0xf2, 0x79, 0x40, 0xa6, 0x36, 0xd2, 0x26, 0x3d, 0x57,
|
||||
0x7d, 0xdb, 0xb7, 0x30, 0x7c, 0xf3, 0x42, 0x16, 0x54, 0xb0, 0x33, 0xe8, 0xb2, 0x60, 0x21, 0xf9,
|
||||
0x3b, 0x45, 0x05, 0xaa, 0xb3, 0x60, 0x9e, 0x8d, 0xf6, 0x0d, 0x0c, 0x7e, 0xe9, 0x38, 0xdc, 0x29,
|
||||
0x56, 0x2c, 0x3e, 0x33, 0x4d, 0xb1, 0x5d, 0x57, 0x67, 0x42, 0x21, 0xf6, 0xb7, 0x06, 0xfa, 0x2b,
|
||||
0x09, 0xc1, 0x78, 0x84, 0x63, 0x30, 0x3c, 0xdf, 0x27, 0x21, 0x6a, 0xd7, 0xf6, 0x73, 0x4d, 0x5d,
|
||||
0x8d, 0x57, 0x30, 0x48, 0x68, 0x95, 0x90, 0xd8, 0x14, 0xcc, 0x3f, 0xc5, 0x18, 0x85, 0x98, 0x43,
|
||||
0x55, 0x6b, 0xcd, 0x9a, 0xb5, 0xec, 0x71, 0xa9, 0xa0, 0xc4, 0x6c, 0xe5, 0x8f, 0xcb, 0xbe, 0xf1,
|
||||
0x18, 0xda, 0xb4, 0xf5, 0x58, 0x68, 0xb6, 0x95, 0x98, 0x0f, 0x78, 0x0a, 0x9d, 0x75, 0xc2, 0xd3,
|
||||
0x58, 0x98, 0x9d, 0x51, 0x73, 0xd2, 0x73, 0x8b, 0x09, 0x9f, 0xe0, 0x7f, 0xe9, 0x20, 0x20, 0x2f,
|
||||
0x08, 0x59, 0x44, 0xa6, 0x3e, 0xd2, 0x26, 0xfd, 0x99, 0xe5, 0xe4, 0x89, 0x3b, 0x65, 0xe2, 0xce,
|
||||
0xbc, 0x4c, 0xdc, 0x1d, 0x16, 0x3b, 0x8f, 0xc5, 0xca, 0xec, 0x4b, 0x83, 0x41, 0x25, 0x7d, 0x9e,
|
||||
0xe0, 0x0b, 0x18, 0xd5, 0x3a, 0x70, 0xec, 0xd4, 0x2a, 0xde, 0x53, 0x95, 0x75, 0x52, 0x47, 0x8a,
|
||||
0x1c, 0xed, 0x06, 0x3e, 0x43, 0xb7, 0x6c, 0x00, 0x2f, 0xea, 0xd0, 0x9f, 0x1e, 0xad, 0xf3, 0x43,
|
||||
0xc7, 0x71, 0xb8, 0xb3, 0x1b, 0x78, 0x0f, 0xba, 0x9b, 0x5b, 0xc7, 0xfd, 0x7f, 0x3b, 0x68, 0x62,
|
||||
0xd9, 0x51, 0x39, 0xdc, 0xfd, 0x04, 0x00, 0x00, 0xff, 0xff, 0xdc, 0x47, 0xff, 0x7e, 0xab, 0x02,
|
||||
0x00, 0x00,
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
syntax = "proto3";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
package authenticate;
|
||||
|
||||
service Authenticator {
|
||||
rpc Authenticate(AuthenticateRequest) returns (Session) {}
|
||||
rpc Validate(ValidateRequest) returns (ValidateReply) {}
|
||||
rpc Refresh(Session) returns (Session) {}
|
||||
}
|
||||
|
||||
message AuthenticateRequest { string code = 1; }
|
||||
|
||||
message ValidateRequest { string id_token = 1; }
|
||||
|
||||
message ValidateReply { bool is_valid = 1; }
|
||||
|
||||
message Session {
|
||||
string access_token = 1;
|
||||
string refresh_token = 2;
|
||||
string id_token = 3;
|
||||
string user = 4;
|
||||
string email = 5;
|
||||
repeated string groups = 6;
|
||||
google.protobuf.Timestamp refresh_deadline = 7;
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
// SessionFromProto converts a converts a protocol buffer session into a pomerium session state.
|
||||
func SessionFromProto(p *Session) (*sessions.SessionState, error) {
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("proto/authenticate: SessionFromProto session cannot be nil")
|
||||
}
|
||||
|
||||
refreshDeadline, err := ptypes.Timestamp(p.RefreshDeadline)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proto/authenticate: couldn't parse refresh deadline %v", err)
|
||||
}
|
||||
return &sessions.SessionState{
|
||||
AccessToken: p.AccessToken,
|
||||
RefreshToken: p.RefreshToken,
|
||||
IDToken: p.IdToken,
|
||||
Email: p.Email,
|
||||
User: p.User,
|
||||
Groups: p.Groups,
|
||||
RefreshDeadline: refreshDeadline,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProtoFromSession converts a pomerium user session into a protocol buffer struct.
|
||||
func ProtoFromSession(s *sessions.SessionState) (*Session, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("proto/authenticate: ProtoFromSession session cannot be nil")
|
||||
}
|
||||
refreshDeadline, err := ptypes.TimestampProto(s.RefreshDeadline)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proto/authenticate: couldn't parse refresh deadline %v", err)
|
||||
}
|
||||
return &Session{
|
||||
AccessToken: s.AccessToken,
|
||||
RefreshToken: s.RefreshToken,
|
||||
IdToken: s.IDToken,
|
||||
Email: s.Email,
|
||||
User: s.User,
|
||||
Groups: s.Groups,
|
||||
RefreshDeadline: refreshDeadline,
|
||||
}, nil
|
||||
}
|
|
@ -1,165 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: proto/authenticate/authenticate.pb.go
|
||||
|
||||
// Package mock_authenticate is a generated GoMock package.
|
||||
package mock_authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/pomerium/pomerium/proto/authenticate"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// MockAuthenticatorClient is a mock of AuthenticatorClient interface
|
||||
type MockAuthenticatorClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAuthenticatorClientMockRecorder
|
||||
}
|
||||
|
||||
// MockAuthenticatorClientMockRecorder is the mock recorder for MockAuthenticatorClient
|
||||
type MockAuthenticatorClientMockRecorder struct {
|
||||
mock *MockAuthenticatorClient
|
||||
}
|
||||
|
||||
// NewMockAuthenticatorClient creates a new mock instance
|
||||
func NewMockAuthenticatorClient(ctrl *gomock.Controller) *MockAuthenticatorClient {
|
||||
mock := &MockAuthenticatorClient{ctrl: ctrl}
|
||||
mock.recorder = &MockAuthenticatorClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockAuthenticatorClient) EXPECT() *MockAuthenticatorClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Authenticate mocks base method
|
||||
func (m *MockAuthenticatorClient) Authenticate(ctx context.Context, in *authenticate.AuthenticateRequest, opts ...grpc.CallOption) (*authenticate.Session, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, in}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Authenticate", varargs...)
|
||||
ret0, _ := ret[0].(*authenticate.Session)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Authenticate indicates an expected call of Authenticate
|
||||
func (mr *MockAuthenticatorClientMockRecorder) Authenticate(ctx, in interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, in}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticatorClient)(nil).Authenticate), varargs...)
|
||||
}
|
||||
|
||||
// Validate mocks base method
|
||||
func (m *MockAuthenticatorClient) Validate(ctx context.Context, in *authenticate.ValidateRequest, opts ...grpc.CallOption) (*authenticate.ValidateReply, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, in}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Validate", varargs...)
|
||||
ret0, _ := ret[0].(*authenticate.ValidateReply)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Validate indicates an expected call of Validate
|
||||
func (mr *MockAuthenticatorClientMockRecorder) Validate(ctx, in interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, in}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockAuthenticatorClient)(nil).Validate), varargs...)
|
||||
}
|
||||
|
||||
// Refresh mocks base method
|
||||
func (m *MockAuthenticatorClient) Refresh(ctx context.Context, in *authenticate.Session, opts ...grpc.CallOption) (*authenticate.Session, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, in}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Refresh", varargs...)
|
||||
ret0, _ := ret[0].(*authenticate.Session)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Refresh indicates an expected call of Refresh
|
||||
func (mr *MockAuthenticatorClientMockRecorder) Refresh(ctx, in interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, in}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockAuthenticatorClient)(nil).Refresh), varargs...)
|
||||
}
|
||||
|
||||
// MockAuthenticatorServer is a mock of AuthenticatorServer interface
|
||||
type MockAuthenticatorServer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAuthenticatorServerMockRecorder
|
||||
}
|
||||
|
||||
// MockAuthenticatorServerMockRecorder is the mock recorder for MockAuthenticatorServer
|
||||
type MockAuthenticatorServerMockRecorder struct {
|
||||
mock *MockAuthenticatorServer
|
||||
}
|
||||
|
||||
// NewMockAuthenticatorServer creates a new mock instance
|
||||
func NewMockAuthenticatorServer(ctrl *gomock.Controller) *MockAuthenticatorServer {
|
||||
mock := &MockAuthenticatorServer{ctrl: ctrl}
|
||||
mock.recorder = &MockAuthenticatorServerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockAuthenticatorServer) EXPECT() *MockAuthenticatorServerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Authenticate mocks base method
|
||||
func (m *MockAuthenticatorServer) Authenticate(arg0 context.Context, arg1 *authenticate.AuthenticateRequest) (*authenticate.Session, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Authenticate", arg0, arg1)
|
||||
ret0, _ := ret[0].(*authenticate.Session)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Authenticate indicates an expected call of Authenticate
|
||||
func (mr *MockAuthenticatorServerMockRecorder) Authenticate(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticatorServer)(nil).Authenticate), arg0, arg1)
|
||||
}
|
||||
|
||||
// Validate mocks base method
|
||||
func (m *MockAuthenticatorServer) Validate(arg0 context.Context, arg1 *authenticate.ValidateRequest) (*authenticate.ValidateReply, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Validate", arg0, arg1)
|
||||
ret0, _ := ret[0].(*authenticate.ValidateReply)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Validate indicates an expected call of Validate
|
||||
func (mr *MockAuthenticatorServerMockRecorder) Validate(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockAuthenticatorServer)(nil).Validate), arg0, arg1)
|
||||
}
|
||||
|
||||
// Refresh mocks base method
|
||||
func (m *MockAuthenticatorServer) Refresh(arg0 context.Context, arg1 *authenticate.Session) (*authenticate.Session, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Refresh", arg0, arg1)
|
||||
ret0, _ := ret[0].(*authenticate.Session)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Refresh indicates an expected call of Refresh
|
||||
func (mr *MockAuthenticatorServerMockRecorder) Refresh(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refresh", reflect.TypeOf((*MockAuthenticatorServer)(nil).Refresh), arg0, arg1)
|
||||
}
|
|
@ -1,117 +0,0 @@
|
|||
package clients // import "github.com/pomerium/pomerium/proxy/clients"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Authenticator provides the authenticate service interface
|
||||
type Authenticator interface {
|
||||
// Redeem takes a code and returns a validated session or an error
|
||||
Redeem(context.Context, string) (*sessions.SessionState, error)
|
||||
// Refresh attempts to refresh a valid session with a refresh token. Returns a refreshed session.
|
||||
Refresh(context.Context, *sessions.SessionState) (*sessions.SessionState, error)
|
||||
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
|
||||
Validate(context.Context, string) (bool, error)
|
||||
// Close closes the authenticator connection if any.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// NewAuthenticateClient returns a new authenticate service client. Presently,
|
||||
// only gRPC is supported and is always returned so name is ignored.
|
||||
func NewAuthenticateClient(name string, opts *Options) (a Authenticator, err error) {
|
||||
return NewGRPCAuthenticateClient(opts)
|
||||
}
|
||||
|
||||
// NewGRPCAuthenticateClient returns a new authenticate service client.
|
||||
func NewGRPCAuthenticateClient(opts *Options) (p *AuthenticateGRPC, err error) {
|
||||
conn, err := NewGRPCClientConn(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authClient := pb.NewAuthenticatorClient(conn)
|
||||
return &AuthenticateGRPC{Conn: conn, client: authClient}, nil
|
||||
}
|
||||
|
||||
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
|
||||
type AuthenticateGRPC struct {
|
||||
Conn *grpc.ClientConn
|
||||
client pb.AuthenticatorClient
|
||||
}
|
||||
|
||||
// Redeem makes an RPC call to the authenticate service to creates a session state
|
||||
// from an encrypted code provided as a result of an oauth2 callback process.
|
||||
func (a *AuthenticateGRPC) Redeem(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Redeem")
|
||||
defer span.End()
|
||||
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
protoSession, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := pb.SessionFromProto(protoSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
|
||||
// user's session. Requires a valid refresh token. Will return an error if the identity provider
|
||||
// has revoked the session or if the refresh token is no longer valid in this context.
|
||||
func (a *AuthenticateGRPC) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Refresh")
|
||||
defer span.End()
|
||||
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("missing refresh token")
|
||||
}
|
||||
req, err := pb.ProtoFromSession(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo(bdd): add grpc specific timeouts to main options
|
||||
// todo(bdd): handle request id (metadata!?) in grpc receiver and add to ctx logger
|
||||
reply, err := a.client.Refresh(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSession, err := pb.SessionFromProto(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSession, nil
|
||||
}
|
||||
|
||||
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
|
||||
// does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (a *AuthenticateGRPC) Validate(ctx context.Context, idToken string) (bool, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Validate")
|
||||
defer span.End()
|
||||
|
||||
if idToken == "" {
|
||||
return false, errors.New("missing id token")
|
||||
}
|
||||
|
||||
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r.IsValid, nil
|
||||
}
|
||||
|
||||
// Close tears down the ClientConn and all underlying connections.
|
||||
func (a *AuthenticateGRPC) Close() error {
|
||||
return a.Conn.Close()
|
||||
}
|
|
@ -1,242 +0,0 @@
|
|||
package clients // import "github.com/pomerium/pomerium/proxy/clients"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
mock "github.com/pomerium/pomerium/proto/authenticate/mock_authenticate"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serviceName string
|
||||
opts *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"grpc good", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "secret"}, false},
|
||||
{"grpc missing shared secret", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: ""}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewAuthenticateClient(tt.serviceName, tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
// rpcMsg implements the gomock.Matcher interface
|
||||
type rpcMsg struct {
|
||||
msg proto.Message
|
||||
}
|
||||
|
||||
func (r *rpcMsg) Matches(msg interface{}) bool {
|
||||
m, ok := msg.(proto.Message)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return proto.Equal(m, r.msg)
|
||||
}
|
||||
|
||||
func (r *rpcMsg) String() string {
|
||||
return fmt.Sprintf("is %s", r.msg)
|
||||
}
|
||||
|
||||
func TestProxy_Redeem(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
req := &pb.AuthenticateRequest{Code: "unit_test"}
|
||||
mockExpire, err := ptypes.TimestampProto(fixedDate)
|
||||
if err != nil {
|
||||
t.Fatalf("%v failed converting timestamp", err)
|
||||
}
|
||||
|
||||
mockAuthenticateClient.EXPECT().Authenticate(
|
||||
gomock.Any(),
|
||||
&rpcMsg{msg: req},
|
||||
).Return(&pb.Session{
|
||||
AccessToken: "mocked access token",
|
||||
RefreshToken: "mocked refresh token",
|
||||
IdToken: "mocked id token",
|
||||
User: "user1",
|
||||
Email: "test@email.com",
|
||||
RefreshDeadline: mockExpire,
|
||||
}, nil)
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
want *sessions.SessionState
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "unit_test", &sessions.SessionState{
|
||||
AccessToken: "mocked access token",
|
||||
RefreshToken: "mocked refresh token",
|
||||
IDToken: "mocked id token",
|
||||
User: "user1",
|
||||
Email: "test@email.com",
|
||||
RefreshDeadline: (fixedDate),
|
||||
}, false},
|
||||
{"empty code", "", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: mockAuthenticateClient}
|
||||
got, err := a.Redeem(context.Background(), tt.idToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != nil {
|
||||
if got.AccessToken != "mocked access token" {
|
||||
t.Errorf("authenticate: invalid access token")
|
||||
}
|
||||
if got.RefreshToken != "mocked refresh token" {
|
||||
t.Errorf("authenticate: invalid refresh token")
|
||||
}
|
||||
if got.IDToken != "mocked id token" {
|
||||
t.Errorf("authenticate: invalid id token")
|
||||
}
|
||||
if got.User != "user1" {
|
||||
t.Errorf("authenticate: invalid user")
|
||||
}
|
||||
if got.Email != "test@email.com" {
|
||||
t.Errorf("authenticate: invalid email")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestProxy_AuthenticateValidate(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
req := &pb.ValidateRequest{IdToken: "unit_test"}
|
||||
|
||||
mockAuthenticateClient.EXPECT().Validate(
|
||||
gomock.Any(),
|
||||
&rpcMsg{msg: req},
|
||||
).Return(&pb.ValidateReply{IsValid: false}, nil)
|
||||
|
||||
ac := mockAuthenticateClient
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "unit_test", false, false},
|
||||
{"empty id token", "", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: ac}
|
||||
|
||||
got, err := a.Validate(context.Background(), tt.idToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Proxy.AuthenticateValidate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_AuthenticateRefresh(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockRefreshClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
mockExpire, _ := ptypes.TimestampProto(fixedDate)
|
||||
|
||||
mockRefreshClient.EXPECT().Refresh(
|
||||
gomock.Any(),
|
||||
gomock.Not(sessions.SessionState{RefreshToken: "fail"}),
|
||||
).Return(&pb.Session{
|
||||
AccessToken: "new access token",
|
||||
RefreshDeadline: mockExpire,
|
||||
}, nil).AnyTimes()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session *sessions.SessionState
|
||||
want *sessions.SessionState
|
||||
wantErr bool
|
||||
}{
|
||||
{"good",
|
||||
&sessions.SessionState{RefreshToken: "unit_test"},
|
||||
&sessions.SessionState{
|
||||
AccessToken: "new access token",
|
||||
RefreshDeadline: fixedDate,
|
||||
}, false},
|
||||
{"empty refresh token", &sessions.SessionState{RefreshToken: ""}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: mockRefreshClient}
|
||||
|
||||
got, err := a.Refresh(context.Background(), tt.session)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Proxy.AuthenticateRefresh() got = \n%#v\nwant \n%#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGRPC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
wantErr bool
|
||||
wantErrStr string
|
||||
wantTarget string
|
||||
}{
|
||||
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""},
|
||||
{"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
|
||||
{"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
|
||||
{"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
|
||||
{"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
|
||||
{"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
|
||||
{"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
|
||||
{"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"},
|
||||
{"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"},
|
||||
{"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"},
|
||||
{"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"},
|
||||
{"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewGRPCAuthenticateClient(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewGRPCAuthenticateClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
|
||||
t.Errorf("NewGRPCAuthenticateClient() error = %v did not contain wantErr %v", err, tt.wantErrStr)
|
||||
}
|
||||
}
|
||||
if got != nil && got.Conn.Target() != tt.wantTarget {
|
||||
t.Errorf("NewGRPCAuthenticateClient() target = %v expected %v", got.Conn.Target(), tt.wantTarget)
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -15,9 +15,9 @@ import (
|
|||
type Authorizer interface {
|
||||
// Authorize takes a route and user session and returns whether the
|
||||
// request is valid per access policy
|
||||
Authorize(context.Context, string, *sessions.SessionState) (bool, error)
|
||||
Authorize(context.Context, string, *sessions.State) (bool, error)
|
||||
// IsAdmin takes a session and returns whether the user is an administrator
|
||||
IsAdmin(context.Context, *sessions.SessionState) (bool, error)
|
||||
IsAdmin(context.Context, *sessions.State) (bool, error)
|
||||
// Close closes the auth connection if any.
|
||||
Close() error
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ type AuthorizeGRPC struct {
|
|||
|
||||
// Authorize takes a route and user session and returns whether the
|
||||
// request is valid per access policy
|
||||
func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions.SessionState) (bool, error) {
|
||||
func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions.State) (bool, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.Authorize")
|
||||
defer span.End()
|
||||
|
||||
|
@ -65,7 +65,7 @@ func (a *AuthorizeGRPC) Authorize(ctx context.Context, route string, s *sessions
|
|||
}
|
||||
|
||||
// IsAdmin takes a session and returns whether the user is an administrator
|
||||
func (a *AuthorizeGRPC) IsAdmin(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
func (a *AuthorizeGRPC) IsAdmin(ctx context.Context, s *sessions.State) (bool, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.client.grpc.IsAdmin")
|
||||
defer span.End()
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ package clients
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -23,12 +25,12 @@ func TestAuthorizeGRPC_Authorize(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
route string
|
||||
s *sessions.SessionState
|
||||
s *sessions.State
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "hello.pomerium.io", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false},
|
||||
{"impersonate request", "hello.pomerium.io", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io", ImpersonateEmail: "other@other.example"}, true, false},
|
||||
{"good", "hello.pomerium.io", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false},
|
||||
{"impersonate request", "hello.pomerium.io", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io", ImpersonateEmail: "other@other.example"}, true, false},
|
||||
{"session cannot be nil", "hello.pomerium.io", nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -56,11 +58,11 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
s *sessions.SessionState
|
||||
s *sessions.State
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &sessions.SessionState{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false},
|
||||
{"good", &sessions.State{User: "admin@pomerium.io", Email: "admin@pomerium.io"}, true, false},
|
||||
{"session cannot be nil", nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -77,3 +79,41 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGRPC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
wantErr bool
|
||||
wantErrStr string
|
||||
wantTarget string
|
||||
}{
|
||||
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""},
|
||||
{"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
|
||||
{"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
|
||||
{"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
|
||||
{"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
|
||||
{"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
|
||||
{"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
|
||||
{"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"},
|
||||
{"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"},
|
||||
{"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"},
|
||||
{"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"},
|
||||
{"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewGRPCAuthorizeClient(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewGRPCAuthorizeClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
|
||||
t.Errorf("NewGRPCAuthorizeClient() error = %v did not contain wantErr %v", err, tt.wantErrStr)
|
||||
}
|
||||
}
|
||||
if got != nil && got.Conn.Target() != tt.wantTarget {
|
||||
t.Errorf("NewGRPCAuthorizeClient() target = %v expected %v", got.Conn.Target(), tt.wantTarget)
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
||||
"go.opencensus.io/plugin/ocgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer/roundrobin"
|
||||
|
@ -25,7 +26,7 @@ const defaultGRPCPort = 443
|
|||
|
||||
// Options contains options for connecting to a pomerium rpc service.
|
||||
type Options struct {
|
||||
// Addr is the location of the authenticate service. e.g. "service.corp.example:8443"
|
||||
// Addr is the location of the service. e.g. "service.corp.example:8443"
|
||||
Addr *url.URL
|
||||
// InternalAddr is the internal (behind the ingress) address to use when
|
||||
// making a connection. If empty, Addr is used.
|
||||
|
@ -34,7 +35,7 @@ type Options struct {
|
|||
// returned certificates from the server. gRPC internals also use it to override the virtual
|
||||
// hosting name if it is set.
|
||||
OverrideCertificateName string
|
||||
// Shared secret is used to authenticate a authenticate-client with a authenticate-server.
|
||||
// Shared secret is used to mutually authenticate a client and server.
|
||||
SharedSecret string
|
||||
// CA specifies the base64 encoded TLS certificate authority to use.
|
||||
CA string
|
||||
|
|
|
@ -6,35 +6,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
// MockAuthenticate provides a mocked implementation of the authenticator interface.
|
||||
type MockAuthenticate struct {
|
||||
RedeemError error
|
||||
RedeemResponse *sessions.SessionState
|
||||
RefreshResponse *sessions.SessionState
|
||||
RefreshError error
|
||||
ValidateResponse bool
|
||||
ValidateError error
|
||||
CloseError error
|
||||
}
|
||||
|
||||
// Redeem is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Redeem(ctx context.Context, code string) (*sessions.SessionState, error) {
|
||||
return a.RedeemResponse, a.RedeemError
|
||||
}
|
||||
|
||||
// Refresh is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Refresh(ctx context.Context, s *sessions.SessionState) (*sessions.SessionState, error) {
|
||||
return a.RefreshResponse, a.RefreshError
|
||||
}
|
||||
|
||||
// Validate is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Validate(ctx context.Context, idToken string) (bool, error) {
|
||||
return a.ValidateResponse, a.ValidateError
|
||||
}
|
||||
|
||||
// Close is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Close() error { return a.CloseError }
|
||||
|
||||
// MockAuthorize provides a mocked implementation of the authorizer interface.
|
||||
type MockAuthorize struct {
|
||||
AuthorizeResponse bool
|
||||
|
@ -48,11 +19,11 @@ type MockAuthorize struct {
|
|||
func (a MockAuthorize) Close() error { return a.CloseError }
|
||||
|
||||
// Authorize is a mocked authorizer client function.
|
||||
func (a MockAuthorize) Authorize(ctx context.Context, route string, s *sessions.SessionState) (bool, error) {
|
||||
func (a MockAuthorize) Authorize(ctx context.Context, route string, s *sessions.State) (bool, error) {
|
||||
return a.AuthorizeResponse, a.AuthorizeError
|
||||
}
|
||||
|
||||
// IsAdmin is a mocked IsAdmin function.
|
||||
func (a MockAuthorize) IsAdmin(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
func (a MockAuthorize) IsAdmin(ctx context.Context, s *sessions.State) (bool, error) {
|
||||
return a.IsAdminResponse, a.IsAdminError
|
||||
}
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
func TestMockAuthenticate(t *testing.T) {
|
||||
// Absurd, but I caught a typo this way.
|
||||
redeemResponse := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
}
|
||||
ma := &MockAuthenticate{
|
||||
RedeemError: errors.New("redeem error"),
|
||||
RedeemResponse: redeemResponse,
|
||||
RefreshResponse: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
},
|
||||
RefreshError: errors.New("refresh error"),
|
||||
ValidateResponse: true,
|
||||
ValidateError: errors.New("validate error"),
|
||||
CloseError: errors.New("close error"),
|
||||
}
|
||||
got, gotErr := ma.Redeem(context.Background(), "a")
|
||||
if gotErr.Error() != "redeem error" {
|
||||
t.Errorf("unexpected value for gotErr %s", gotErr)
|
||||
}
|
||||
if !reflect.DeepEqual(redeemResponse, got) {
|
||||
t.Errorf("unexpected value for redeemResponse %s", got)
|
||||
}
|
||||
newSession, gotErr := ma.Refresh(context.Background(), nil)
|
||||
if gotErr.Error() != "refresh error" {
|
||||
t.Errorf("unexpected value for gotErr %s", gotErr)
|
||||
}
|
||||
if !reflect.DeepEqual(newSession, redeemResponse) {
|
||||
t.Errorf("unexpected value for newSession %s", newSession)
|
||||
}
|
||||
|
||||
ok, gotErr := ma.Validate(context.Background(), "a")
|
||||
if !ok {
|
||||
t.Errorf("unexpected value for ok : %t", ok)
|
||||
}
|
||||
if gotErr.Error() != "validate error" {
|
||||
t.Errorf("unexpected value for gotErr %s", gotErr)
|
||||
}
|
||||
gotErr = ma.Close()
|
||||
if gotErr.Error() != "close error" {
|
||||
t.Errorf("unexpected value for ma.CloseError %s", gotErr)
|
||||
}
|
||||
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
// StateParameter holds the redirect id along with the session id.
|
||||
|
@ -36,9 +37,9 @@ func (p *Proxy) Handler() http.Handler {
|
|||
mux.HandleFunc("/.pomerium", p.UserDashboard)
|
||||
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
|
||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
||||
// handlers handlers with validation
|
||||
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback))
|
||||
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh))
|
||||
// handlers with validation
|
||||
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback))
|
||||
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh))
|
||||
mux.Handle("/", validate.ThenFunc(p.Proxy))
|
||||
return mux
|
||||
}
|
||||
|
@ -60,12 +61,12 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
uri, err := url.Parse(r.Form.Get("redirect_uri"))
|
||||
uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
|
||||
if err == nil && uri.String() != "" {
|
||||
redirectURL = uri
|
||||
}
|
||||
default:
|
||||
uri, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||
uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri"))
|
||||
if err == nil && uri.String() != "" {
|
||||
redirectURL = uri
|
||||
}
|
||||
|
@ -76,24 +77,20 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
// OAuthStart begins the authenticate flow, encrypting the redirect url
|
||||
// in a request to the provider's sign in endpoint.
|
||||
func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// create a CSRF value used to mitigate replay attacks.
|
||||
state := &StateParameter{
|
||||
SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()),
|
||||
RedirectURI: r.URL.String(),
|
||||
}
|
||||
|
||||
// Encrypt, and save CSRF state. Will be checked on callback.
|
||||
localState, err := p.cipher.Marshal(state)
|
||||
// Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback.
|
||||
csrfState, err := p.cipher.Marshal(state)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
p.csrfStore.SetCSRF(w, r, localState)
|
||||
p.csrfStore.SetCSRF(w, r, csrfState)
|
||||
|
||||
// Though the plaintext payload is identical, we re-encrypt which will
|
||||
// create a different cipher text using another nonce
|
||||
remoteState, err := p.cipher.Marshal(state)
|
||||
paramState, err := p.cipher.Marshal(state)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
|
@ -101,68 +98,55 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Sanity check. The encrypted payload of local and remote state should
|
||||
// never match as each encryption round uses a cryptographic nonce.
|
||||
//
|
||||
// todo(bdd): since this should nearly (1/(2^32*2^32)) never happen should
|
||||
// we panic as a failure most likely means the rands entropy source is failing?
|
||||
if remoteState == localState {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
// if paramState == csrfState {
|
||||
// httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
|
||||
// return
|
||||
// }
|
||||
|
||||
signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), remoteState)
|
||||
log.FromRequest(r).Debug().Str("SigninURL", signinURL.String()).Msg("proxy: oauth start")
|
||||
signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState)
|
||||
|
||||
// Redirect the user to the authenticate service along with the encrypted
|
||||
// state which contains a redirect uri back to the proxy and a nonce
|
||||
http.Redirect(w, r, signinURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// OAuthCallback validates the cookie sent back from the authenticate service. This function will
|
||||
// contain an error, or it will contain a `code`; the code can be used to fetch an access token, and
|
||||
// other metadata, from the authenticator.
|
||||
// finish the oauth cycle
|
||||
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// AuthenticateCallback checks the state parameter to make sure it matches the
|
||||
// local csrf state then redirects the user back to the original intended route.
|
||||
func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if callbackError := r.Form.Get("error"); callbackError != "" {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(callbackError, http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypted CSRF passed from authenticate service
|
||||
remoteStateEncrypted := r.Form.Get("state")
|
||||
remoteStatePlain := new(StateParameter)
|
||||
if err := p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain); err != nil {
|
||||
var remoteStatePlain StateParameter
|
||||
if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypted CSRF from session storage
|
||||
c, err := p.csrfStore.GetCSRF(r)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
p.csrfStore.ClearCSRF(w, r)
|
||||
|
||||
localStateEncrypted := c.Value
|
||||
localStatePlain := new(StateParameter)
|
||||
err = p.cipher.Unmarshal(localStateEncrypted, localStatePlain)
|
||||
var localStatePlain StateParameter
|
||||
err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the encrypted value of local and remote state match, reject.
|
||||
// Likely a replay attack or nonce-reuse.
|
||||
// assert no nonce reuse
|
||||
if remoteStateEncrypted == localStateEncrypted {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
|
||||
httputil.ErrorResponse(w, r, httputil.Error("local and remote state should not match!", http.StatusBadRequest, nil))
|
||||
|
||||
httputil.ErrorResponse(w, r,
|
||||
httputil.Error("local and remote state", http.StatusBadRequest,
|
||||
fmt.Errorf("possible nonce-reuse / replay attack")))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -205,13 +189,23 @@ func isCORSPreflight(r *http.Request) bool {
|
|||
r.Header.Get("Origin") != ""
|
||||
}
|
||||
|
||||
func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) {
|
||||
s, err := p.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: invalid session: %w", err)
|
||||
}
|
||||
if err := s.Valid(); err != nil {
|
||||
return nil, fmt.Errorf("proxy: invalid state: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
||||
// or starting the authenticate service for validation if not.
|
||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
// does a route exist for this request?
|
||||
route, ok := p.router(r)
|
||||
if !ok {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not a managed route.", r.Host), http.StatusNotFound, nil))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -221,30 +215,17 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
s, err := p.restStore.LoadSession(r)
|
||||
// if authorization bearer token does not exist or fails, use cookie store
|
||||
if err != nil || s == nil {
|
||||
s, err = p.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("proxy: invalid session, re-authenticating")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
p.OAuthStart(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err = p.authenticate(w, r, s); err != nil {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
httputil.ErrorResponse(w, r, httputil.Error("User unauthenticated", http.StatusUnauthorized, err))
|
||||
s, err := p.loadExistingSession(r)
|
||||
if err != nil {
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
||||
p.OAuthStart(w, r)
|
||||
return
|
||||
}
|
||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !authorized {
|
||||
} else if !authorized {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil))
|
||||
return
|
||||
}
|
||||
|
@ -259,20 +240,13 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
// It also contains certain administrative actions like user impersonation.
|
||||
// Nota bene: This endpoint does authentication, not authorization.
|
||||
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := p.sessionStore.LoadSession(r)
|
||||
session, err := p.loadExistingSession(r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("proxy: no session, redirecting to auth")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
||||
p.OAuthStart(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.authenticate(w, r, session); err != nil {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
httputil.ErrorResponse(w, r, httputil.Error("User unauthenticated", http.StatusUnauthorized, err))
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"}
|
||||
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
||||
if err != nil {
|
||||
|
@ -314,13 +288,14 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
|||
templates.New().ExecuteTemplate(w, "dashboard.html", t)
|
||||
}
|
||||
|
||||
// Refresh redeems and extends an existing authenticated oidc session with
|
||||
// ForceRefresh redeems and extends an existing authenticated oidc session with
|
||||
// the underlying identity provider. All session details including groups,
|
||||
// timeouts, will be renewed.
|
||||
func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := p.sessionStore.LoadSession(r)
|
||||
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := p.loadExistingSession(r)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
||||
p.OAuthStart(w, r)
|
||||
return
|
||||
}
|
||||
iss, err := session.IssuedAt()
|
||||
|
@ -332,16 +307,13 @@ func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) {
|
|||
// reject a refresh if it's been less than the refresh cooldown to prevent abuse
|
||||
if time.Since(iss) < p.refreshCooldown {
|
||||
httputil.ErrorResponse(w, r,
|
||||
httputil.Error(fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown), http.StatusBadRequest, nil))
|
||||
httputil.Error(
|
||||
fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown),
|
||||
http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
|
||||
newSession, err := p.AuthenticateClient.Refresh(r.Context(), session)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
if err = p.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||
session.ForceRefresh()
|
||||
if err = p.sessionStore.SaveSession(w, r, session); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
@ -357,12 +329,12 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
session, err := p.sessionStore.LoadSession(r)
|
||||
session, err := p.loadExistingSession(r)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
||||
p.OAuthStart(w, r)
|
||||
return
|
||||
}
|
||||
// authorization check -- is this user an admin?
|
||||
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
||||
if err != nil || !isAdmin {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err))
|
||||
|
@ -376,7 +348,7 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
p.csrfStore.ClearCSRF(w, r)
|
||||
encryptedCSRF := c.Value
|
||||
decryptedCSRF := new(StateParameter)
|
||||
var decryptedCSRF StateParameter
|
||||
if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
|
@ -398,26 +370,6 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, "/.pomerium", http.StatusFound)
|
||||
}
|
||||
|
||||
// Authenticate authenticates a request by checking for a session cookie, and validating its expiration,
|
||||
// clearing the session cookie if it's invalid and returning an error if necessary..
|
||||
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request, s *sessions.SessionState) error {
|
||||
if s.RefreshPeriodExpired() {
|
||||
s, err := p.AuthenticateClient.Refresh(r.Context(), s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("proxy: session refresh failed : %v", err)
|
||||
}
|
||||
if err := p.sessionStore.SaveSession(w, r, s); err != nil {
|
||||
return fmt.Errorf("proxy: refresh failed : %v", err)
|
||||
}
|
||||
} else {
|
||||
valid, err := p.AuthenticateClient.Validate(r.Context(), s.IDToken)
|
||||
if err != nil || !valid {
|
||||
return fmt.Errorf("proxy: session validate failed: %v : %v", valid, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// router attempts to find a route for a request. If a route is successfully matched,
|
||||
// it returns the route information and a bool value of `true`. If a route can
|
||||
// not be matched, a nil value for the route and false bool value is returned.
|
||||
|
@ -461,7 +413,7 @@ func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string
|
|||
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"})
|
||||
now := time.Now()
|
||||
rawRedirect := redirectURL.String()
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
|
||||
params.Set("redirect_uri", rawRedirect)
|
||||
params.Set("shared_secret", p.SharedKey)
|
||||
params.Set("response_type", "code")
|
||||
|
@ -477,7 +429,7 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
|
|||
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"})
|
||||
now := time.Now()
|
||||
rawRedirect := redirectURL.String()
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
|
||||
params.Add("redirect_uri", rawRedirect)
|
||||
params.Set("ts", fmt.Sprint(now.Unix()))
|
||||
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
||||
|
|
|
@ -72,7 +72,6 @@ func TestProxy_GetRedirectURL(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}}
|
||||
|
||||
if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
@ -240,8 +239,7 @@ func TestProxy_router(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.AuthenticateClient = clients.MockAuthenticate{}
|
||||
p.cipher = mockCipher{}
|
||||
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
|
||||
_, ok := p.router(req)
|
||||
|
@ -253,7 +251,7 @@ func TestProxy_router(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProxy_Proxy(t *testing.T) {
|
||||
goodSession := &sessions.SessionState{
|
||||
goodSession := &sessions.State{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
|
@ -278,39 +276,34 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
headersWs.Set("Upgrade", "websocket")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
header http.Header
|
||||
host string
|
||||
session sessions.SessionStore
|
||||
authenticator clients.Authenticator
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
header http.Header
|
||||
host string
|
||||
session sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
// same request as above, but with cors_allow_preflight=false in the policy
|
||||
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// cors allowed, but the request is missing proper headers
|
||||
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"unexpected error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
|
||||
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// redirect to start auth process
|
||||
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
|
||||
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
|
||||
// authenticate errors
|
||||
{"weird load session error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
|
||||
{"failed refreshed session", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.SessionState{RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RefreshError: errors.New("refresh error")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized},
|
||||
{"cannot resave refreshed session", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{SaveError: errors.New("weird"), Session: &sessions.SessionState{RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized},
|
||||
{"authenticate validation error", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: false}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized},
|
||||
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
|
||||
// no session, redirect to login
|
||||
{"no http found (no session)", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
|
||||
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
{"session error, redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
|
||||
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -323,13 +316,13 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.cipher = mockCipher{}
|
||||
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
|
||||
p.sessionStore = tt.session
|
||||
p.AuthenticateClient = tt.authenticator
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
|
||||
r := httptest.NewRequest(tt.method, tt.host, nil)
|
||||
r.Header = tt.header
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
p.Proxy(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
|
@ -348,23 +341,21 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
func TestProxy_UserDashboard(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
cipher cryptutil.Cipher
|
||||
session sessions.SessionStore
|
||||
authenticator clients.Authenticator
|
||||
authorizer clients.Authorizer
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
cipher cryptutil.Cipher
|
||||
session sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
|
||||
wantAdminForm bool
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusBadRequest},
|
||||
{"auth failure, validation error", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthenticate{ValidateError: errors.New("not valid anymore"), ValidateResponse: false}, clients.MockAuthorize{}, false, http.StatusUnauthorized},
|
||||
{"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||
{"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||
{"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, false, http.StatusFound},
|
||||
{"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||
{"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||
{"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -375,15 +366,18 @@ func TestProxy_UserDashboard(t *testing.T) {
|
|||
}
|
||||
p.cipher = tt.cipher
|
||||
p.sessionStore = tt.session
|
||||
p.AuthenticateClient = tt.authenticator
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
p.UserDashboard(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", opts)
|
||||
t.Errorf("\n%+v", w.Body.String())
|
||||
|
||||
}
|
||||
if adminForm := strings.Contains(w.Body.String(), "impersonate"); adminForm != tt.wantAdminForm {
|
||||
t.Errorf("wanted admin form got %v want %v", adminForm, tt.wantAdminForm)
|
||||
|
@ -393,28 +387,27 @@ func TestProxy_UserDashboard(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProxy_Refresh(t *testing.T) {
|
||||
func TestProxy_ForceRefresh(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
opts.RefreshCooldown = 0
|
||||
timeSinceError := testOptions(t)
|
||||
timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
cipher cryptutil.Cipher
|
||||
session sessions.SessionStore
|
||||
authenticator clients.Authenticator
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
cipher cryptutil.Cipher
|
||||
session sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusFound},
|
||||
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusBadRequest},
|
||||
{"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{RefreshError: errors.New("err")}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthenticate{}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, http.StatusFound},
|
||||
{"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
|
||||
{"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||
{"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -424,12 +417,11 @@ func TestProxy_Refresh(t *testing.T) {
|
|||
}
|
||||
p.cipher = tt.cipher
|
||||
p.sessionStore = tt.session
|
||||
p.AuthenticateClient = tt.authenticator
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
p.Refresh(w, r)
|
||||
p.ForceRefresh(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", opts)
|
||||
|
@ -442,30 +434,29 @@ func TestProxy_Impersonate(t *testing.T) {
|
|||
opts := testOptions(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
malformed bool
|
||||
options config.Options
|
||||
method string
|
||||
email string
|
||||
groups string
|
||||
csrf string
|
||||
cipher cryptutil.Cipher
|
||||
sessionStore sessions.SessionStore
|
||||
csrfStore sessions.CSRFStore
|
||||
authenticator clients.Authenticator
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
name string
|
||||
malformed bool
|
||||
options config.Options
|
||||
method string
|
||||
email string
|
||||
groups string
|
||||
csrf string
|
||||
cipher cryptutil.Cipher
|
||||
sessionStore sessions.SessionStore
|
||||
csrfStore sessions.CSRFStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||
{"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
||||
{"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
|
||||
{"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
// {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||
{"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
||||
{"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
|
||||
{"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -476,7 +467,6 @@ func TestProxy_Impersonate(t *testing.T) {
|
|||
p.cipher = tt.cipher
|
||||
p.sessionStore = tt.sessionStore
|
||||
p.csrfStore = tt.csrfStore
|
||||
p.AuthenticateClient = tt.authenticator
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
postForm := url.Values{}
|
||||
postForm.Add("email", tt.email)
|
||||
|
@ -501,19 +491,17 @@ func TestProxy_Impersonate(t *testing.T) {
|
|||
|
||||
func TestProxy_OAuthCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
csrf sessions.MockCSRFStore
|
||||
session sessions.MockSessionStore
|
||||
authenticator clients.MockAuthenticate
|
||||
params map[string]string
|
||||
wantCode int
|
||||
name string
|
||||
csrf sessions.MockCSRFStore
|
||||
session sessions.MockSessionStore
|
||||
params map[string]string
|
||||
wantCode int
|
||||
}{
|
||||
{"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
|
||||
{"error", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"error": "some error"}, http.StatusBadRequest},
|
||||
{"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
|
||||
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
{"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
{"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthenticate{RedeemResponse: &sessions.SessionState{AccessToken: "AccessToken", RefreshToken: "RefreshToken"}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
{"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
|
||||
{"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
|
||||
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
{"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
{"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -524,7 +512,6 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
|||
}
|
||||
proxy.sessionStore = &tt.session
|
||||
proxy.csrfStore = tt.csrf
|
||||
proxy.AuthenticateClient = tt.authenticator
|
||||
proxy.cipher = mockCipher{}
|
||||
// proxy.Csrf
|
||||
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
|
||||
|
@ -537,7 +524,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
|||
req.URL.RawQuery = "email=%zzzzz"
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
proxy.OAuthCallback(w, req)
|
||||
proxy.AuthenticateCallback(w, req)
|
||||
if status := w.Code; status != tt.wantCode {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||
}
|
||||
|
|
111
proxy/proxy.go
111
proxy/proxy.go
|
@ -2,11 +2,9 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
@ -39,51 +37,27 @@ const (
|
|||
// ValidateOptions checks that proper configuration settings are set to create
|
||||
// a proper Proxy instance
|
||||
func ValidateOptions(o config.Options) error {
|
||||
decoded, err := base64.StdEncoding.DecodeString(o.SharedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("`SHARED_SECRET` setting is invalid base64: %v", err)
|
||||
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
|
||||
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
|
||||
}
|
||||
if len(decoded) != 32 {
|
||||
return fmt.Errorf("`SHARED_SECRET` want 32 but got %d bytes", len(decoded))
|
||||
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
||||
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
|
||||
}
|
||||
|
||||
if o.AuthenticateURL == nil {
|
||||
return fmt.Errorf("proxy: missing setting: authenticate-service-url")
|
||||
return fmt.Errorf("proxy: missing 'AUTHENTICATE_SERVICE_URL'")
|
||||
}
|
||||
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
|
||||
return fmt.Errorf("proxy: error parsing authenticate url: %v", err)
|
||||
return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
|
||||
}
|
||||
|
||||
if o.AuthorizeURL == nil {
|
||||
return fmt.Errorf("proxy: missing setting: authenticate-service-url")
|
||||
return fmt.Errorf("proxy: missing 'AUTHORIZE_SERVICE_URL'")
|
||||
}
|
||||
if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil {
|
||||
return fmt.Errorf("proxy: error parsing authorize url: %v", err)
|
||||
}
|
||||
if o.AuthenticateInternalAddr != nil {
|
||||
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddr.String()); err != nil {
|
||||
return fmt.Errorf("proxy: error parsing authorize url: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if o.CookieSecret == "" {
|
||||
return fmt.Errorf("proxy: missing setting: cookie-secret")
|
||||
}
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("proxy: cookie secret is invalid base64: %v", err)
|
||||
}
|
||||
if len(decodedCookieSecret) != 32 {
|
||||
return fmt.Errorf("proxy: cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
|
||||
return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err)
|
||||
}
|
||||
if len(o.SigningKey) != 0 {
|
||||
decodedSigningKey, err := base64.StdEncoding.DecodeString(o.SigningKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("proxy: signing key is invalid base64: %v", err)
|
||||
}
|
||||
_, err = cryptutil.NewES256Signer(decodedSigningKey, "localhost")
|
||||
if err != nil {
|
||||
return fmt.Errorf("proxy: invalid signing key is : %v", err)
|
||||
if _, err := cryptutil.NewES256Signer(o.SigningKey, "localhost"); err != nil {
|
||||
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -92,12 +66,11 @@ func ValidateOptions(o config.Options) error {
|
|||
// Proxy stores all the information associated with proxying a request.
|
||||
type Proxy struct {
|
||||
// SharedKey used to mutually authenticate service communication
|
||||
SharedKey string
|
||||
authenticateURL *url.URL
|
||||
authenticateInternalAddr *url.URL
|
||||
authorizeURL *url.URL
|
||||
AuthenticateClient clients.Authenticator
|
||||
AuthorizeClient clients.Authorizer
|
||||
SharedKey string
|
||||
authenticateURL *url.URL
|
||||
authorizeURL *url.URL
|
||||
|
||||
AuthorizeClient clients.Authorizer
|
||||
|
||||
cipher cryptutil.Cipher
|
||||
cookieName string
|
||||
|
@ -105,7 +78,6 @@ type Proxy struct {
|
|||
defaultUpstreamTimeout time.Duration
|
||||
redirectURL *url.URL
|
||||
refreshCooldown time.Duration
|
||||
restStore sessions.SessionStore
|
||||
routeConfigs map[string]*routeConfig
|
||||
sessionStore sessions.SessionStore
|
||||
signingKey string
|
||||
|
@ -123,11 +95,9 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
if err := ValidateOptions(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// error explicitly handled by validate
|
||||
decodedSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cipher, err := cryptutil.NewCipher(decodedSecret)
|
||||
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cookie-secret error: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieStore(
|
||||
|
@ -140,10 +110,6 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
CookieCipher: cipher,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
restStore, err := sessions.NewRestStore(&sessions.RestStoreOptions{Cipher: cipher})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -158,7 +124,6 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||
refreshCooldown: opts.RefreshCooldown,
|
||||
restStore: restStore,
|
||||
sessionStore: cookieStore,
|
||||
signingKey: opts.SigningKey,
|
||||
templates: templates.New(),
|
||||
|
@ -166,7 +131,6 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
// DeepCopy urls to avoid accidental mutation, err checked in validate func
|
||||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||
p.authenticateInternalAddr, _ = urlutil.DeepCopy(opts.AuthenticateInternalAddr)
|
||||
|
||||
if err := p.UpdatePolicies(&opts); err != nil {
|
||||
return nil, err
|
||||
|
@ -174,20 +138,6 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
metrics.AddPolicyCountCallback("proxy", func() int64 {
|
||||
return int64(len(p.routeConfigs))
|
||||
})
|
||||
p.AuthenticateClient, err = clients.NewAuthenticateClient("grpc",
|
||||
&clients.Options{
|
||||
Addr: p.authenticateURL,
|
||||
InternalAddr: p.authenticateInternalAddr,
|
||||
OverrideCertificateName: opts.OverrideCertificateName,
|
||||
SharedSecret: opts.SharedKey,
|
||||
CA: opts.CA,
|
||||
CAFile: opts.CAFile,
|
||||
RequestTimeout: opts.GRPCClientTimeout,
|
||||
ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc",
|
||||
&clients.Options{
|
||||
Addr: p.authorizeURL,
|
||||
|
@ -213,19 +163,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
|||
}
|
||||
proxy := NewReverseProxy(policy.Destination)
|
||||
// build http transport (roundtripper) middleware chain
|
||||
// todo(bdd): replace with transport.Clone() in go 1.13
|
||||
transport := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
c := tripper.NewChain()
|
||||
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host))
|
||||
|
||||
|
@ -253,7 +191,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
|||
if isCustomClientConfig {
|
||||
transport.TLSClientConfig = &tlsClientConfig
|
||||
}
|
||||
proxy.Transport = c.Then(&transport)
|
||||
proxy.Transport = c.Then(transport)
|
||||
|
||||
handler, err := p.newReverseProxyHandler(proxy, &policy)
|
||||
if err != nil {
|
||||
|
@ -298,15 +236,6 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
|||
return proxy
|
||||
}
|
||||
|
||||
// newRouteSigner creates a route specific signer.
|
||||
func (p *Proxy) newRouteSigner(audience string) (cryptutil.JWTSigner, error) {
|
||||
decodedSigningKey, err := base64.StdEncoding.DecodeString(p.signingKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cryptutil.NewES256Signer(decodedSigningKey, audience)
|
||||
}
|
||||
|
||||
// newReverseProxyHandler applies handler specific options to a given route.
|
||||
func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) {
|
||||
handler = &UpstreamProxy{
|
||||
|
@ -318,7 +247,7 @@ func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.
|
|||
|
||||
// if signing key is set, add signer to middleware
|
||||
if len(p.signingKey) != 0 {
|
||||
signer, err := p.newRouteSigner(route.Source.Host)
|
||||
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -169,9 +169,6 @@ func TestOptions_Validate(t *testing.T) {
|
|||
authurl, _ := url.Parse("authenticate.corp.beyondperimeter.com")
|
||||
authenticateBadScheme := testOptions(t)
|
||||
authenticateBadScheme.AuthenticateURL = authurl
|
||||
authenticateInternalBadScheme := testOptions(t)
|
||||
authenticateInternalBadScheme.AuthenticateInternalAddr = authurl
|
||||
|
||||
authorizeBadSCheme := testOptions(t)
|
||||
authorizeBadSCheme.AuthorizeURL = authurl
|
||||
authorizeNil := testOptions(t)
|
||||
|
@ -200,7 +197,6 @@ func TestOptions_Validate(t *testing.T) {
|
|||
{"nil options", config.Options{}, true},
|
||||
{"authenticate service url", badAuthURL, true},
|
||||
{"authenticate service url no scheme", authenticateBadScheme, true},
|
||||
{"internal authenticate service url no scheme", authenticateInternalBadScheme, true},
|
||||
{"authorize service url no scheme", authorizeBadSCheme, true},
|
||||
{"authorize service cannot be nil", authorizeNil, true},
|
||||
{"no cookie secret", emptyCookieSecret, true},
|
||||
|
@ -221,7 +217,6 @@ func TestOptions_Validate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
|
||||
good := testOptions(t)
|
||||
shortCookieLength := testOptions(t)
|
||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue