mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
proxy: add tests (#44)
This commit is contained in:
parent
4f4f3965aa
commit
09744f6adb
8 changed files with 185 additions and 102 deletions
|
@ -1,18 +1,7 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// Authenticator provides the authenticate service interface
|
||||
|
@ -28,42 +17,24 @@ type Authenticator interface {
|
|||
Close() error
|
||||
}
|
||||
|
||||
// New returns a new identity provider based given its name.
|
||||
// Returns an error if selected provided not found or if the identity provider is not known.
|
||||
func New(uri *url.URL, internalURL, OverideCertificateName, key string) (p Authenticator, err error) {
|
||||
// if no port given, assume https/443
|
||||
port := uri.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
authEndpoint := fmt.Sprintf("%s:%s", uri.Host, port)
|
||||
|
||||
cp, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if internalURL != "" {
|
||||
authEndpoint = internalURL
|
||||
}
|
||||
|
||||
log.Info().Str("authEndpoint", authEndpoint).Msgf("proxy.New: grpc authenticate connection")
|
||||
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
||||
if OverideCertificateName != "" {
|
||||
err = cert.OverrideServerName(OverideCertificateName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
grpcAuth := middleware.NewSharedSecretCred(key)
|
||||
conn, err := grpc.Dial(
|
||||
authEndpoint,
|
||||
grpc.WithTransportCredentials(cert),
|
||||
grpc.WithPerRPCCredentials(grpcAuth),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authClient := pb.NewAuthenticatorClient(conn)
|
||||
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
|
||||
// Options contains options for connecting to an authenticate service .
|
||||
type Options struct {
|
||||
// Addr is the location of the authenticate service. Used if InternalAddr is not set.
|
||||
Addr string
|
||||
Port int
|
||||
// InternalAddr is the internal (behind the ingress) address to use when making an
|
||||
// authentication connection. If empty, Addr is used.
|
||||
InternalAddr string
|
||||
// OverrideServerName overrides the server name used to verify the hostname on the
|
||||
// returned certificates from the server. gRPC internals also use it to override the virtual
|
||||
// hosting name if it is set.
|
||||
OverideCertificateName string
|
||||
// Shared secret is used to authenticate a authenticate-client with a authenticate-server.
|
||||
SharedSecret string
|
||||
}
|
||||
|
||||
// New returns a new authenticate service client. Takes a client implementation name as an argument.
|
||||
// Currently only gRPC is supported and is always returned.
|
||||
func New(name string, opts *Options) (a Authenticator, err error) {
|
||||
return NewGRPC(opts)
|
||||
}
|
||||
|
|
|
@ -1,36 +1,58 @@
|
|||
package authenticator
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
type args struct {
|
||||
uri *url.URL
|
||||
internalURL string
|
||||
OverideCertificateName string
|
||||
key string
|
||||
func TestMockAuthenticate(t *testing.T) {
|
||||
// Absurd, but I caught a typo this way.
|
||||
fixedDate := time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
||||
redeemResponse := &RedeemResponse{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
Expiry: fixedDate,
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantP Authenticator
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
ma := &MockAuthenticate{
|
||||
RedeemError: errors.New("RedeemError"),
|
||||
RedeemResponse: redeemResponse,
|
||||
RefreshResponse: "RefreshResponse",
|
||||
RefreshTime: fixedDate,
|
||||
RefreshError: errors.New("RefreshError"),
|
||||
ValidateResponse: true,
|
||||
ValidateError: errors.New("ValidateError"),
|
||||
CloseError: errors.New("CloseError"),
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotP, err := New(tt.args.uri, tt.args.internalURL, tt.args.OverideCertificateName, tt.args.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotP, tt.wantP) {
|
||||
t.Errorf("New() = %v, want %v", gotP, tt.wantP)
|
||||
}
|
||||
})
|
||||
got, gotErr := ma.Redeem("a")
|
||||
if gotErr.Error() != "RedeemError" {
|
||||
t.Errorf("unexpected value for gotErr %s", gotErr)
|
||||
}
|
||||
if !reflect.DeepEqual(redeemResponse, got) {
|
||||
t.Errorf("unexpected value for redeemResponse %s", got)
|
||||
}
|
||||
gotToken, gotTime, gotErr := ma.Refresh("a")
|
||||
if gotErr.Error() != "RefreshError" {
|
||||
t.Errorf("unexpected value for gotErr %s", gotErr)
|
||||
}
|
||||
if !reflect.DeepEqual(gotToken, "RefreshResponse") {
|
||||
t.Errorf("unexpected value for gotToken %s", gotToken)
|
||||
}
|
||||
if !gotTime.Equal(fixedDate) {
|
||||
t.Errorf("unexpected value for gotTime %s", gotTime)
|
||||
}
|
||||
|
||||
ok, gotErr := ma.Validate("a")
|
||||
if !ok {
|
||||
t.Errorf("unexpected value for ok : %t", ok)
|
||||
}
|
||||
if gotErr.Error() != "ValidateError" {
|
||||
t.Errorf("unexpected value for gotErr %s", gotErr)
|
||||
}
|
||||
gotErr = ma.Close()
|
||||
if gotErr.Error() != "CloseError" {
|
||||
t.Errorf("unexpected value for ma.CloseError %s", gotErr)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,15 +1,73 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
pb "github.com/pomerium/pomerium/proto/authenticate"
|
||||
)
|
||||
|
||||
// NewGRPC returns a new authenticate service client.
|
||||
func NewGRPC(opts *Options) (p Authenticator, err error) {
|
||||
// gRPC uses a pre-shared secret middleware to establish authentication b/w server and client
|
||||
if opts.SharedSecret == "" {
|
||||
return nil, errors.New("proxy/authenticator: grpc client requires shared secret")
|
||||
}
|
||||
grpcAuth := middleware.NewSharedSecretCred(opts.SharedSecret)
|
||||
|
||||
var connAddr string
|
||||
if opts.InternalAddr != "" {
|
||||
connAddr = opts.InternalAddr
|
||||
} else {
|
||||
connAddr = opts.Addr
|
||||
}
|
||||
if connAddr == "" {
|
||||
return nil, errors.New("proxy/authenticator: connection address required")
|
||||
}
|
||||
// no colon exists in the connection string, assume one must be added manually
|
||||
if !strings.Contains(":", connAddr) {
|
||||
connAddr = fmt.Sprintf("%s:%d", connAddr, opts.Port)
|
||||
}
|
||||
|
||||
cp, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("OverideCertificateName", opts.OverideCertificateName).
|
||||
Str("addr", connAddr).Msgf("proxy/authenticator: grpc connection")
|
||||
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
|
||||
|
||||
// overide allowed certificate name string, typically used when doing behind ingress connection
|
||||
if opts.OverideCertificateName != "" {
|
||||
err = cert.OverrideServerName(opts.OverideCertificateName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
conn, err := grpc.Dial(
|
||||
connAddr,
|
||||
grpc.WithTransportCredentials(cert),
|
||||
grpc.WithPerRPCCredentials(grpcAuth),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authClient := pb.NewAuthenticatorClient(conn)
|
||||
return &AuthenticateGRPC{conn: conn, client: authClient}, nil
|
||||
}
|
||||
|
||||
// RedeemResponse contains data from a authenticator redeem request.
|
||||
type RedeemResponse struct {
|
||||
AccessToken string
|
||||
|
@ -49,9 +107,6 @@ func (a *AuthenticateGRPC) Redeem(code string) (*RedeemResponse, error) {
|
|||
User: r.User,
|
||||
Email: r.Email,
|
||||
Expiry: expiry,
|
||||
// RefreshDeadline: (expiry).Truncate(time.Second),
|
||||
// LifetimeDeadline: extendDeadline(p.CookieLifetimeTTL),
|
||||
// ValidDeadline: extendDeadline(p.CookieExpire),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package authenticator // import "github.com/pomerium/pomerium/proxy/authenticator"
|
||||
package authenticator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -179,3 +180,35 @@ func TestProxy_AuthenticateRefresh(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGRPC(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *Options
|
||||
wantErr bool
|
||||
wantErrStr string
|
||||
}{
|
||||
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret"},
|
||||
{"empty connection", &Options{Addr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
|
||||
{"empty connections", &Options{Addr: "", InternalAddr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
|
||||
{"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"},
|
||||
{"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"},
|
||||
|
||||
// {"addr and internal ", &Options{Addr: "localhost", InternalAddr: "local.localhost", SharedSecret: "shh"}, nil, true, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewGRPC(tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewGRPC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
|
||||
t.Errorf("NewGRPC() error = %v did not contain wantErr %v", err, tt.wantErrStr)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// MockAuthenticate is a mock authenticator interface
|
||||
// MockAuthenticate provides a mocked implementation of the authenticator interface.
|
||||
type MockAuthenticate struct {
|
||||
RedeemError error
|
||||
RedeemResponse *RedeemResponse
|
||||
|
@ -16,20 +16,20 @@ type MockAuthenticate struct {
|
|||
CloseError error
|
||||
}
|
||||
|
||||
// Redeem is a mocked implementation for authenticator testing.
|
||||
// Redeem is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Redeem(code string) (*RedeemResponse, error) {
|
||||
return a.RedeemResponse, a.RedeemError
|
||||
}
|
||||
|
||||
// Refresh is a mocked implementation for authenticator testing.
|
||||
// Refresh is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Refresh(refreshToken string) (string, time.Time, error) {
|
||||
return a.RefreshResponse, a.RefreshTime, a.RefreshError
|
||||
}
|
||||
|
||||
// Validate is a mocked implementation for authenticator testing.
|
||||
// Validate is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Validate(idToken string) (bool, error) {
|
||||
return a.ValidateResponse, a.ValidateError
|
||||
}
|
||||
|
||||
// Close is a mocked implementation for authenticator testing.
|
||||
func (a MockAuthenticate) Close() error { return a.ValidateError }
|
||||
// Close is a mocked authenticator client function.
|
||||
func (a MockAuthenticate) Close() error { return a.CloseError }
|
||||
|
|
|
@ -42,7 +42,6 @@ func (p *Proxy) Handler() http.Handler {
|
|||
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
||||
mux.HandleFunc("/.pomerium/callback", p.OAuthCallback)
|
||||
mux.HandleFunc("/.pomerium/auth", p.AuthenticateOnly)
|
||||
mux.HandleFunc("/", p.Proxy)
|
||||
|
||||
// middleware chain
|
||||
|
@ -236,15 +235,6 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// AuthenticateOnly calls the Authenticate handler.
|
||||
func (p *Proxy) AuthenticateOnly(w http.ResponseWriter, r *http.Request) {
|
||||
err := p.Authenticate(w, r)
|
||||
if err != nil {
|
||||
http.Error(w, "unauthorized request", http.StatusUnauthorized)
|
||||
}
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
||||
// or starting the authenticate service for validation if not.
|
||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -274,7 +264,6 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, "unknown route to proxy", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
route.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
|
|
|
@ -429,6 +429,12 @@ func TestProxy_Authenticate(t *testing.T) {
|
|||
authenticator authenticator.Authenticator
|
||||
wantErr bool
|
||||
}{
|
||||
{"cannot save session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
|
||||
authenticator.MockAuthenticate{}, true},
|
||||
|
||||
{"cannot load session",
|
||||
"https://corp.example.com/",
|
||||
map[string]string{"corp.example.com": "example.com"},
|
||||
|
|
|
@ -33,9 +33,10 @@ const (
|
|||
// Options represents the configurations available for the proxy service.
|
||||
type Options struct {
|
||||
// Authenticate service settings
|
||||
AuthenticateURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
||||
AuthenticateInternalURL string `envconfig:"AUTHENTICATE_INTERNAL_URL"`
|
||||
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
|
||||
AuthenticateURL *url.URL `envconfig:"AUTHENTICATE_SERVICE_URL"`
|
||||
AuthenticateInternalAddr string `envconfig:"AUTHENTICATE_INTERNAL_URL"`
|
||||
OverideCertificateName string `envconfig:"OVERIDE_CERTIFICATE_NAME"`
|
||||
AuthenticatePort int `envconfig:"AUTHENTICATE_SERVICE_PORT"`
|
||||
|
||||
// SigningKey is a base64 encoded private key used to add a JWT-signature to proxied requests.
|
||||
// See : https://www.pomerium.io/guide/signed-headers.html
|
||||
|
@ -67,6 +68,8 @@ var defaultOptions = &Options{
|
|||
CookieRefresh: time.Duration(30) * time.Minute,
|
||||
CookieLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
DefaultUpstreamTimeout: time.Duration(10) * time.Second,
|
||||
// services
|
||||
AuthenticatePort: 443,
|
||||
}
|
||||
|
||||
// OptionsFromEnvConfig builds the IdentityProvider service's configuration
|
||||
|
@ -199,11 +202,15 @@ func New(opts *Options) (*Proxy, error) {
|
|||
p.Handle(fromURL.Host, handler)
|
||||
log.Info().Str("from", fromURL.Host).Str("to", toURL.String()).Msg("proxy.New: new route")
|
||||
}
|
||||
|
||||
p.AuthenticateClient, err = authenticator.New(
|
||||
opts.AuthenticateURL,
|
||||
opts.AuthenticateInternalURL,
|
||||
opts.OverideCertificateName,
|
||||
opts.SharedKey)
|
||||
"grpc",
|
||||
&authenticator.Options{
|
||||
Addr: opts.AuthenticateURL.Host,
|
||||
InternalAddr: opts.AuthenticateInternalAddr,
|
||||
OverideCertificateName: opts.OverideCertificateName,
|
||||
SharedSecret: opts.SharedKey,
|
||||
})
|
||||
return p, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue