mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-20 20:47:16 +02:00
proxy: add unit tests (#43)
This commit is contained in:
parent
cedf9922d3
commit
4f4f3965aa
12 changed files with 577 additions and 323 deletions
|
@ -74,7 +74,7 @@ func main() {
|
|||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium: new proxy")
|
||||
}
|
||||
defer proxyService.AuthenticateConn.Close()
|
||||
defer proxyService.AuthenticateClient.Close()
|
||||
}
|
||||
|
||||
topMux := http.NewServeMux()
|
||||
|
|
|
@ -12,17 +12,17 @@ type MockCSRFStore struct {
|
|||
}
|
||||
|
||||
// SetCSRF sets the ResponseCSRF string to a val
|
||||
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
ms.ResponseCSRF = val
|
||||
}
|
||||
|
||||
// ClearCSRF clears the ResponseCSRF string
|
||||
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseCSRF = ""
|
||||
}
|
||||
|
||||
// GetCSRF returns the cookie and error
|
||||
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||
return ms.Cookie, ms.GetError
|
||||
}
|
||||
|
||||
|
@ -35,16 +35,16 @@ type MockSessionStore struct {
|
|||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||
func (ms MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||
return ms.SaveError
|
||||
}
|
||||
|
|
69
proxy/authenticator/authenticator.go
Normal file
69
proxy/authenticator/authenticator.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// Authenticator provides the authenticate service interface
|
||||
type Authenticator interface {
|
||||
// Redeem takes a code and returns a validated session or an error
|
||||
Redeem(string) (*RedeemResponse, error)
|
||||
// Refresh attempts to refresh a valid session with a refresh token. Returns a new access token
|
||||
// and expiration, or an error.
|
||||
Refresh(string) (string, time.Time, error)
|
||||
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
|
||||
Validate(string) (bool, error)
|
||||
// Close closes the authenticator connection if any.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// New returns a new identity provider based given its name.
|
||||
// Returns an error if selected provided not found or if the identity provider is not known.
|
||||
func New(uri *url.URL, internalURL, OverideCertificateName, key string) (p Authenticator, err error) {
|
||||
// if no port given, assume https/443
|
||||
port := uri.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
authEndpoint := fmt.Sprintf("%s:%s", uri.Host, port)
|
||||
|
||||
cp, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if internalURL != "" {
|
||||
authEndpoint = internalURL
|
||||
}
|
||||
|
||||
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
|
||||
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
||||
if OverideCertificateName != "" {
|
||||
err = cert.OverrideServerName(OverideCertificateName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
grpcAuth := middleware.NewSharedSecretCred(key)
|
||||
conn, err := grpc.Dial(
|
||||
authEndpoint,
|
||||
grpc.WithTransportCredentials(cert),
|
||||
grpc.WithPerRPCCredentials(grpcAuth),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authClient := pb.NewAuthenticatorClient(conn)
|
||||
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
|
||||
}
|
36
proxy/authenticator/authenticator_test.go
Normal file
36
proxy/authenticator/authenticator_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package authenticator
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
type args struct {
|
||||
uri *url.URL
|
||||
internalURL string
|
||||
OverideCertificateName string
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantP Authenticator
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotP, err := New(tt.args.uri, tt.args.internalURL, tt.args.OverideCertificateName, tt.args.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotP, tt.wantP) {
|
||||
t.Errorf("New() = %v, want %v", gotP, tt.wantP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
98
proxy/authenticator/grpc.go
Normal file
98
proxy/authenticator/grpc.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// RedeemResponse contains data from a authenticator redeem request.
|
||||
type RedeemResponse struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
IDToken string
|
||||
User string
|
||||
Email string
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
|
||||
type AuthenticateGRPC struct {
|
||||
conn *grpc.ClientConn
|
||||
client pb.AuthenticatorClient
|
||||
}
|
||||
|
||||
// Redeem makes an RPC call to the authenticate service to creates a session state
|
||||
// from an encrypted code provided as a result of an oauth2 callback process.
|
||||
func (a *AuthenticateGRPC) Redeem(code string) (*RedeemResponse, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RedeemResponse{
|
||||
AccessToken: r.AccessToken,
|
||||
RefreshToken: r.RefreshToken,
|
||||
IDToken: r.IdToken,
|
||||
User: r.User,
|
||||
Email: r.Email,
|
||||
Expiry: expiry,
|
||||
// RefreshDeadline: (expiry).Truncate(time.Second),
|
||||
// LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
// ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
|
||||
// user's session. Requires a valid refresh token. Will return an error if the identity provider
|
||||
// has revoked the session or if the refresh token is no longer valid in this context.
|
||||
func (a *AuthenticateGRPC) Refresh(refreshToken string) (string, time.Time, error) {
|
||||
if refreshToken == "" {
|
||||
return "", time.Time{}, errors.New("missing refresh token")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := a.client.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return r.AccessToken, expiry, nil
|
||||
}
|
||||
|
||||
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
|
||||
// does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (a *AuthenticateGRPC) Validate(idToken string) (bool, error) {
|
||||
if idToken == "" {
|
||||
return false, errors.New("missing id token")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r.IsValid, nil
|
||||
}
|
||||
|
||||
// Close tears down the ClientConn and all underlying connections.
|
||||
func (a *AuthenticateGRPC) Close() error {
|
||||
return a.conn.Close()
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package proxy
|
||||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -6,8 +6,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
|
@ -34,7 +32,7 @@ func (r *rpcMsg) String() string {
|
|||
return fmt.Sprintf("is %s", r.msg)
|
||||
}
|
||||
|
||||
func TestProxy_AuthenticateRedeem(t *testing.T) {
|
||||
func TestProxy_Redeem(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
|
@ -55,29 +53,26 @@ func TestProxy_AuthenticateRedeem(t *testing.T) {
|
|||
Email: "test@email.com",
|
||||
Expiry: mockExpire,
|
||||
}, nil)
|
||||
p := &Proxy{AuthenticatorClient: mockAuthenticateClient}
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
want *sessions.SessionState
|
||||
want *RedeemResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "unit_test", &sessions.SessionState{
|
||||
AccessToken: "mocked access token",
|
||||
RefreshToken: "mocked refresh token",
|
||||
IDToken: "mocked id token",
|
||||
User: "user1",
|
||||
Email: "test@email.com",
|
||||
RefreshDeadline: (fixedDate).Truncate(time.Second),
|
||||
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
{"good", "unit_test", &RedeemResponse{
|
||||
AccessToken: "mocked access token",
|
||||
RefreshToken: "mocked refresh token",
|
||||
IDToken: "mocked id token",
|
||||
User: "user1",
|
||||
Email: "test@email.com",
|
||||
Expiry: (fixedDate),
|
||||
}, false},
|
||||
{"empty code", "", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
got, err := p.AuthenticateRedeem(tt.idToken)
|
||||
a := AuthenticateGRPC{client: mockAuthenticateClient}
|
||||
got, err := a.Redeem(tt.idToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -113,7 +108,7 @@ func TestProxy_AuthenticateValidate(t *testing.T) {
|
|||
&rpcMsg{msg: req},
|
||||
).Return(&pb.ValidateReply{IsValid: false}, nil)
|
||||
|
||||
p := &Proxy{AuthenticatorClient: mockAuthenticateClient}
|
||||
ac := mockAuthenticateClient
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
|
@ -125,8 +120,9 @@ func TestProxy_AuthenticateValidate(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: ac}
|
||||
|
||||
got, err := p.AuthenticateValidate(tt.idToken)
|
||||
got, err := a.Validate(tt.idToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -167,9 +163,9 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Proxy{AuthenticatorClient: mockRefreshClient}
|
||||
a := AuthenticateGRPC{client: mockRefreshClient}
|
||||
|
||||
got, gotExp, err := p.AuthenticateRefresh(tt.refreshToken)
|
||||
got, gotExp, err := a.Refresh(tt.refreshToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -183,22 +179,3 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extendDeadline(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ttl time.Duration
|
||||
want time.Time
|
||||
}{
|
||||
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
|
||||
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
35
proxy/authenticator/mock_authenticator.go
Normal file
35
proxy/authenticator/mock_authenticator.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// MockAuthenticate is a mock authenticator interface
|
||||
type MockAuthenticate struct {
|
||||
RedeemError error
|
||||
RedeemResponse *RedeemResponse
|
||||
RefreshResponse string
|
||||
RefreshTime time.Time
|
||||
RefreshError error
|
||||
ValidateResponse bool
|
||||
ValidateError error
|
||||
CloseError error
|
||||
}
|
||||
|
||||
// Redeem is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Redeem(code string) (*RedeemResponse, error) {
|
||||
return a.RedeemResponse, a.RedeemError
|
||||
}
|
||||
|
||||
// Refresh is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Refresh(refreshToken string) (string, time.Time, error) {
|
||||
return a.RefreshResponse, a.RefreshTime, a.RefreshError
|
||||
}
|
||||
|
||||
// Validate is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Validate(idToken string) (bool, error) {
|
||||
return a.ValidateResponse, a.ValidateError
|
||||
}
|
||||
|
||||
// Close is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Close() error { return a.ValidateError }
|
|
@ -1,80 +0,0 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// AuthenticateRedeem makes an RPC call to the authenticate service to creates a session state
|
||||
// from an encrypted code provided as a result of an oauth2 callback process.
|
||||
func (p *Proxy) AuthenticateRedeem(code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := p.AuthenticatorClient.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sessions.SessionState{
|
||||
AccessToken: r.AccessToken,
|
||||
RefreshToken: r.RefreshToken,
|
||||
IDToken: r.IdToken,
|
||||
User: r.User,
|
||||
Email: r.Email,
|
||||
RefreshDeadline: (expiry).Truncate(time.Second),
|
||||
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthenticateRefresh makes an RPC call to the authenticate service to attempt to refresh the
|
||||
// user's session. Requires a valid refresh token. Will return an error if the identity provider
|
||||
// has revoked the session or if the refresh token is no longer valid in this context.
|
||||
func (p *Proxy) AuthenticateRefresh(refreshToken string) (string, time.Time, error) {
|
||||
if refreshToken == "" {
|
||||
return "", time.Time{}, errors.New("missing refresh token")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := p.AuthenticatorClient.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return r.AccessToken, expiry, nil
|
||||
}
|
||||
|
||||
// AuthenticateValidate makes an RPC call to the authenticate service to validate the JWT id token;
|
||||
// does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *Proxy) AuthenticateValidate(idToken string) (bool, error) {
|
||||
if idToken == "" {
|
||||
return false, errors.New("missing id token")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := p.AuthenticatorClient.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r.IsValid, nil
|
||||
}
|
||||
|
||||
func extendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
|
@ -160,7 +160,7 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// We begin the process of redeeming the code for an access token.
|
||||
session, err := p.AuthenticateRedeem(r.Form.Get("code"))
|
||||
rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code"))
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
||||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||
|
@ -168,6 +168,10 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
encryptedState := r.Form.Get("state")
|
||||
log.Warn().
|
||||
Str("encryptedState", encryptedState).
|
||||
Msg("OK")
|
||||
|
||||
stateParameter := &StateParameter{}
|
||||
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
||||
if err != nil {
|
||||
|
@ -192,13 +196,11 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if encryptedState == encryptedCSRF {
|
||||
log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
|
||||
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
||||
log.FromRequest(r).Error().Msg("state and CSRF should be equal")
|
||||
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
||||
|
@ -206,7 +208,16 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// We store the session in a cookie and redirect the user back to the application
|
||||
err = p.sessionStore.SaveSession(w, r, session)
|
||||
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{
|
||||
AccessToken: rr.AccessToken,
|
||||
RefreshToken: rr.RefreshToken,
|
||||
IDToken: rr.IDToken,
|
||||
User: rr.User,
|
||||
Email: rr.Email,
|
||||
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
|
||||
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
})
|
||||
if err != nil {
|
||||
log.FromRequest(r).Error().Msg("error saving session")
|
||||
httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
|
||||
|
@ -216,8 +227,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
log.FromRequest(r).Info().
|
||||
Str("code", r.Form.Get("code")).
|
||||
Str("state", r.Form.Get("state")).
|
||||
Str("RefreshToken", session.RefreshToken).
|
||||
Str("session", session.AccessToken).
|
||||
Str("RefreshToken", rr.RefreshToken).
|
||||
Str("session", rr.AccessToken).
|
||||
Str("RedirectURI", stateParameter.RedirectURI).
|
||||
Msg("session")
|
||||
|
||||
|
@ -242,10 +253,6 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrUserNotAuthorized:
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: user access forbidden")
|
||||
httputil.ErrorResponse(w, r, "You don't have access", http.StatusForbidden)
|
||||
return
|
||||
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: starting auth flow")
|
||||
p.OAuthStart(w, r)
|
||||
|
@ -256,8 +263,9 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ! ! !
|
||||
// todo(bdd): ! Authorization checks will go here !
|
||||
// todo(bdd): ! Authorization service goes here !
|
||||
// ! ! !
|
||||
|
||||
// We have validated the users request and now proxy their request to the provided upstream.
|
||||
|
@ -293,7 +301,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
|||
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
|
||||
// refresh token (which doesn't change) can be used to request a new access-token. If access
|
||||
// is revoked by identity provider, or no refresh token is set request will return an error
|
||||
accessToken, expiry, err := p.AuthenticateRefresh(session.RefreshToken)
|
||||
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Warn().
|
||||
Str("RefreshToken", session.RefreshToken).
|
||||
|
@ -377,3 +385,7 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
|
|||
a.RawQuery = params.Encode()
|
||||
return a
|
||||
}
|
||||
|
||||
func extendDeadline(ttl time.Duration) time.Time {
|
||||
return time.Now().Add(ttl).Truncate(time.Second)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -10,9 +11,34 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||
)
|
||||
|
||||
type mockCipher struct{}
|
||||
|
||||
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
||||
if string(s) == "error" {
|
||||
return []byte(""), errors.New("error encrypting")
|
||||
}
|
||||
return []byte("OK"), nil
|
||||
}
|
||||
|
||||
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
||||
if string(s) == "error" {
|
||||
return []byte(""), errors.New("error encrypting")
|
||||
}
|
||||
return []byte("OK"), nil
|
||||
}
|
||||
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
|
||||
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
||||
if string(s) == "unmarshal error" || string(s) == "error" {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestProxy_RobotsTxt(t *testing.T) {
|
||||
proxy := Proxy{}
|
||||
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
||||
|
@ -49,8 +75,6 @@ func TestProxy_GetRedirectURL(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProxy_signRedirectURL(t *testing.T) {
|
||||
fixedDate := time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rawRedirect string
|
||||
|
@ -194,135 +218,71 @@ func TestProxy_Handler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// err := r.ParseForm()
|
||||
// if err != nil {
|
||||
// log.FromRequest(r).Error().Err(err).Msg("failed parsing request form")
|
||||
// httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
// errorString := r.Form.Get("error")
|
||||
// if errorString != "" {
|
||||
// httputil.ErrorResponse(w, r, errorString, http.StatusForbidden)
|
||||
// return
|
||||
// }
|
||||
// // We begin the process of redeeming the code for an access token.
|
||||
// session, err := p.AuthenticateRedeem(r.Form.Get("code"))
|
||||
// if err != nil {
|
||||
// log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
||||
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
|
||||
// encryptedState := r.Form.Get("state")
|
||||
// stateParameter := &StateParameter{}
|
||||
// err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
||||
// if err != nil {
|
||||
// log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state")
|
||||
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
|
||||
// c, err := p.csrfStore.GetCSRF(r)
|
||||
// if err != nil {
|
||||
// log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie")
|
||||
// httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
// p.csrfStore.ClearCSRF(w, r)
|
||||
|
||||
// encryptedCSRF := c.Value
|
||||
// csrfParameter := &StateParameter{}
|
||||
// err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
|
||||
// if err != nil {
|
||||
// log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF")
|
||||
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
|
||||
// if encryptedState == encryptedCSRF {
|
||||
// log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
|
||||
// httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
|
||||
// if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
||||
// log.FromRequest(r).Error().Msg("state and CSRF should be equal")
|
||||
// httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
|
||||
// // We store the session in a cookie and redirect the user back to the application
|
||||
// err = p.sessionStore.SaveSession(w, r, session)
|
||||
// if err != nil {
|
||||
// log.FromRequest(r).Error().Msg("error saving session")
|
||||
// httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
|
||||
// log.FromRequest(r).Info().
|
||||
// Str("code", r.Form.Get("code")).
|
||||
// Str("state", r.Form.Get("state")).
|
||||
// Str("RefreshToken", session.RefreshToken).
|
||||
// Str("session", session.AccessToken).
|
||||
// Str("RedirectURI", stateParameter.RedirectURI).
|
||||
// Msg("session")
|
||||
|
||||
// // This is the redirect back to the original requested application
|
||||
// http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
|
||||
// }
|
||||
|
||||
// func TestProxy_OAuthCallback2(t *testing.T) {
|
||||
// proxy, err := New(testOptions())
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// testError := url.Values{"error": []string{"There was a bad error to handle"}}
|
||||
// req := httptest.NewRequest("GET", "/oauth-callback", strings.NewReader(testError.Encode()))
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// rr := httptest.NewRecorder()
|
||||
// proxy.OAuthCallback)
|
||||
// // expect oauth redirect
|
||||
// if status := rr.Code; status != http.StatusInternalServerError {
|
||||
// t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusInternalServerError)
|
||||
// }
|
||||
// // expected url
|
||||
// // expected := `<a href="https://sso-auth.corp.beyondperimeter.com/sign_in`
|
||||
// // body := rr.Body.String()
|
||||
// // if !strings.HasPrefix(body, expected) {
|
||||
// // t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||
// // }
|
||||
// }
|
||||
|
||||
func TestProxy_OAuthCallback(t *testing.T) {
|
||||
proxy, err := New(testOptions())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
//todo(bdd): test malformed requests
|
||||
// https://github.com/golang/go/blob/master/src/net/http/request_test.go#L110
|
||||
normalSession := sessions.MockSessionStore{
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(-10 * time.Second),
|
||||
},
|
||||
}
|
||||
normalAuth := authenticator.MockAuthenticate{
|
||||
RedeemResponse: &authenticator.RedeemResponse{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Expiry: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
}
|
||||
normalCsrf := sessions.MockCSRFStore{
|
||||
ResponseCSRF: "ok",
|
||||
GetError: nil,
|
||||
Cookie: &http.Cookie{
|
||||
Name: "something_csrf",
|
||||
Value: "csrf_state",
|
||||
}}
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
params map[string]string
|
||||
wantCode int
|
||||
name string
|
||||
csrf sessions.MockCSRFStore
|
||||
session sessions.MockSessionStore
|
||||
authenticator authenticator.MockAuthenticate
|
||||
params map[string]string
|
||||
wantCode int
|
||||
}{
|
||||
{"nil", http.MethodPost, nil, http.StatusInternalServerError},
|
||||
{"error", http.MethodPost, map[string]string{"error": "some error"}, http.StatusForbidden},
|
||||
{"state", http.MethodPost, map[string]string{"code": "code"}, http.StatusInternalServerError},
|
||||
{"good", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
|
||||
{"error", normalCsrf, normalSession, normalAuth, map[string]string{"error": "some error"}, http.StatusForbidden},
|
||||
{"code err", normalCsrf, normalSession, authenticator.MockAuthenticate{RedeemError: errors.New("error")}, map[string]string{"code": "error"}, http.StatusInternalServerError},
|
||||
{"state err", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
|
||||
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusBadRequest},
|
||||
{"unmarshal err", sessions.MockCSRFStore{
|
||||
Cookie: &http.Cookie{
|
||||
Name: "something_csrf",
|
||||
Value: "unmarshal error",
|
||||
},
|
||||
}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
{"encrypted state != CSRF", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "csrf_state"}, http.StatusBadRequest},
|
||||
{"session save err", normalCsrf, sessions.MockSessionStore{SaveError: errors.New("error")}, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest(tt.method, "/.pomerium/callback", nil)
|
||||
proxy, err := New(testOptions())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
proxy.sessionStore = tt.session
|
||||
proxy.csrfStore = tt.csrf
|
||||
proxy.AuthenticateClient = tt.authenticator
|
||||
proxy.cipher = mockCipher{}
|
||||
// proxy.Csrf
|
||||
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
|
||||
q := req.URL.Query()
|
||||
for k, v := range tt.params {
|
||||
q.Add(k, v)
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
fmt.Println("OK OK OK OK")
|
||||
|
||||
fmt.Println(req.URL.String())
|
||||
w := httptest.NewRecorder()
|
||||
proxy.OAuthCallback(w, req)
|
||||
if status := w.Code; status != tt.wantCode {
|
||||
|
@ -330,4 +290,188 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
func Test_extendDeadline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ttl time.Duration
|
||||
want time.Time
|
||||
}{
|
||||
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
|
||||
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_router(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
mux map[string]string
|
||||
route http.Handler
|
||||
wantOk bool
|
||||
}{
|
||||
{"good corp", "https://corp.example.com", map[string]string{"corp.example.com": "example.com"}, nil, true},
|
||||
{"good with slash", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, nil, true},
|
||||
{"good with path", "https://corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, true},
|
||||
{"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
|
||||
{"bad corp", "https://notcorp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
|
||||
{"bad sub-sub", "https://notcorp.corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions()
|
||||
opts.Routes = tt.mux
|
||||
p, err := New(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.AuthenticateClient = authenticator.MockAuthenticate{}
|
||||
p.cipher = mockCipher{}
|
||||
|
||||
req := httptest.NewRequest("GET", tt.host, nil)
|
||||
_, ok := p.router(req)
|
||||
if ok != tt.wantOk {
|
||||
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_Proxy(t *testing.T) {
|
||||
goodSession := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
}
|
||||
expiredLifetime := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||
}
|
||||
// expiredDeadline := &sessions.SessionState{
|
||||
// AccessToken: "AccessToken",
|
||||
// RefreshToken: "RefreshToken",
|
||||
// LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
// RefreshDeadline: time.Now().Add(-10 * time.Second),
|
||||
// }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
session sessions.SessionStore
|
||||
authenticator authenticator.Authenticator
|
||||
wantStatus int
|
||||
}{
|
||||
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
|
||||
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
|
||||
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
|
||||
// redirect to start auth process
|
||||
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
|
||||
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions()
|
||||
p, err := New(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.cipher = mockCipher{}
|
||||
p.sessionStore = tt.session
|
||||
p.AuthenticateClient = tt.authenticator
|
||||
|
||||
r := httptest.NewRequest("GET", tt.host, nil)
|
||||
w := httptest.NewRecorder()
|
||||
p.Proxy(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_Authenticate(t *testing.T) {
|
||||
goodSession := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||
}
|
||||
expiredLifetime := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||
}
|
||||
expiredDeadline := &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||
RefreshDeadline: time.Now().Add(-10 * time.Second),
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
mux map[string]string
|
||||
session sessions.SessionStore
|
||||
authenticator authenticator.Authenticator
|
||||
wantErr bool
|
||||
}{
|
||||
{"cannot load session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
|
||||
{"expired lifetime",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
|
||||
{"expired session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
|
||||
{"bad refresh authenticator",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{
|
||||
Session: expiredDeadline,
|
||||
},
|
||||
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
|
||||
true},
|
||||
|
||||
{"good",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions()
|
||||
opts.Routes = tt.mux
|
||||
p, err := New(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.sessionStore = tt.session
|
||||
p.AuthenticateClient = tt.authenticator
|
||||
p.cipher = mockCipher{}
|
||||
r := httptest.NewRequest("GET", tt.host, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := p.Authenticate(w, r); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -15,15 +13,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pomerium/envconfig"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -40,8 +35,7 @@ type Options struct {
|
|||
// Authenticate service settings
|
||||
AuthenticateURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
||||
AuthenticateInternalURL string `envconfig:"AUTHENTICATE_INTERNAL_URL"`
|
||||
//
|
||||
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
|
||||
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
|
||||
|
||||
// SigningKey is a base64 encoded private key used to add a JWT-signature to proxied requests.
|
||||
// See : https://www.pomerium.io/guide/signed-headers.html
|
||||
|
@ -131,21 +125,17 @@ func (o *Options) Validate() error {
|
|||
type Proxy struct {
|
||||
SharedKey string
|
||||
|
||||
// Authenticate Service Configuration
|
||||
AuthenticateURL *url.URL
|
||||
AuthenticateInternalURL string
|
||||
AuthenticatorClient pb.AuthenticatorClient
|
||||
// AuthenticateConn must be closed by Proxy's caller
|
||||
AuthenticateConn *grpc.ClientConn
|
||||
// services
|
||||
AuthenticateURL *url.URL
|
||||
AuthenticateClient authenticator.Authenticator
|
||||
|
||||
OverideCertificateName string
|
||||
// session
|
||||
cipher cryptutil.Cipher
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieLifetimeTTL time.Duration
|
||||
cipher cryptutil.Cipher
|
||||
csrfStore sessions.CSRFStore
|
||||
sessionStore sessions.SessionStore
|
||||
|
||||
redirectURL *url.URL
|
||||
templates *template.Template
|
||||
|
@ -154,6 +144,8 @@ type Proxy struct {
|
|||
|
||||
// New takes a Proxy service from options and a validation function.
|
||||
// Function returns an error if options fail to validate.
|
||||
//
|
||||
// Caller responsible for closing AuthenticateConn.
|
||||
func New(opts *Options) (*Proxy, error) {
|
||||
if opts == nil {
|
||||
return nil, errors.New("options cannot be nil")
|
||||
|
@ -182,20 +174,18 @@ func New(opts *Options) (*Proxy, error) {
|
|||
}
|
||||
|
||||
p := &Proxy{
|
||||
// these fields make up the routing mechanism
|
||||
mux: make(map[string]http.Handler),
|
||||
// services
|
||||
AuthenticateURL: opts.AuthenticateURL,
|
||||
// session state
|
||||
cipher: cipher,
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
AuthenticateURL: opts.AuthenticateURL,
|
||||
AuthenticateInternalURL: opts.AuthenticateInternalURL,
|
||||
OverideCertificateName: opts.OverideCertificateName,
|
||||
SharedKey: opts.SharedKey,
|
||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||
templates: templates.New(),
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieLifetimeTTL: opts.CookieLifetimeTTL,
|
||||
cipher: cipher,
|
||||
csrfStore: cookieStore,
|
||||
sessionStore: cookieStore,
|
||||
SharedKey: opts.SharedKey,
|
||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||
templates: templates.New(),
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieLifetimeTTL: opts.CookieLifetimeTTL,
|
||||
}
|
||||
|
||||
for from, to := range opts.Routes {
|
||||
|
@ -209,41 +199,11 @@ func New(opts *Options) (*Proxy, error) {
|
|||
p.Handle(fromURL.Host, handler)
|
||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
|
||||
}
|
||||
// if no port given, assume https/443
|
||||
port := p.AuthenticateURL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
authEndpoint := fmt.Sprintf("%s:%s", p.AuthenticateURL.Host, port)
|
||||
|
||||
cp, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.AuthenticateInternalURL != "" {
|
||||
authEndpoint = p.AuthenticateInternalURL
|
||||
}
|
||||
|
||||
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
|
||||
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
||||
if p.OverideCertificateName != "" {
|
||||
err = cert.OverrideServerName(p.OverideCertificateName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
grpcAuth := middleware.NewSharedSecretCred(p.SharedKey)
|
||||
p.AuthenticateConn, err = grpc.Dial(
|
||||
authEndpoint,
|
||||
grpc.WithTransportCredentials(cert),
|
||||
grpc.WithPerRPCCredentials(grpcAuth),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.AuthenticatorClient = pb.NewAuthenticatorClient(p.AuthenticateConn)
|
||||
|
||||
p.AuthenticateClient, err = authenticator.New(
|
||||
opts.AuthenticateURL,
|
||||
opts.AuthenticateInternalURL,
|
||||
opts.OverideCertificateName,
|
||||
opts.SharedKey)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -9,8 +9,11 @@ import (
|
|||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||
os.Clearenv()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue