proxy: add unit tests (#43)

This commit is contained in:
Bobby DeSimone 2019-02-11 20:15:01 -08:00 committed by GitHub
parent cedf9922d3
commit 4f4f3965aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 577 additions and 323 deletions

View file

@ -0,0 +1,69 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/url"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
// Authenticator provides the authenticate service interface
type Authenticator interface {
// Redeem takes a code and returns a validated session or an error
Redeem(string) (*RedeemResponse, error)
// Refresh attempts to refresh a valid session with a refresh token. Returns a new access token
// and expiration, or an error.
Refresh(string) (string, time.Time, error)
// Validate evaluates a given oidc id_token for validity. Returns validity and any error.
Validate(string) (bool, error)
// Close closes the authenticator connection if any.
Close() error
}
// New returns a new identity provider based given its name.
// Returns an error if selected provided not found or if the identity provider is not known.
func New(uri *url.URL, internalURL, OverideCertificateName, key string) (p Authenticator, err error) {
// if no port given, assume https/443
port := uri.Port()
if port == "" {
port = "443"
}
authEndpoint := fmt.Sprintf("%s:%s", uri.Host, port)
cp, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
if internalURL != "" {
authEndpoint = internalURL
}
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
if OverideCertificateName != "" {
err = cert.OverrideServerName(OverideCertificateName)
if err != nil {
return nil, err
}
}
grpcAuth := middleware.NewSharedSecretCred(key)
conn, err := grpc.Dial(
authEndpoint,
grpc.WithTransportCredentials(cert),
grpc.WithPerRPCCredentials(grpcAuth),
)
if err != nil {
return nil, err
}
authClient := pb.NewAuthenticatorClient(conn)
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
}

View file

@ -0,0 +1,36 @@
package authenticator
import (
"net/url"
"reflect"
"testing"
)
func TestNew(t *testing.T) {
type args struct {
uri *url.URL
internalURL string
OverideCertificateName string
key string
}
tests := []struct {
name string
args args
wantP Authenticator
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotP, err := New(tt.args.uri, tt.args.internalURL, tt.args.OverideCertificateName, tt.args.key)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotP, tt.wantP) {
t.Errorf("New() = %v, want %v", gotP, tt.wantP)
}
})
}
}

View file

