mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 13:07:13 +02:00
proxy: add unit tests (#43)
This commit is contained in:
parent
cedf9922d3
commit
4f4f3965aa
12 changed files with 577 additions and 323 deletions
|
@ -74,7 +74,7 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("cmd/pomerium: new proxy")
|
log.Fatal().Err(err).Msg("cmd/pomerium: new proxy")
|
||||||
}
|
}
|
||||||
defer proxyService.AuthenticateConn.Close()
|
defer proxyService.AuthenticateClient.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
topMux := http.NewServeMux()
|
topMux := http.NewServeMux()
|
||||||
|
|
|
@ -12,17 +12,17 @@ type MockCSRFStore struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCSRF sets the ResponseCSRF string to a val
|
// SetCSRF sets the ResponseCSRF string to a val
|
||||||
func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
ms.ResponseCSRF = val
|
ms.ResponseCSRF = val
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearCSRF clears the ResponseCSRF string
|
// ClearCSRF clears the ResponseCSRF string
|
||||||
func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||||
ms.ResponseCSRF = ""
|
ms.ResponseCSRF = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCSRF returns the cookie and error
|
// GetCSRF returns the cookie and error
|
||||||
func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||||
return ms.Cookie, ms.GetError
|
return ms.Cookie, ms.GetError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,16 +35,16 @@ type MockSessionStore struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearSession clears the ResponseSession
|
// ClearSession clears the ResponseSession
|
||||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
func (ms MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||||
ms.ResponseSession = ""
|
ms.ResponseSession = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSession returns the session and a error
|
// LoadSession returns the session and a error
|
||||||
func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
func (ms MockSessionStore) LoadSession(*http.Request) (*SessionState, error) {
|
||||||
return ms.Session, ms.LoadError
|
return ms.Session, ms.LoadError
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession returns a save error.
|
// SaveSession returns a save error.
|
||||||
func (ms *MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *SessionState) error {
|
||||||
return ms.SaveError
|
return ms.SaveError
|
||||||
}
|
}
|
||||||
|
|
69
proxy/authenticator/authenticator.go
Normal file
69
proxy/authenticator/authenticator.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Authenticator provides the authenticate service interface
|
||||||
|
type Authenticator interface {
|
||||||
|
// Redeem takes a code and returns a validated session or an error
|
||||||
|
Redeem(string) (*RedeemResponse, error)
|
||||||
|
// Refresh attempts to refresh a valid session with a refresh token. Returns a new access token
|
||||||
|
// and expiration, or an error.
|
||||||
|
Refresh(string) (string, time.Time, error)
|
||||||
|
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
|
||||||
|
Validate(string) (bool, error)
|
||||||
|
// Close closes the authenticator connection if any.
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new identity provider based given its name.
|
||||||
|
// Returns an error if selected provided not found or if the identity provider is not known.
|
||||||
|
func New(uri *url.URL, internalURL, OverideCertificateName, key string) (p Authenticator, err error) {
|
||||||
|
// if no port given, assume https/443
|
||||||
|
port := uri.Port()
|
||||||
|
if port == "" {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
authEndpoint := fmt.Sprintf("%s:%s", uri.Host, port)
|
||||||
|
|
||||||
|
cp, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if internalURL != "" {
|
||||||
|
authEndpoint = internalURL
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
|
||||||
|
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
||||||
|
if OverideCertificateName != "" {
|
||||||
|
err = cert.OverrideServerName(OverideCertificateName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
grpcAuth := middleware.NewSharedSecretCred(key)
|
||||||
|
conn, err := grpc.Dial(
|
||||||
|
authEndpoint,
|
||||||
|
grpc.WithTransportCredentials(cert),
|
||||||
|
grpc.WithPerRPCCredentials(grpcAuth),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authClient := pb.NewAuthenticatorClient(conn)
|
||||||
|
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
|
||||||
|
}
|
36
proxy/authenticator/authenticator_test.go
Normal file
36
proxy/authenticator/authenticator_test.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package authenticator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
uri *url.URL
|
||||||
|
internalURL string
|
||||||
|
OverideCertificateName string
|
||||||
|
key string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantP Authenticator
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotP, err := New(tt.args.uri, tt.args.internalURL, tt.args.OverideCertificateName, tt.args.key)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotP, tt.wantP) {
|
||||||
|
t.Errorf("New() = %v, want %v", gotP, tt.wantP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
98
proxy/authenticator/grpc.go
Normal file
98
proxy/authenticator/grpc.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/ptypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RedeemResponse contains data from a authenticator redeem request.
|
||||||
|
type RedeemResponse struct {
|
||||||
|
AccessToken string
|
||||||
|
RefreshToken string
|
||||||
|
IDToken string
|
||||||
|
User string
|
||||||
|
Email string
|
||||||
|
Expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
|
||||||
|
type AuthenticateGRPC struct {
|
||||||
|
conn *grpc.ClientConn
|
||||||
|
client pb.AuthenticatorClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem makes an RPC call to the authenticate service to creates a session state
|
||||||
|
// from an encrypted code provided as a result of an oauth2 callback process.
|
||||||
|
func (a *AuthenticateGRPC) Redeem(code string) (*RedeemResponse, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("missing code")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
r, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &RedeemResponse{
|
||||||
|
AccessToken: r.AccessToken,
|
||||||
|
RefreshToken: r.RefreshToken,
|
||||||
|
IDToken: r.IdToken,
|
||||||
|
User: r.User,
|
||||||
|
Email: r.Email,
|
||||||
|
Expiry: expiry,
|
||||||
|
// RefreshDeadline: (expiry).Truncate(time.Second),
|
||||||
|
// LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||||
|
// ValidDeadline: extendDeadline(p.CookieExpire),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
|
||||||
|
// user's session. Requires a valid refresh token. Will return an error if the identity provider
|
||||||
|
// has revoked the session or if the refresh token is no longer valid in this context.
|
||||||
|
func (a *AuthenticateGRPC) Refresh(refreshToken string) (string, time.Time, error) {
|
||||||
|
if refreshToken == "" {
|
||||||
|
return "", time.Time{}, errors.New("missing refresh token")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
r, err := a.client.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
|
||||||
|
if err != nil {
|
||||||
|
return "", time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||||
|
if err != nil {
|
||||||
|
return "", time.Time{}, err
|
||||||
|
}
|
||||||
|
return r.AccessToken, expiry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
|
||||||
|
// does NOT do nonce or revokation validation.
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
|
func (a *AuthenticateGRPC) Validate(idToken string) (bool, error) {
|
||||||
|
if idToken == "" {
|
||||||
|
return false, errors.New("missing id token")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return r.IsValid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close tears down the ClientConn and all underlying connections.
|
||||||
|
func (a *AuthenticateGRPC) Close() error {
|
||||||
|
return a.conn.Close()
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package proxy
|
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -6,8 +6,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
|
@ -34,7 +32,7 @@ func (r *rpcMsg) String() string {
|
||||||
return fmt.Sprintf("is %s", r.msg)
|
return fmt.Sprintf("is %s", r.msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_AuthenticateRedeem(t *testing.T) {
|
func TestProxy_Redeem(t *testing.T) {
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||||
|
@ -55,29 +53,26 @@ func TestProxy_AuthenticateRedeem(t *testing.T) {
|
||||||
Email: "test@email.com",
|
Email: "test@email.com",
|
||||||
Expiry: mockExpire,
|
Expiry: mockExpire,
|
||||||
}, nil)
|
}, nil)
|
||||||
p := &Proxy{AuthenticatorClient: mockAuthenticateClient}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
idToken string
|
idToken string
|
||||||
want *sessions.SessionState
|
want *RedeemResponse
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"good", "unit_test", &sessions.SessionState{
|
{"good", "unit_test", &RedeemResponse{
|
||||||
AccessToken: "mocked access token",
|
AccessToken: "mocked access token",
|
||||||
RefreshToken: "mocked refresh token",
|
RefreshToken: "mocked refresh token",
|
||||||
IDToken: "mocked id token",
|
IDToken: "mocked id token",
|
||||||
User: "user1",
|
User: "user1",
|
||||||
Email: "test@email.com",
|
Email: "test@email.com",
|
||||||
RefreshDeadline: (fixedDate).Truncate(time.Second),
|
Expiry: (fixedDate),
|
||||||
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
|
||||||
ValidDeadline: extendDeadline(p.CookieExpire),
|
|
||||||
}, false},
|
}, false},
|
||||||
{"empty code", "", nil, true},
|
{"empty code", "", nil, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := AuthenticateGRPC{client: mockAuthenticateClient}
|
||||||
got, err := p.AuthenticateRedeem(tt.idToken)
|
got, err := a.Redeem(tt.idToken)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
|
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -113,7 +108,7 @@ func TestProxy_AuthenticateValidate(t *testing.T) {
|
||||||
&rpcMsg{msg: req},
|
&rpcMsg{msg: req},
|
||||||
).Return(&pb.ValidateReply{IsValid: false}, nil)
|
).Return(&pb.ValidateReply{IsValid: false}, nil)
|
||||||
|
|
||||||
p := &Proxy{AuthenticatorClient: mockAuthenticateClient}
|
ac := mockAuthenticateClient
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
idToken string
|
idToken string
|
||||||
|
@ -125,8 +120,9 @@ func TestProxy_AuthenticateValidate(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := AuthenticateGRPC{client: ac}
|
||||||
|
|
||||||
got, err := p.AuthenticateValidate(tt.idToken)
|
got, err := a.Validate(tt.idToken)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -167,9 +163,9 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p := &Proxy{AuthenticatorClient: mockRefreshClient}
|
a := AuthenticateGRPC{client: mockRefreshClient}
|
||||||
|
|
||||||
got, gotExp, err := p.AuthenticateRefresh(tt.refreshToken)
|
got, gotExp, err := a.Refresh(tt.refreshToken)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -183,22 +179,3 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_extendDeadline(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
ttl time.Duration
|
|
||||||
want time.Time
|
|
||||||
}{
|
|
||||||
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
|
|
||||||
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
35
proxy/authenticator/mock_authenticator.go
Normal file
35
proxy/authenticator/mock_authenticator.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockAuthenticate is a mock authenticator interface
|
||||||
|
type MockAuthenticate struct {
|
||||||
|
RedeemError error
|
||||||
|
RedeemResponse *RedeemResponse
|
||||||
|
RefreshResponse string
|
||||||
|
RefreshTime time.Time
|
||||||
|
RefreshError error
|
||||||
|
ValidateResponse bool
|
||||||
|
ValidateError error
|
||||||
|
CloseError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem is a mocked implementation for authenticator testing.
|
||||||
|
func (a MockAuthenticate) Redeem(code string) (*RedeemResponse, error) {
|
||||||
|
return a.RedeemResponse, a.RedeemError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh is a mocked implementation for authenticator testing.
|
||||||
|
func (a MockAuthenticate) Refresh(refreshToken string) (string, time.Time, error) {
|
||||||
|
return a.RefreshResponse, a.RefreshTime, a.RefreshError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate is a mocked implementation for authenticator testing.
|
||||||
|
func (a MockAuthenticate) Validate(idToken string) (bool, error) {
|
||||||
|
return a.ValidateResponse, a.ValidateError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is a mocked implementation for authenticator testing.
|
||||||
|
func (a MockAuthenticate) Close() error { return a.ValidateError }
|
|
@ -1,80 +0,0 @@
|
||||||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AuthenticateRedeem makes an RPC call to the authenticate service to creates a session state
|
|
||||||
// from an encrypted code provided as a result of an oauth2 callback process.
|
|
||||||
func (p *Proxy) AuthenticateRedeem(code string) (*sessions.SessionState, error) {
|
|
||||||
if code == "" {
|
|
||||||
return nil, errors.New("missing code")
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
r, err := p.AuthenticatorClient.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &sessions.SessionState{
|
|
||||||
AccessToken: r.AccessToken,
|
|
||||||
RefreshToken: r.RefreshToken,
|
|
||||||
IDToken: r.IdToken,
|
|
||||||
User: r.User,
|
|
||||||
Email: r.Email,
|
|
||||||
RefreshDeadline: (expiry).Truncate(time.Second),
|
|
||||||
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
|
||||||
ValidDeadline: extendDeadline(p.CookieExpire),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthenticateRefresh makes an RPC call to the authenticate service to attempt to refresh the
|
|
||||||
// user's session. Requires a valid refresh token. Will return an error if the identity provider
|
|
||||||
// has revoked the session or if the refresh token is no longer valid in this context.
|
|
||||||
func (p *Proxy) AuthenticateRefresh(refreshToken string) (string, time.Time, error) {
|
|
||||||
if refreshToken == "" {
|
|
||||||
return "", time.Time{}, errors.New("missing refresh token")
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
r, err := p.AuthenticatorClient.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
|
|
||||||
if err != nil {
|
|
||||||
return "", time.Time{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
|
||||||
if err != nil {
|
|
||||||
return "", time.Time{}, err
|
|
||||||
}
|
|
||||||
return r.AccessToken, expiry, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthenticateValidate makes an RPC call to the authenticate service to validate the JWT id token;
|
|
||||||
// does NOT do nonce or revokation validation.
|
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
|
||||||
func (p *Proxy) AuthenticateValidate(idToken string) (bool, error) {
|
|
||||||
if idToken == "" {
|
|
||||||
return false, errors.New("missing id token")
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
r, err := p.AuthenticatorClient.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return r.IsValid, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extendDeadline(ttl time.Duration) time.Time {
|
|
||||||
return time.Now().Add(ttl).Truncate(time.Second)
|
|
||||||
}
|
|
|
@ -160,7 +160,7 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// We begin the process of redeeming the code for an access token.
|
// We begin the process of redeeming the code for an access token.
|
||||||
session, err := p.AuthenticateRedeem(r.Form.Get("code"))
|
rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
||||||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||||
|
@ -168,6 +168,10 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedState := r.Form.Get("state")
|
encryptedState := r.Form.Get("state")
|
||||||
|
log.Warn().
|
||||||
|
Str("encryptedState", encryptedState).
|
||||||
|
Msg("OK")
|
||||||
|
|
||||||
stateParameter := &StateParameter{}
|
stateParameter := &StateParameter{}
|
||||||
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -192,13 +196,11 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if encryptedState == encryptedCSRF {
|
if encryptedState == encryptedCSRF {
|
||||||
log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
|
log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
|
||||||
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
||||||
log.FromRequest(r).Error().Msg("state and CSRF should be equal")
|
log.FromRequest(r).Error().Msg("state and CSRF should be equal")
|
||||||
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
||||||
|
@ -206,7 +208,16 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We store the session in a cookie and redirect the user back to the application
|
// We store the session in a cookie and redirect the user back to the application
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = p.sessionStore.SaveSession(w, r, &sessions.SessionState{
|
||||||
|
AccessToken: rr.AccessToken,
|
||||||
|
RefreshToken: rr.RefreshToken,
|
||||||
|
IDToken: rr.IDToken,
|
||||||
|
User: rr.User,
|
||||||
|
Email: rr.Email,
|
||||||
|
RefreshDeadline: (rr.Expiry).Truncate(time.Second),
|
||||||
|
LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||||
|
ValidDeadline: extendDeadline(p.CookieExpire),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Msg("error saving session")
|
log.FromRequest(r).Error().Msg("error saving session")
|
||||||
httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
|
||||||
|
@ -216,8 +227,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
log.FromRequest(r).Info().
|
log.FromRequest(r).Info().
|
||||||
Str("code", r.Form.Get("code")).
|
Str("code", r.Form.Get("code")).
|
||||||
Str("state", r.Form.Get("state")).
|
Str("state", r.Form.Get("state")).
|
||||||
Str("RefreshToken", session.RefreshToken).
|
Str("RefreshToken", rr.RefreshToken).
|
||||||
Str("session", session.AccessToken).
|
Str("session", rr.AccessToken).
|
||||||
Str("RedirectURI", stateParameter.RedirectURI).
|
Str("RedirectURI", stateParameter.RedirectURI).
|
||||||
Msg("session")
|
Msg("session")
|
||||||
|
|
||||||
|
@ -242,10 +253,6 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case ErrUserNotAuthorized:
|
|
||||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: user access forbidden")
|
|
||||||
httputil.ErrorResponse(w, r, "You don't have access", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: starting auth flow")
|
log.FromRequest(r).Debug().Err(err).Msg("proxy: starting auth flow")
|
||||||
p.OAuthStart(w, r)
|
p.OAuthStart(w, r)
|
||||||
|
@ -256,8 +263,9 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ! ! !
|
// ! ! !
|
||||||
// todo(bdd): ! Authorization checks will go here !
|
// todo(bdd): ! Authorization service goes here !
|
||||||
// ! ! !
|
// ! ! !
|
||||||
|
|
||||||
// We have validated the users request and now proxy their request to the provided upstream.
|
// We have validated the users request and now proxy their request to the provided upstream.
|
||||||
|
@ -293,7 +301,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
||||||
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
|
// AccessToken's usually expire after 60 or so minutes. If offline_access scope is set, a
|
||||||
// refresh token (which doesn't change) can be used to request a new access-token. If access
|
// refresh token (which doesn't change) can be used to request a new access-token. If access
|
||||||
// is revoked by identity provider, or no refresh token is set request will return an error
|
// is revoked by identity provider, or no refresh token is set request will return an error
|
||||||
accessToken, expiry, err := p.AuthenticateRefresh(session.RefreshToken)
|
accessToken, expiry, err := p.AuthenticateClient.Refresh(session.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Warn().
|
log.FromRequest(r).Warn().
|
||||||
Str("RefreshToken", session.RefreshToken).
|
Str("RefreshToken", session.RefreshToken).
|
||||||
|
@ -377,3 +385,7 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
|
||||||
a.RawQuery = params.Encode()
|
a.RawQuery = params.Encode()
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extendDeadline(ttl time.Duration) time.Time {
|
||||||
|
return time.Now().Add(ttl).Truncate(time.Second)
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -10,9 +11,34 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockCipher struct{}
|
||||||
|
|
||||||
|
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
||||||
|
if string(s) == "error" {
|
||||||
|
return []byte(""), errors.New("error encrypting")
|
||||||
|
}
|
||||||
|
return []byte("OK"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
||||||
|
if string(s) == "error" {
|
||||||
|
return []byte(""), errors.New("error encrypting")
|
||||||
|
}
|
||||||
|
return []byte("OK"), nil
|
||||||
|
}
|
||||||
|
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
|
||||||
|
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
||||||
|
if string(s) == "unmarshal error" || string(s) == "error" {
|
||||||
|
return errors.New("error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxy_RobotsTxt(t *testing.T) {
|
func TestProxy_RobotsTxt(t *testing.T) {
|
||||||
proxy := Proxy{}
|
proxy := Proxy{}
|
||||||
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
||||||
|
@ -49,8 +75,6 @@ func TestProxy_GetRedirectURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_signRedirectURL(t *testing.T) {
|
func TestProxy_signRedirectURL(t *testing.T) {
|
||||||
fixedDate := time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
rawRedirect string
|
rawRedirect string
|
||||||
|
@ -194,135 +218,71 @@ func TestProxy_Handler(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// err := r.ParseForm()
|
|
||||||
// if err != nil {
|
|
||||||
// log.FromRequest(r).Error().Err(err).Msg("failed parsing request form")
|
|
||||||
// httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// errorString := r.Form.Get("error")
|
|
||||||
// if errorString != "" {
|
|
||||||
// httputil.ErrorResponse(w, r, errorString, http.StatusForbidden)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// // We begin the process of redeeming the code for an access token.
|
|
||||||
// session, err := p.AuthenticateRedeem(r.Form.Get("code"))
|
|
||||||
// if err != nil {
|
|
||||||
// log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
|
||||||
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// encryptedState := r.Form.Get("state")
|
|
||||||
// stateParameter := &StateParameter{}
|
|
||||||
// err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
|
||||||
// if err != nil {
|
|
||||||
// log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state")
|
|
||||||
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// c, err := p.csrfStore.GetCSRF(r)
|
|
||||||
// if err != nil {
|
|
||||||
// log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie")
|
|
||||||
// httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// p.csrfStore.ClearCSRF(w, r)
|
|
||||||
|
|
||||||
// encryptedCSRF := c.Value
|
|
||||||
// csrfParameter := &StateParameter{}
|
|
||||||
// err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
|
|
||||||
// if err != nil {
|
|
||||||
// log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF")
|
|
||||||
// httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if encryptedState == encryptedCSRF {
|
|
||||||
// log.FromRequest(r).Error().Msg("encrypted state and CSRF should not be equal")
|
|
||||||
// httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if !reflect.DeepEqual(stateParameter, csrfParameter) {
|
|
||||||
// log.FromRequest(r).Error().Msg("state and CSRF should be equal")
|
|
||||||
// httputil.ErrorResponse(w, r, "Bad request", http.StatusBadRequest)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // We store the session in a cookie and redirect the user back to the application
|
|
||||||
// err = p.sessionStore.SaveSession(w, r, session)
|
|
||||||
// if err != nil {
|
|
||||||
// log.FromRequest(r).Error().Msg("error saving session")
|
|
||||||
// httputil.ErrorResponse(w, r, "Error saving session", http.StatusInternalServerError)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// log.FromRequest(r).Info().
|
|
||||||
// Str("code", r.Form.Get("code")).
|
|
||||||
// Str("state", r.Form.Get("state")).
|
|
||||||
// Str("RefreshToken", session.RefreshToken).
|
|
||||||
// Str("session", session.AccessToken).
|
|
||||||
// Str("RedirectURI", stateParameter.RedirectURI).
|
|
||||||
// Msg("session")
|
|
||||||
|
|
||||||
// // This is the redirect back to the original requested application
|
|
||||||
// http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func TestProxy_OAuthCallback2(t *testing.T) {
|
|
||||||
// proxy, err := New(testOptions())
|
|
||||||
// if err != nil {
|
|
||||||
// t.Fatal(err)
|
|
||||||
// }
|
|
||||||
// testError := url.Values{"error": []string{"There was a bad error to handle"}}
|
|
||||||
// req := httptest.NewRequest("GET", "/oauth-callback", strings.NewReader(testError.Encode()))
|
|
||||||
// if err != nil {
|
|
||||||
// t.Fatal(err)
|
|
||||||
// }
|
|
||||||
// rr := httptest.NewRecorder()
|
|
||||||
// proxy.OAuthCallback)
|
|
||||||
// // expect oauth redirect
|
|
||||||
// if status := rr.Code; status != http.StatusInternalServerError {
|
|
||||||
// t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusInternalServerError)
|
|
||||||
// }
|
|
||||||
// // expected url
|
|
||||||
// // expected := `<a href="https://sso-auth.corp.beyondperimeter.com/sign_in`
|
|
||||||
// // body := rr.Body.String()
|
|
||||||
// // if !strings.HasPrefix(body, expected) {
|
|
||||||
// // t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
|
||||||
// // }
|
|
||||||
// }
|
|
||||||
|
|
||||||
func TestProxy_OAuthCallback(t *testing.T) {
|
func TestProxy_OAuthCallback(t *testing.T) {
|
||||||
proxy, err := New(testOptions())
|
//todo(bdd): test malformed requests
|
||||||
if err != nil {
|
// https://github.com/golang/go/blob/master/src/net/http/request_test.go#L110
|
||||||
t.Fatal(err)
|
normalSession := sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(-10 * time.Second),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
normalAuth := authenticator.MockAuthenticate{
|
||||||
|
RedeemResponse: &authenticator.RedeemResponse{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Expiry: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
normalCsrf := sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "ok",
|
||||||
|
GetError: nil,
|
||||||
|
Cookie: &http.Cookie{
|
||||||
|
Name: "something_csrf",
|
||||||
|
Value: "csrf_state",
|
||||||
|
}}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
csrf sessions.MockCSRFStore
|
||||||
params map[string]string
|
session sessions.MockSessionStore
|
||||||
wantCode int
|
authenticator authenticator.MockAuthenticate
|
||||||
|
params map[string]string
|
||||||
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"nil", http.MethodPost, nil, http.StatusInternalServerError},
|
{"good", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
|
||||||
{"error", http.MethodPost, map[string]string{"error": "some error"}, http.StatusForbidden},
|
{"error", normalCsrf, normalSession, normalAuth, map[string]string{"error": "some error"}, http.StatusForbidden},
|
||||||
{"state", http.MethodPost, map[string]string{"code": "code"}, http.StatusInternalServerError},
|
{"code err", normalCsrf, normalSession, authenticator.MockAuthenticate{RedeemError: errors.New("error")}, map[string]string{"code": "error"}, http.StatusInternalServerError},
|
||||||
|
{"state err", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
|
||||||
|
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusBadRequest},
|
||||||
|
{"unmarshal err", sessions.MockCSRFStore{
|
||||||
|
Cookie: &http.Cookie{
|
||||||
|
Name: "something_csrf",
|
||||||
|
Value: "unmarshal error",
|
||||||
|
},
|
||||||
|
}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||||
|
{"encrypted state != CSRF", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "csrf_state"}, http.StatusBadRequest},
|
||||||
|
{"session save err", normalCsrf, sessions.MockSessionStore{SaveError: errors.New("error")}, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
proxy, err := New(testOptions())
|
||||||
req := httptest.NewRequest(tt.method, "/.pomerium/callback", nil)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
proxy.sessionStore = tt.session
|
||||||
|
proxy.csrfStore = tt.csrf
|
||||||
|
proxy.AuthenticateClient = tt.authenticator
|
||||||
|
proxy.cipher = mockCipher{}
|
||||||
|
// proxy.Csrf
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
|
||||||
q := req.URL.Query()
|
q := req.URL.Query()
|
||||||
for k, v := range tt.params {
|
for k, v := range tt.params {
|
||||||
q.Add(k, v)
|
q.Add(k, v)
|
||||||
}
|
}
|
||||||
req.URL.RawQuery = q.Encode()
|
req.URL.RawQuery = q.Encode()
|
||||||
fmt.Println("OK OK OK OK")
|
|
||||||
|
|
||||||
fmt.Println(req.URL.String())
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.OAuthCallback(w, req)
|
proxy.OAuthCallback(w, req)
|
||||||
if status := w.Code; status != tt.wantCode {
|
if status := w.Code; status != tt.wantCode {
|
||||||
|
@ -330,4 +290,188 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
func Test_extendDeadline(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ttl time.Duration
|
||||||
|
want time.Time
|
||||||
|
}{
|
||||||
|
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
|
||||||
|
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_router(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
mux map[string]string
|
||||||
|
route http.Handler
|
||||||
|
wantOk bool
|
||||||
|
}{
|
||||||
|
{"good corp", "https://corp.example.com", map[string]string{"corp.example.com": "example.com"}, nil, true},
|
||||||
|
{"good with slash", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, nil, true},
|
||||||
|
{"good with path", "https://corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, true},
|
||||||
|
{"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
|
||||||
|
{"bad corp", "https://notcorp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
|
||||||
|
{"bad sub-sub", "https://notcorp.corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
opts := testOptions()
|
||||||
|
opts.Routes = tt.mux
|
||||||
|
p, err := New(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p.AuthenticateClient = authenticator.MockAuthenticate{}
|
||||||
|
p.cipher = mockCipher{}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", tt.host, nil)
|
||||||
|
_, ok := p.router(req)
|
||||||
|
if ok != tt.wantOk {
|
||||||
|
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_Proxy(t *testing.T) {
|
||||||
|
goodSession := &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}
|
||||||
|
expiredLifetime := &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||||
|
}
|
||||||
|
// expiredDeadline := &sessions.SessionState{
|
||||||
|
// AccessToken: "AccessToken",
|
||||||
|
// RefreshToken: "RefreshToken",
|
||||||
|
// LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
// RefreshDeadline: time.Now().Add(-10 * time.Second),
|
||||||
|
// }
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
session sessions.SessionStore
|
||||||
|
authenticator authenticator.Authenticator
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
|
||||||
|
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
|
||||||
|
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
|
||||||
|
// redirect to start auth process
|
||||||
|
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
|
||||||
|
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
opts := testOptions()
|
||||||
|
p, err := New(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p.cipher = mockCipher{}
|
||||||
|
p.sessionStore = tt.session
|
||||||
|
p.AuthenticateClient = tt.authenticator
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", tt.host, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
p.Proxy(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_Authenticate(t *testing.T) {
|
||||||
|
goodSession := &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}
|
||||||
|
expiredLifetime := &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||||
|
}
|
||||||
|
expiredDeadline := &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(-10 * time.Second),
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
mux map[string]string
|
||||||
|
session sessions.SessionStore
|
||||||
|
authenticator authenticator.Authenticator
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"cannot load session",
|
||||||
|
"https://corp.example.com/",
|
||||||
|
map[string]string{"corp.example.com": "example.com"},
|
||||||
|
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
|
||||||
|
{"expired lifetime",
|
||||||
|
"https://corp.example.com/",
|
||||||
|
map[string]string{"corp.example.com": "example.com"},
|
||||||
|
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
|
||||||
|
{"expired session",
|
||||||
|
"https://corp.example.com/",
|
||||||
|
map[string]string{"corp.example.com": "example.com"},
|
||||||
|
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
|
||||||
|
{"bad refresh authenticator",
|
||||||
|
"https://corp.example.com/",
|
||||||
|
map[string]string{"corp.example.com": "example.com"},
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
Session: expiredDeadline,
|
||||||
|
},
|
||||||
|
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
|
||||||
|
true},
|
||||||
|
|
||||||
|
{"good",
|
||||||
|
"https://corp.example.com/",
|
||||||
|
map[string]string{"corp.example.com": "example.com"},
|
||||||
|
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
opts := testOptions()
|
||||||
|
opts.Routes = tt.mux
|
||||||
|
p, err := New(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p.sessionStore = tt.session
|
||||||
|
p.AuthenticateClient = tt.authenticator
|
||||||
|
p.cipher = mockCipher{}
|
||||||
|
r := httptest.NewRequest("GET", tt.host, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := p.Authenticate(w, r); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Proxy.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -15,15 +13,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/envconfig"
|
"github.com/pomerium/envconfig"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
"github.com/pomerium/pomerium/proxy/authenticator"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -40,8 +35,7 @@ type Options struct {
|
||||||
// Authenticate service settings
|
// Authenticate service settings
|
||||||
AuthenticateURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
AuthenticateURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
||||||
AuthenticateInternalURL string `envconfig:"AUTHENTICATE_INTERNAL_URL"`
|
AuthenticateInternalURL string `envconfig:"AUTHENTICATE_INTERNAL_URL"`
|
||||||
//
|
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
|
||||||
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
|
|
||||||
|
|
||||||
// SigningKey is a base64 encoded private key used to add a JWT-signature to proxied requests.
|
// SigningKey is a base64 encoded private key used to add a JWT-signature to proxied requests.
|
||||||
// See : https://www.pomerium.io/guide/signed-headers.html
|
// See : https://www.pomerium.io/guide/signed-headers.html
|
||||||
|
@ -131,21 +125,17 @@ func (o *Options) Validate() error {
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
SharedKey string
|
SharedKey string
|
||||||
|
|
||||||
// Authenticate Service Configuration
|
// services
|
||||||
AuthenticateURL *url.URL
|
AuthenticateURL *url.URL
|
||||||
AuthenticateInternalURL string
|
AuthenticateClient authenticator.Authenticator
|
||||||
AuthenticatorClient pb.AuthenticatorClient
|
|
||||||
// AuthenticateConn must be closed by Proxy's caller
|
|
||||||
AuthenticateConn *grpc.ClientConn
|
|
||||||
|
|
||||||
OverideCertificateName string
|
|
||||||
// session
|
// session
|
||||||
cipher cryptutil.Cipher
|
|
||||||
csrfStore sessions.CSRFStore
|
|
||||||
sessionStore sessions.SessionStore
|
|
||||||
CookieExpire time.Duration
|
CookieExpire time.Duration
|
||||||
CookieRefresh time.Duration
|
CookieRefresh time.Duration
|
||||||
CookieLifetimeTTL time.Duration
|
CookieLifetimeTTL time.Duration
|
||||||
|
cipher cryptutil.Cipher
|
||||||
|
csrfStore sessions.CSRFStore
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
|
||||||
redirectURL *url.URL
|
redirectURL *url.URL
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
|
@ -154,6 +144,8 @@ type Proxy struct {
|
||||||
|
|
||||||
// New takes a Proxy service from options and a validation function.
|
// New takes a Proxy service from options and a validation function.
|
||||||
// Function returns an error if options fail to validate.
|
// Function returns an error if options fail to validate.
|
||||||
|
//
|
||||||
|
// Caller responsible for closing AuthenticateConn.
|
||||||
func New(opts *Options) (*Proxy, error) {
|
func New(opts *Options) (*Proxy, error) {
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
return nil, errors.New("options cannot be nil")
|
return nil, errors.New("options cannot be nil")
|
||||||
|
@ -182,20 +174,18 @@ func New(opts *Options) (*Proxy, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &Proxy{
|
p := &Proxy{
|
||||||
// these fields make up the routing mechanism
|
|
||||||
mux: make(map[string]http.Handler),
|
mux: make(map[string]http.Handler),
|
||||||
|
// services
|
||||||
|
AuthenticateURL: opts.AuthenticateURL,
|
||||||
// session state
|
// session state
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
csrfStore: cookieStore,
|
csrfStore: cookieStore,
|
||||||
sessionStore: cookieStore,
|
sessionStore: cookieStore,
|
||||||
AuthenticateURL: opts.AuthenticateURL,
|
SharedKey: opts.SharedKey,
|
||||||
AuthenticateInternalURL: opts.AuthenticateInternalURL,
|
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
||||||
OverideCertificateName: opts.OverideCertificateName,
|
templates: templates.New(),
|
||||||
SharedKey: opts.SharedKey,
|
CookieExpire: opts.CookieExpire,
|
||||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
CookieLifetimeTTL: opts.CookieLifetimeTTL,
|
||||||
templates: templates.New(),
|
|
||||||
CookieExpire: opts.CookieExpire,
|
|
||||||
CookieLifetimeTTL: opts.CookieLifetimeTTL,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for from, to := range opts.Routes {
|
for from, to := range opts.Routes {
|
||||||
|
@ -209,41 +199,11 @@ func New(opts *Options) (*Proxy, error) {
|
||||||
p.Handle(fromURL.Host, handler)
|
p.Handle(fromURL.Host, handler)
|
||||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
|
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
|
||||||
}
|
}
|
||||||
// if no port given, assume https/443
|
p.AuthenticateClient, err = authenticator.New(
|
||||||
port := p.AuthenticateURL.Port()
|
opts.AuthenticateURL,
|
||||||
if port == "" {
|
opts.AuthenticateInternalURL,
|
||||||
port = "443"
|
opts.OverideCertificateName,
|
||||||
}
|
opts.SharedKey)
|
||||||
authEndpoint := fmt.Sprintf("%s:%s", p.AuthenticateURL.Host, port)
|
|
||||||
|
|
||||||
cp, err := x509.SystemCertPool()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.AuthenticateInternalURL != "" {
|
|
||||||
authEndpoint = p.AuthenticateInternalURL
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
|
|
||||||
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
|
||||||
if p.OverideCertificateName != "" {
|
|
||||||
err = cert.OverrideServerName(p.OverideCertificateName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
grpcAuth := middleware.NewSharedSecretCred(p.SharedKey)
|
|
||||||
p.AuthenticateConn, err = grpc.Dial(
|
|
||||||
authEndpoint,
|
|
||||||
grpc.WithTransportCredentials(cert),
|
|
||||||
grpc.WithPerRPCCredentials(grpcAuth),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
p.AuthenticatorClient = pb.NewAuthenticatorClient(p.AuthenticateConn)
|
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,11 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||||
|
|
||||||
func TestOptionsFromEnvConfig(t *testing.T) {
|
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||||
os.Clearenv()
|
os.Clearenv()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue