proxy: add unit tests (#43)

This commit is contained in:
Bobby DeSimone 2019-02-11 20:15:01 -08:00 committed by GitHub
parent cedf9922d3
commit 4f4f3965aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 577 additions and 323 deletions

View file

@ -74,7 +74,7 @@ func main() {
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: new proxy")
}
defer proxyService.AuthenticateConn.Close()
defer proxyService.AuthenticateClient.Close()
}
topMux := http.NewServeMux()

View file

@ -12,17 +12,17 @@ type MockCSRFStore struct {
}
// SetCSRF sets the ResponseCSRF string to a val
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
ms.ResponseCSRF = val
}
// ClearCSRF clears the ResponseCSRF string
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
ms.ResponseCSRF = ""
}
// GetCSRF returns the cookie and error
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
return ms.Cookie, ms.GetError
}
@ -35,16 +35,16 @@ type MockSessionStore struct {
}
// ClearSession clears the ResponseSession
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
func (ms MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
return ms.SaveError
}

View file

@ -0,0 +1,69 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/url"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
// Authenticator provides the authenticate service interface
type Authenticator interface {
// Redeem takes a code and returns a validated session or an error
Redeem(string) (*RedeemResponse, error)
// Refresh attempts to refresh a valid session with a refresh token. Returns a new access token
// and expiration, or an error.
Refresh(string) (string, time.Time, error)
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
Validate(string) (bool, error)
// Close closes the authenticator connection if any.
Close() error
}
// New returns a new identity provider based given its name.
// Returns an error if selected provided not found or if the identity provider is not known.
func New(uri *url.URL, internalURL, OverideCertificateName, key string) (p Authenticator, err error) {
// if no port given, assume https/443
port := uri.Port()
if port == "" {
port = "443"
}
authEndpoint := fmt.Sprintf("%s:%s", uri.Host, port)
cp, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
if internalURL != "" {
authEndpoint = internalURL
}
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
if OverideCertificateName != "" {
err = cert.OverrideServerName(OverideCertificateName)
if err != nil {
return nil, err
}
}
grpcAuth := middleware.NewSharedSecretCred(key)
conn, err := grpc.Dial(
authEndpoint,
grpc.WithTransportCredentials(cert),
grpc.WithPerRPCCredentials(grpcAuth),
)
if err != nil {
return nil, err
}
authClient := pb.NewAuthenticatorClient(conn)
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
}

View file

@ -0,0 +1,36 @@
package authenticator
import (
"net/url"
"reflect"
"testing"
)
func TestNew(t *testing.T) {
type args struct {
uri *url.URL
internalURL string
OverideCertificateName string
key string
}
tests := []struct {
name string
args args
wantP Authenticator
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotP, err := New(tt.args.uri, tt.args.internalURL, tt.args.OverideCertificateName, tt.args.key)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotP, tt.wantP) {
t.Errorf("New() = %v, want %v", gotP, tt.wantP)
}
})
}
}

View file

