mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
proxy: add unit tests (#43)
This commit is contained in:
parent
cedf9922d3
commit
4f4f3965aa
12 changed files with 577 additions and 323 deletions
69
proxy/authenticator/authenticator.go
Normal file
69
proxy/authenticator/authenticator.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// Authenticator provides the authenticate service interface
|
||||
type Authenticator interface {
|
||||
// Redeem takes a code and returns a validated session or an error
|
||||
Redeem(string) (*RedeemResponse, error)
|
||||
// Refresh attempts to refresh a valid session with a refresh token. Returns a new access token
|
||||
// and expiration, or an error.
|
||||
Refresh(string) (string, time.Time, error)
|
||||
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
|
||||
Validate(string) (bool, error)
|
||||
// Close closes the authenticator connection if any.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// New returns a new identity provider based given its name.
|
||||
// Returns an error if selected provided not found or if the identity provider is not known.
|
||||
func New(uri *url.URL, internalURL, OverideCertificateName, key string) (p Authenticator, err error) {
|
||||
// if no port given, assume https/443
|
||||
port := uri.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
authEndpoint := fmt.Sprintf("%s:%s", uri.Host, port)
|
||||
|
||||
cp, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if internalURL != "" {
|
||||
authEndpoint = internalURL
|
||||
}
|
||||
|
||||
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
|
||||
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
||||
if OverideCertificateName != "" {
|
||||
err = cert.OverrideServerName(OverideCertificateName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
grpcAuth := middleware.NewSharedSecretCred(key)
|
||||
conn, err := grpc.Dial(
|
||||
authEndpoint,
|
||||
grpc.WithTransportCredentials(cert),
|
||||
grpc.WithPerRPCCredentials(grpcAuth),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authClient := pb.NewAuthenticatorClient(conn)
|
||||
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
|
||||
}
|
36
proxy/authenticator/authenticator_test.go
Normal file
36
proxy/authenticator/authenticator_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package authenticator
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
type args struct {
|
||||
uri *url.URL
|
||||
internalURL string
|
||||
OverideCertificateName string
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantP Authenticator
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotP, err := New(tt.args.uri, tt.args.internalURL, tt.args.OverideCertificateName, tt.args.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotP, tt.wantP) {
|
||||
t.Errorf("New() = %v, want %v", gotP, tt.wantP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
98
proxy/authenticator/grpc.go
Normal file
98
proxy/authenticator/grpc.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// RedeemResponse contains data from a authenticator redeem request.
|
||||
type RedeemResponse struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
IDToken string
|
||||
User string
|
||||
Email string
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
|
||||
type AuthenticateGRPC struct {
|
||||
conn *grpc.ClientConn
|
||||
client pb.AuthenticatorClient
|
||||
}
|
||||
|
||||
// Redeem makes an RPC call to the authenticate service to creates a session state
|
||||
// from an encrypted code provided as a result of an oauth2 callback process.
|
||||
func (a *AuthenticateGRPC) Redeem(code string) (*RedeemResponse, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RedeemResponse{
|
||||
AccessToken: r.AccessToken,
|
||||
RefreshToken: r.RefreshToken,
|
||||
IDToken: r.IdToken,
|
||||
User: r.User,
|
||||
Email: r.Email,
|
||||
Expiry: expiry,
|
||||
// RefreshDeadline: (expiry).Truncate(time.Second),
|
||||
// LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
// ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
|
||||
// user's session. Requires a valid refresh token. Will return an error if the identity provider
|
||||
// has revoked the session or if the refresh token is no longer valid in this context.
|
||||
func (a *AuthenticateGRPC) Refresh(refreshToken string) (string, time.Time, error) {
|
||||
if refreshToken == "" {
|
||||
return "", time.Time{}, errors.New("missing refresh token")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := a.client.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
expiry, err := ptypes.Timestamp(r.Expiry)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return r.AccessToken, expiry, nil
|
||||
}
|
||||
|
||||
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
|
||||
// does NOT do nonce or revokation validation.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (a *AuthenticateGRPC) Validate(idToken string) (bool, error) {
|
||||
if idToken == "" {
|
||||
return false, errors.New("missing id token")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r.IsValid, nil
|
||||
}
|
||||
|
||||
// Close tears down the ClientConn and all underlying connections.
|
||||
func (a *AuthenticateGRPC) Close() error {
|
||||
return a.conn.Close()
|
||||
}
|
181
proxy/authenticator/grpc_test.go
Normal file
181
proxy/authenticator/grpc_test.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
mock "github.com/pomerium/pomerium/proto/authenticate/mock_authenticate"
|
||||
)
|
||||
|
||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
|
||||
// rpcMsg implements the gomock.Matcher interface
|
||||
type rpcMsg struct {
|
||||
msg proto.Message
|
||||
}
|
||||
|
||||
func (r *rpcMsg) Matches(msg interface{}) bool {
|
||||
m, ok := msg.(proto.Message)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return proto.Equal(m, r.msg)
|
||||
}
|
||||
|
||||
func (r *rpcMsg) String() string {
|
||||
return fmt.Sprintf("is %s", r.msg)
|
||||
}
|
||||
|
||||
func TestProxy_Redeem(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
req := &pb.AuthenticateRequest{Code: "unit_test"}
|
||||
mockExpire, err := ptypes.TimestampProto(fixedDate)
|
||||
if err != nil {
|
||||
t.Fatalf("%v failed converting timestamp", err)
|
||||
}
|
||||
|
||||
mockAuthenticateClient.EXPECT().Authenticate(
|
||||
gomock.Any(),
|
||||
&rpcMsg{msg: req},
|
||||
).Return(&pb.AuthenticateReply{
|
||||
AccessToken: "mocked access token",
|
||||
RefreshToken: "mocked refresh token",
|
||||
IdToken: "mocked id token",
|
||||
User: "user1",
|
||||
Email: "test@email.com",
|
||||
Expiry: mockExpire,
|
||||
}, nil)
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
want *RedeemResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "unit_test", &RedeemResponse{
|
||||
AccessToken: "mocked access token",
|
||||
RefreshToken: "mocked refresh token",
|
||||
IDToken: "mocked id token",
|
||||
User: "user1",
|
||||
Email: "test@email.com",
|
||||
Expiry: (fixedDate),
|
||||
}, false},
|
||||
{"empty code", "", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: mockAuthenticateClient}
|
||||
got, err := a.Redeem(tt.idToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != nil {
|
||||
if got.AccessToken != "mocked access token" {
|
||||
t.Errorf("authenticate: invalid access token")
|
||||
}
|
||||
if got.RefreshToken != "mocked refresh token" {
|
||||
t.Errorf("authenticate: invalid refresh token")
|
||||
}
|
||||
if got.IDToken != "mocked id token" {
|
||||
t.Errorf("authenticate: invalid id token")
|
||||
}
|
||||
if got.User != "user1" {
|
||||
t.Errorf("authenticate: invalid user")
|
||||
}
|
||||
if got.Email != "test@email.com" {
|
||||
t.Errorf("authenticate: invalid email")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestProxy_AuthenticateValidate(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
req := &pb.ValidateRequest{IdToken: "unit_test"}
|
||||
|
||||
mockAuthenticateClient.EXPECT().Validate(
|
||||
gomock.Any(),
|
||||
&rpcMsg{msg: req},
|
||||
).Return(&pb.ValidateReply{IsValid: false}, nil)
|
||||
|
||||
ac := mockAuthenticateClient
|
||||
tests := []struct {
|
||||
name string
|
||||
idToken string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "unit_test", false, false},
|
||||
{"empty id token", "", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: ac}
|
||||
|
||||
got, err := a.Validate(tt.idToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Proxy.AuthenticateValidate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_AuthenticateRefresh(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockRefreshClient := mock.NewMockAuthenticatorClient(ctrl)
|
||||
req := &pb.RefreshRequest{RefreshToken: "unit_test"}
|
||||
mockExpire, err := ptypes.TimestampProto(fixedDate)
|
||||
if err != nil {
|
||||
t.Fatalf("%v failed converting timestamp", err)
|
||||
}
|
||||
mockRefreshClient.EXPECT().Refresh(
|
||||
gomock.Any(),
|
||||
&rpcMsg{msg: req},
|
||||
).Return(&pb.RefreshReply{
|
||||
AccessToken: "mocked access token",
|
||||
Expiry: mockExpire,
|
||||
}, nil).AnyTimes()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
wantAT string
|
||||
wantExp time.Time
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", "unit_test", "mocked access token", fixedDate, false},
|
||||
{"missing refresh", "", "", time.Time{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := AuthenticateGRPC{client: mockRefreshClient}
|
||||
|
||||
got, gotExp, err := a.Refresh(tt.refreshToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.wantAT {
|
||||
t.Errorf("Proxy.AuthenticateRefresh() got = %v, want %v", got, tt.wantAT)
|
||||
}
|
||||
if !reflect.DeepEqual(gotExp, tt.wantExp) {
|
||||
t.Errorf("Proxy.AuthenticateRefresh() gotExp = %v, want %v", gotExp, tt.wantExp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
35
proxy/authenticator/mock_authenticator.go
Normal file
35
proxy/authenticator/mock_authenticator.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// MockAuthenticate is a mock authenticator interface
|
||||
type MockAuthenticate struct {
|
||||
RedeemError error
|
||||
RedeemResponse *RedeemResponse
|
||||
RefreshResponse string
|
||||
RefreshTime time.Time
|
||||
RefreshError error
|
||||
ValidateResponse bool
|
||||
ValidateError error
|
||||
CloseError error
|
||||
}
|
||||
|
||||
// Redeem is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Redeem(code string) (*RedeemResponse, error) {
|
||||
return a.RedeemResponse, a.RedeemError
|
||||
}
|
||||
|
||||
// Refresh is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Refresh(refreshToken string) (string, time.Time, error) {
|
||||
return a.RefreshResponse, a.RefreshTime, a.RefreshError
|
||||
}
|
||||
|
||||
// Validate is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Validate(idToken string) (bool, error) {
|
||||
return a.ValidateResponse, a.ValidateError
|
||||
}
|
||||
|
||||
// Close is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Close() error { return a.ValidateError }
|
Loading…
Add table
Add a link
Reference in a new issue