@ -0,0 +1,98 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"context"
"errors"
"time"
"github.com/golang/protobuf/ptypes"
"google.golang.org/grpc"
pb "github.com/pomerium/pomerium/proto/authenticate"
)
// RedeemResponse contains data from a authenticator redeem request.
type RedeemResponse struct {
AccessToken string
RefreshToken string
IDToken string
User string
Email string
Expiry time.Time
}
// AuthenticateGRPC is a gRPC implementation of an authenticator (authenticate client)
type AuthenticateGRPC struct {
conn *grpc.ClientConn
client pb.AuthenticatorClient
}
// Redeem makes an RPC call to the authenticate service to creates a session state
// from an encrypted code provided as a result of an oauth2 callback process.
func (a *AuthenticateGRPC) Redeem(code string) (*RedeemResponse, error) {
if code == "" {
return nil, errors.New("missing code")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := a.client.Authenticate(ctx, &pb.AuthenticateRequest{Code: code})
if err != nil {
return nil, err
}
expiry, err := ptypes.Timestamp(r.Expiry)
if err != nil {
return nil, err
}
return &RedeemResponse{
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
IDToken: r.IdToken,
User: r.User,
Email: r.Email,
Expiry: expiry,
// RefreshDeadline: (expiry).Truncate(time.Second),
// LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
// ValidDeadline: extendDeadline(p.CookieExpire),
}, nil
}
// Refresh makes an RPC call to the authenticate service to attempt to refresh the
// user's session. Requires a valid refresh token. Will return an error if the identity provider
// has revoked the session or if the refresh token is no longer valid in this context.
func (a *AuthenticateGRPC) Refresh(refreshToken string) (string, time.Time, error) {
if refreshToken == "" {
return "", time.Time{}, errors.New("missing refresh token")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := a.client.Refresh(ctx, &pb.RefreshRequest{RefreshToken: refreshToken})
if err != nil {
return "", time.Time{}, err
}
expiry, err := ptypes.Timestamp(r.Expiry)
if err != nil {
return "", time.Time{}, err
}
return r.AccessToken, expiry, nil
}
// Validate makes an RPC call to the authenticate service to validate the JWT id token;
// does NOT do nonce or revokation validation.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (a *AuthenticateGRPC) Validate(idToken string) (bool, error) {
if idToken == "" {
return false, errors.New("missing id token")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := a.client.Validate(ctx, &pb.ValidateRequest{IdToken: idToken})
if err != nil {
return false, err
}
return r.IsValid, nil
}
// Close tears down the ClientConn and all underlying connections.
func (a *AuthenticateGRPC) Close() error {
return a.conn.Close()
}

View file

@ -0,0 +1,181 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
pb "github.com/pomerium/pomerium/proto/authenticate"
mock "github.com/pomerium/pomerium/proto/authenticate/mock_authenticate"
)
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
// rpcMsg implements the gomock.Matcher interface
type rpcMsg struct {
msg proto.Message
}
func (r *rpcMsg) Matches(msg interface{}) bool {
m, ok := msg.(proto.Message)
if !ok {
return false
}
return proto.Equal(m, r.msg)
}
func (r *rpcMsg) String() string {
return fmt.Sprintf("is %s", r.msg)
}
func TestProxy_Redeem(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
req := &pb.AuthenticateRequest{Code: "unit_test"}
mockExpire, err := ptypes.TimestampProto(fixedDate)
if err != nil {
t.Fatalf("%v failed converting timestamp", err)
}
mockAuthenticateClient.EXPECT().Authenticate(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&pb.AuthenticateReply{
AccessToken: "mocked access token",
RefreshToken: "mocked refresh token",
IdToken: "mocked id token",
User: "user1",
Email: "test@email.com",
Expiry: mockExpire,
}, nil)
tests := []struct {
name string
idToken string
want *RedeemResponse
wantErr bool
}{
{"good", "unit_test", &RedeemResponse{
AccessToken: "mocked access token",
RefreshToken: "mocked refresh token",
IDToken: "mocked id token",
User: "user1",
Email: "test@email.com",
Expiry: (fixedDate),
}, false},
{"empty code", "", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: mockAuthenticateClient}
got, err := a.Redeem(tt.idToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateValidate() error = %v,\n wantErr %v", err, tt.wantErr)
return
}
if got != nil {
if got.AccessToken != "mocked access token" {
t.Errorf("authenticate: invalid access token")
}
if got.RefreshToken != "mocked refresh token" {
t.Errorf("authenticate: invalid refresh token")
}
if got.IDToken != "mocked id token" {
t.Errorf("authenticate: invalid id token")
}
if got.User != "user1" {
t.Errorf("authenticate: invalid user")
}
if got.Email != "test@email.com" {
t.Errorf("authenticate: invalid email")
}
}
})
}
}
func TestProxy_AuthenticateValidate(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAuthenticateClient := mock.NewMockAuthenticatorClient(ctrl)
req := &pb.ValidateRequest{IdToken: "unit_test"}
mockAuthenticateClient.EXPECT().Validate(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&pb.ValidateReply{IsValid: false}, nil)
ac := mockAuthenticateClient
tests := []struct {
name string
idToken string
want bool
wantErr bool
}{
{"good", "unit_test", false, false},
{"empty id token", "", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: ac}
got, err := a.Validate(tt.idToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateValidate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Proxy.AuthenticateValidate() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_AuthenticateRefresh(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockRefreshClient := mock.NewMockAuthenticatorClient(ctrl)
req := &pb.RefreshRequest{RefreshToken: "unit_test"}
mockExpire, err := ptypes.TimestampProto(fixedDate)
if err != nil {
t.Fatalf("%v failed converting timestamp", err)
}
mockRefreshClient.EXPECT().Refresh(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&pb.RefreshReply{
AccessToken: "mocked access token",
Expiry: mockExpire,
}, nil).AnyTimes()
tests := []struct {
name string
refreshToken string
wantAT string
wantExp time.Time
wantErr bool
}{
{"good", "unit_test", "mocked access token", fixedDate, false},
{"missing refresh", "", "", time.Time{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := AuthenticateGRPC{client: mockRefreshClient}
got, gotExp, err := a.Refresh(tt.refreshToken)
if (err != nil) != tt.wantErr {
t.Errorf("Proxy.AuthenticateRefresh() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.wantAT {
t.Errorf("Proxy.AuthenticateRefresh() got = %v, want %v", got, tt.wantAT)
}
if !reflect.DeepEqual(gotExp, tt.wantExp) {
t.Errorf("Proxy.AuthenticateRefresh() gotExp = %v, want %v", gotExp, tt.wantExp)
}
})
}
}

View file

@ -0,0 +1,35 @@
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
import (
"time"
)
// MockAuthenticate is a mock authenticator interface
type MockAuthenticate struct {
RedeemError error
RedeemResponse *RedeemResponse
RefreshResponse string
RefreshTime time.Time
RefreshError error
ValidateResponse bool
ValidateError error
CloseError error
}
// Redeem is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Redeem(code string) (*RedeemResponse, error) {
return a.RedeemResponse, a.RedeemError
}
// Refresh is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Refresh(refreshToken string) (string, time.Time, error) {
return a.RefreshResponse, a.RefreshTime, a.RefreshError
}
// Validate is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Validate(idToken string) (bool, error) {
return a.ValidateResponse, a.ValidateError
}
// Close is a mocked implementation for authenticator testing.
func (a MockAuthenticate) Close() error { return a.ValidateError }