@ -0,0 +1,98 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"context"
"errors"
"time"
"github.com/golang/protobuf/ptypes"
"google.golang.org/grpc"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
// RedeemResponse contains data from a authenticator redeem request.
type RedeemResponse struct {
AccessToken string
RefreshToken string
IDToken string
User string
Email string
Expiry time.Time
}
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
type AuthenticateGRPC struct {
conn *grpc.ClientConn
client pb.AuthenticatorClient
}
// Redeem makes an RPC call to the authenticate service to creates a session state
// from an encrypted code provided as a result of an oauth2 callback process.
func (a *AuthenticateGRPC) Redeem(code string) (*RedeemResponse, error) {
if code == "" {
return nil, errors.New("missing code")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
if err != nil {
return nil, err
}
expiry, err := ptypes.Timestamp(r.Expiry)
if err != nil {
return nil, err
}
return &RedeemResponse{
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
IDToken: r.IdToken,
User: r.User,
Email: r.Email,
Expiry: expiry,
// RefreshDeadline: (expiry).Truncate(time.Second),
// LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
// ValidDeadline: extendDeadline(p.CookieExpire),
}, nil
}
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
// user's session. Requires a valid refresh token. Will return an error if the identity provider
// has revoked the session or if the refresh token is no longer valid in this context.
func (a *AuthenticateGRPC) Refresh(refreshToken string) (string, time.Time, error) {
if refreshToken == "" {
return "", time.Time{}, errors.New("missing refresh token")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := a.client.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
if err != nil {
return "", time.Time{}, err
}
expiry, err := ptypes.Timestamp(r.Expiry)
if err != nil {
return "", time.Time{}, err
}
return r.AccessToken, expiry, nil
}
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
// does NOT do nonce or revokation validation.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (a *AuthenticateGRPC) Validate(idToken string) (bool, error) {
if idToken == "" {
return false, errors.New("missing id token")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
if err != nil {
return false, err
}
return r.IsValid, nil
}
// Close tears down the ClientConn and all underlying connections.
func (a *AuthenticateGRPC) Close() error {
return a.conn.Close()
}

View file

@ -1,4 +1,4 @@
package proxy
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"fmt"
@ -6,8 +6,6 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
@ -34,7 +32,7 @@ func (r *rpcMsg) String() string {
return fmt.Sprintf("is %s", r.msg)
}
func TestProxy_AuthenticateRedeem(t *testing.T) {
func TestProxy_Redeem(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
@ -55,29 +53,26 @@ func TestProxy_AuthenticateRedeem(t *testing.T) {
Email: "test@email.com",
Expiry: mockExpire,
}, nil)
p := &Proxy{AuthenticatorClient: mockAuthenticateClient}
tests := []struct {
name string
idToken string
want *sessions.SessionState
want *RedeemResponse
wantErr bool
}{
{"good", "unit_test", &sessions.SessionState{
{"good", "unit_test", &RedeemResponse{
AccessToken: "mocked access token",
RefreshToken: "mocked refresh token",
IDToken: "mocked id token",
User: "user1",
Email: "test@email.com",
RefreshDeadline: (fixedDate).Truncate(time.Second),
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
ValidDeadline: extendDeadline(p.CookieExpire),
Expiry: (fixedDate),
}, false},
{"empty code", "", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := p.AuthenticateRedeem(tt.idToken)
a := AuthenticateGRPC{client: mockAuthenticateClient}
got, err := a.Redeem(tt.idToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
return
@ -113,7 +108,7 @@ func TestProxy_AuthenticateValidate(t *testing.T) {
&rpcMsg{msg: req},
).Return(&pb.ValidateReply{IsValid: false}, nil)
p := &Proxy{AuthenticatorClient: mockAuthenticateClient}
ac := mockAuthenticateClient
tests := []struct {
name string
idToken string
@ -125,8 +120,9 @@ func TestProxy_AuthenticateValidate(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: ac}
got, err := p.AuthenticateValidate(tt.idToken)
got, err := a.Validate(tt.idToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
return
@ -167,9 +163,9 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{AuthenticatorClient: mockRefreshClient}
a := AuthenticateGRPC{client: mockRefreshClient}
got, gotExp, err := p.AuthenticateRefresh(tt.refreshToken)
got, gotExp, err := a.Refresh(tt.refreshToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
return
@ -183,22 +179,3 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
})
}
}
func Test_extendDeadline(t *testing.T) {
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,35 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"time"
)
// MockAuthenticate is a mock authenticator interface
type MockAuthenticate struct {
RedeemError error
RedeemResponse *RedeemResponse
RefreshResponse string
RefreshTime time.Time
RefreshError error
ValidateResponse bool
ValidateError error
CloseError error
}
// Redeem is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Redeem(code string) (*RedeemResponse, error) {
return a.RedeemResponse, a.RedeemError
}
// Refresh is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Refresh(refreshToken string) (string, time.Time, error) {
return a.RefreshResponse, a.RefreshTime, a.RefreshError
}
// Validate is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Validate(idToken string) (bool, error) {
return a.ValidateResponse, a.ValidateError
}
// Close is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Close() error { return a.ValidateError }

View file

@ -1,80 +0,0 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"context"
"errors"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/pomerium/pomerium/internal/sessions"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
// AuthenticateRedeem makes an RPC call to the authenticate service to creates a session state
// from an encrypted code provided as a result of an oauth2 callback process.
func (p *Proxy) AuthenticateRedeem(code string) (*sessions.SessionState, error) {
if code == "" {
return nil, errors.New("missing code")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := p.AuthenticatorClient.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
if err != nil {
return nil, err
}
expiry, err := ptypes.Timestamp(r.Expiry)
if err != nil {
return nil, err
}
return &sessions.SessionState{
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
IDToken: r.IdToken,
User: r.User,
Email: r.Email,
RefreshDeadline: (expiry).Truncate(time.Second),
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
ValidDeadline: extendDeadline(p.CookieExpire),
}, nil
}
// AuthenticateRefresh makes an RPC call to the authenticate service to attempt to refresh the
// user's session. Requires a valid refresh token. Will return an error if the identity provider
// has revoked the session or if the refresh token is no longer valid in this context.
func (p *Proxy) AuthenticateRefresh(refreshToken string) (string, time.Time, error) {
if refreshToken == "" {
return "", time.Time{}, errors.New("missing refresh token")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := p.AuthenticatorClient.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
if err != nil {
return "", time.Time{}, err
}
expiry, err := ptypes.Timestamp(r.Expiry)
if err != nil {
return "", time.Time{}, err
}
return r.AccessToken, expiry, nil
}
// AuthenticateValidate makes an RPC call to the authenticate service to validate the JWT id token;
// does NOT do nonce or revokation validation.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (p *Proxy) AuthenticateValidate(idToken string) (bool, error) {
if idToken == "" {
return false, errors.New("missing id token")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := p.AuthenticatorClient.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
if err != nil {
return false, err
}
return r.IsValid, nil
}
func extendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

View file

@ -160,7 +160,7 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
return
}
// We begin the process of redeeming the code for an access token.
session, err := p.AuthenticateRedeem(r.Form.Get("code"))
rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code"))
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
@ -168,6 +168,10 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
}
encryptedState := r.Form.Get("state")
log.Warn().
Str("encryptedState", encryptedState).
Msg("OK")
stateParameter := &StateParameter{}
err = p.cipher.Unmarshal(encryptedState, stateParameter)
if err != nil {
@ -192,13 +196,11 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
return
}
if encryptedState == encryptedCSRF {
log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
return
}
if !reflect.DeepEqual(stateParameter, csrfParameter) {
log.FromRequest(r).Error().Msg("state and CSRF should be equal")
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
@ -206,7 +208,16 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
}
// We store the session in a cookie and redirect the user back to the application
err = p.sessionStore.SaveSession(w, r, session)
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{
AccessToken: rr.AccessToken,
RefreshToken: rr.RefreshToken,
IDToken: rr.IDToken,
User: rr.User,
Email: rr.Email,
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
ValidDeadline: extendDeadline(p.CookieExpire),
})
if err != nil {
log.FromRequest(r).Error().Msg("error saving session")
httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
@ -216,8 +227,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
log.FromRequest(r).Info().
Str("code", r.Form.Get("code")).
Str("state", r.Form.Get("state")).
Str("RefreshToken", session.RefreshToken).
Str("session", session.AccessToken).
Str("RefreshToken", rr.RefreshToken).
Str("session", rr.AccessToken).
Str("RedirectURI", stateParameter.RedirectURI).
Msg("session")
@ -242,10 +253,6 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
if err != nil {
switch err {
case ErrUserNotAuthorized:
log.FromRequest(r).Debug().Err(err).Msg("proxy: user access forbidden")
httputil.ErrorResponse(w, r, "You don't have access", http.StatusForbidden)
return
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
log.FromRequest(r).Debug().Err(err).Msg("proxy: starting auth flow")
p.OAuthStart(w, r)
@ -256,8 +263,9 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
return
}
}
// ! ! !
// todo(bdd): ! Authorization checks will go here !
// todo(bdd): ! Authorization service goes here !
// ! ! !
// We have validated the users request and now proxy their request to the provided upstream.
@ -293,7 +301,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
// refresh token (which doesn't change) can be used to request a new access-token. If access
// is revoked by identity provider, or no refresh token is set request will return an error
accessToken, expiry, err := p.AuthenticateRefresh(session.RefreshToken)
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
if err != nil {
log.FromRequest(r).Warn().
Str("RefreshToken", session.RefreshToken).
@ -377,3 +385,7 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
a.RawQuery = params.Encode()
return a
}
func extendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}

View file

@ -1,6 +1,7 @@
package proxy
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -10,9 +11,34 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/proxy/authenticator"
)
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if string(s) == "unmarshal error" || string(s) == "error" {
return errors.New("error")
}
return nil
}
func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{}
req := httptest.NewRequest("GET", "/robots.txt", nil)
@ -49,8 +75,6 @@ func TestProxy_GetRedirectURL(t *testing.T) {
}
func TestProxy_signRedirectURL(t *testing.T) {
fixedDate := time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
tests := []struct {
name string
rawRedirect string
@ -194,135 +218,71 @@ func TestProxy_Handler(t *testing.T) {
}
}
// func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
// err := r.ParseForm()
// if err != nil {
// log.FromRequest(r).Error().Err(err).Msg("failed parsing request form")
// httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
// return
// }
// errorString := r.Form.Get("error")
// if errorString != "" {
// httputil.ErrorResponse(w, r, errorString, http.StatusForbidden)
// return
// }
// // We begin the process of redeeming the code for an access token.
// session, err := p.AuthenticateRedeem(r.Form.Get("code"))
// if err != nil {
// log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
// return
// }
// encryptedState := r.Form.Get("state")
// stateParameter := &StateParameter{}
// err = p.cipher.Unmarshal(encryptedState, stateParameter)
// if err != nil {
// log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state")
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
// return
// }
// c, err := p.csrfStore.GetCSRF(r)
// if err != nil {
// log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie")
// httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
// return
// }
// p.csrfStore.ClearCSRF(w, r)
// encryptedCSRF := c.Value
// csrfParameter := &StateParameter{}
// err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
// if err != nil {
// log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF")
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
// return
// }
// if encryptedState == encryptedCSRF {
// log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
// httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
// return
// }
// if !reflect.DeepEqual(stateParameter, csrfParameter) {
// log.FromRequest(r).Error().Msg("state and CSRF should be equal")
// httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
// return
// }
// // We store the session in a cookie and redirect the user back to the application
// err = p.sessionStore.SaveSession(w, r, session)
// if err != nil {
// log.FromRequest(r).Error().Msg("error saving session")
// httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
// return
// }
// log.FromRequest(r).Info().
// Str("code", r.Form.Get("code")).
// Str("state", r.Form.Get("state")).
// Str("RefreshToken", session.RefreshToken).
// Str("session", session.AccessToken).
// Str("RedirectURI", stateParameter.RedirectURI).
// Msg("session")
// // This is the redirect back to the original requested application
// http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
// }
// func TestProxy_OAuthCallback2(t *testing.T) {
// proxy, err := New(testOptions())
// if err != nil {
// t.Fatal(err)
// }
// testError := url.Values{"error": []string{"There was a bad error to handle"}}
// req := httptest.NewRequest("GET", "/oauth-callback", strings.NewReader(testError.Encode()))
// if err != nil {
// t.Fatal(err)
// }
// rr := httptest.NewRecorder()
// proxy.OAuthCallback)
// // expect oauth redirect
// if status := rr.Code; status != http.StatusInternalServerError {
// t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusInternalServerError)
// }
// // expected url
// // expected := `<a href="https://sso-auth.corp.beyondperimeter.com/sign_in`
// // body := rr.Body.String()
// // if !strings.HasPrefix(body, expected) {
// // t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
// // }
// }
func TestProxy_OAuthCallback(t *testing.T) {
//todo(bdd): test malformed requests
// https://github.com/golang/go/blob/master/src/net/http/request_test.go#L110
normalSession := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(-10 * time.Second),
},
}
normalAuth := authenticator.MockAuthenticate{
RedeemResponse: &authenticator.RedeemResponse{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Expiry: time.Now().Add(10 * time.Second),
},
}
normalCsrf := sessions.MockCSRFStore{
ResponseCSRF: "ok",
GetError: nil,
Cookie: &http.Cookie{
Name: "something_csrf",
Value: "csrf_state",
}}
tests := []struct {
name string
csrf sessions.MockCSRFStore
session sessions.MockSessionStore
authenticator authenticator.MockAuthenticate
params map[string]string
wantCode int
}{
{"good", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
{"error", normalCsrf, normalSession, normalAuth, map[string]string{"error": "some error"}, http.StatusForbidden},
{"code err", normalCsrf, normalSession, authenticator.MockAuthenticate{RedeemError: errors.New("error")}, map[string]string{"code": "error"}, http.StatusInternalServerError},
{"state err", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusBadRequest},
{"unmarshal err", sessions.MockCSRFStore{
Cookie: &http.Cookie{
Name: "something_csrf",
Value: "unmarshal error",
},
}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"encrypted state != CSRF", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "csrf_state"}, http.StatusBadRequest},
{"session save err", normalCsrf, sessions.MockSessionStore{SaveError: errors.New("error")}, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy, err := New(testOptions())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
method string
params map[string]string
wantCode int
}{
{"nil", http.MethodPost, nil, http.StatusInternalServerError},
{"error", http.MethodPost, map[string]string{"error": "some error"}, http.StatusForbidden},
{"state", http.MethodPost, map[string]string{"code": "code"}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/.pomerium/callback", nil)
proxy.sessionStore = tt.session
proxy.csrfStore = tt.csrf
proxy.AuthenticateClient = tt.authenticator
proxy.cipher = mockCipher{}
// proxy.Csrf
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
q := req.URL.Query()
for k, v := range tt.params {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
fmt.Println("OK OK OK OK")
fmt.Println(req.URL.String())
w := httptest.NewRecorder()
proxy.OAuthCallback(w, req)
if status := w.Code; status != tt.wantCode {
@ -330,4 +290,188 @@ func TestProxy_OAuthCallback(t *testing.T) {
}
})
}
}
func Test_extendDeadline(t *testing.T) {
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_router(t *testing.T) {
tests := []struct {
name string
host string
mux map[string]string
route http.Handler
wantOk bool
}{
{"good corp", "https://corp.example.com", map[string]string{"corp.example.com": "example.com"}, nil, true},
{"good with slash", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, nil, true},
{"good with path", "https://corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, true},
{"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
{"bad corp", "https://notcorp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
{"bad sub-sub", "https://notcorp.corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
opts.Routes = tt.mux
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.AuthenticateClient = authenticator.MockAuthenticate{}
p.cipher = mockCipher{}
req := httptest.NewRequest("GET", tt.host, nil)
_, ok := p.router(req)
if ok != tt.wantOk {
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
}
})
}
}
func TestProxy_Proxy(t *testing.T) {
goodSession := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
// expiredDeadline := &sessions.SessionState{
// AccessToken: "AccessToken",
// RefreshToken: "RefreshToken",
// LifetimeDeadline: time.Now().Add(10 * time.Second),
// RefreshDeadline: time.Now().Add(-10 * time.Second),
// }
tests := []struct {
name string
host string
session sessions.SessionStore
authenticator authenticator.Authenticator
wantStatus int
}{
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
// redirect to start auth process
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.cipher = mockCipher{}
p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
}
})
}
}
func TestProxy_Authenticate(t *testing.T) {
goodSession := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
expiredDeadline := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(-10 * time.Second),
}
tests := []struct {
name string
host string
mux map[string]string
session sessions.SessionStore
authenticator authenticator.Authenticator
wantErr bool
}{
{"cannot load session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired lifetime",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
{"expired session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
{"bad refresh authenticator",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{
Session: expiredDeadline,
},
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
true},
{"good",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
opts.Routes = tt.mux
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
p.cipher = mockCipher{}
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
if err := p.Authenticate(w, r); (err != nil) != tt.wantErr {
t.Errorf("Proxy.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,8 +1,6 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
@ -15,15 +13,12 @@ import (
"time"
"github.com/pomerium/envconfig"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
pb "github.com/pomerium/pomerium/proto/authenticate"
"github.com/pomerium/pomerium/proxy/authenticator"
)
const (
@ -40,7 +35,6 @@ type Options struct {
// Authenticate service settings
AuthenticateURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
AuthenticateInternalURL string `envconfig:"AUTHENTICATE_INTERNAL_URL"`
//
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
// SigningKey is a base64 encoded private key used to add a JWT-signature to proxied requests.
@ -131,21 +125,17 @@ func (o *Options) Validate() error {
type Proxy struct {
SharedKey string
// Authenticate Service Configuration
// services
AuthenticateURL *url.URL
AuthenticateInternalURL string
AuthenticatorClient pb.AuthenticatorClient
// AuthenticateConn must be closed by Proxy's caller
AuthenticateConn *grpc.ClientConn
AuthenticateClient authenticator.Authenticator
OverideCertificateName string
// session
cipher cryptutil.Cipher
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
CookieExpire time.Duration
CookieRefresh time.Duration
CookieLifetimeTTL time.Duration
cipher cryptutil.Cipher
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
redirectURL *url.URL
templates *template.Template
@ -154,6 +144,8 @@ type Proxy struct {
// New takes a Proxy service from options and a validation function.
// Function returns an error if options fail to validate.
//
// Caller responsible for closing AuthenticateConn.
func New(opts *Options) (*Proxy, error) {
if opts == nil {
return nil, errors.New("options cannot be nil")
@ -182,15 +174,13 @@ func New(opts *Options) (*Proxy, error) {
}
p := &Proxy{
// these fields make up the routing mechanism
mux: make(map[string]http.Handler),
// services
AuthenticateURL: opts.AuthenticateURL,
// session state
cipher: cipher,
csrfStore: cookieStore,
sessionStore: cookieStore,
AuthenticateURL: opts.AuthenticateURL,
AuthenticateInternalURL: opts.AuthenticateInternalURL,
OverideCertificateName: opts.OverideCertificateName,
SharedKey: opts.SharedKey,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
templates: templates.New(),
@ -209,41 +199,11 @@ func New(opts *Options) (*Proxy, error) {
p.Handle(fromURL.Host, handler)
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
}
// if no port given, assume https/443
port := p.AuthenticateURL.Port()
if port == "" {
port = "443"
}
authEndpoint := fmt.Sprintf("%s:%s", p.AuthenticateURL.Host, port)
cp, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
if p.AuthenticateInternalURL != "" {
authEndpoint = p.AuthenticateInternalURL
}
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
if p.OverideCertificateName != "" {
err = cert.OverrideServerName(p.OverideCertificateName)
if err != nil {
return nil, err
}
}
grpcAuth := middleware.NewSharedSecretCred(p.SharedKey)
p.AuthenticateConn, err = grpc.Dial(
authEndpoint,
grpc.WithTransportCredentials(cert),
grpc.WithPerRPCCredentials(grpcAuth),
)
if err != nil {
return nil, err
}
p.AuthenticatorClient = pb.NewAuthenticatorClient(p.AuthenticateConn)
p.AuthenticateClient, err = authenticator.New(
opts.AuthenticateURL,
opts.AuthenticateInternalURL,
opts.OverideCertificateName,
opts.SharedKey)
return p, nil
}

View file

@ -9,8 +9,11 @@ import (
"os"
"reflect"
"testing"
"time"
)
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func TestOptionsFromEnvConfig(t *testing.T) {
os.Clearenv()