internal/config: refactor option parsing

- authorize: build whitelist from policy's URLs instead of strings.
- internal/httputil: merged httputil and https package.
- internal/config: merged config and policy packages.
- internal/metrics: removed unused measure struct.
- proxy/clients: refactor Addr fields to be urls.
- proxy: remove unused extend deadline function.
- proxy: use handler middleware for reverse proxy leg.
- proxy: change the way websocket requests are made (route based).

General improvements
- omitted value from range in several cases where for loop could be simplified.
- added error checking to many tests.
- standardize url parsing.
- remove unnecessary return statements.

- proxy: add self-signed certificate support. #179
- proxy: add skip tls certificate verification. #179
- proxy: Refactor websocket support to be route based. #204
This commit is contained in:
Bobby DeSimone 2019-07-04 10:12:25 -07:00
parent 28efa3359b
commit 7558d5b0de
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
38 changed files with 1354 additions and 1079 deletions

View file

@ -6,11 +6,15 @@
- Add programmatic authentication support. [GH-177] - Add programmatic authentication support. [GH-177]
- Add Prometheus format metrics endpoint. [GH-35] - Add Prometheus format metrics endpoint. [GH-35]
- Add policy setting to enable self-signed certificate support. [GH-179]
- Add policy setting to skip tls certificate verification. [GH-179]
### CHANGED ### CHANGED
- Policy `to` and `from` settings must be set to valid HTTP URLs including [schemes](https://en.wikipedia.org/wiki/Uniform_Resource_Identifier) and hostnames (e.g. `http.corp.domain.example` should now be `https://http.corp.domain.example`).
- Proxy's sign out handler `{}/.pomerium/sign_out` now accepts an optional `redirect_uri` parameter which can be used to specify a custom redirect page, so long as it is under the same top-level domain. [GH-183] - Proxy's sign out handler `{}/.pomerium/sign_out` now accepts an optional `redirect_uri` parameter which can be used to specify a custom redirect page, so long as it is under the same top-level domain. [GH-183]
- Policy configuration can now be empty at startup [GH-190] - Policy configuration can now be empty at startup. [GH-190]
- Websocket support is now set per-route instead of globally. [GH-204]
### FIXED ### FIXED

View file

@ -18,7 +18,7 @@ import (
// The checks do not modify the internal state of the Option structure. Returns // The checks do not modify the internal state of the Option structure. Returns
// on first error found. // on first error found.
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
if o.AuthenticateURL.Hostname() == "" { if o.AuthenticateURL == nil {
return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' missing") return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' missing")
} }
if o.ClientID == "" { if o.ClientID == "" {
@ -35,7 +35,7 @@ func ValidateOptions(o config.Options) error {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' must be base64 encoded: %v", err) return fmt.Errorf("authenticate: 'COOKIE_SECRET' must be base64 encoded: %v", err)
} }
if len(decodedCookieSecret) != 32 { if len(decodedCookieSecret) != 32 {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' should be 32; got %d", len(decodedCookieSecret)) return fmt.Errorf("authenticate: 'COOKIE_SECRET' %s be 32; got %d", o.CookieSecret, len(decodedCookieSecret))
} }
return nil return nil
} }
@ -80,7 +80,7 @@ func New(opts config.Options) (*Authenticate, error) {
provider, err := identity.New( provider, err := identity.New(
opts.Provider, opts.Provider,
&identity.Provider{ &identity.Provider{
RedirectURL: &redirectURL, RedirectURL: redirectURL,
ProviderName: opts.Provider, ProviderName: opts.Provider,
ProviderURL: opts.ProviderURL, ProviderURL: opts.ProviderURL,
ClientID: opts.ClientID, ClientID: opts.ClientID,
@ -97,7 +97,7 @@ func New(opts config.Options) (*Authenticate, error) {
} }
return &Authenticate{ return &Authenticate{
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
RedirectURL: &redirectURL, RedirectURL: redirectURL,
templates: templates.New(), templates: templates.New(),
csrfStore: cookieStore, csrfStore: cookieStore,
sessionStore: cookieStore, sessionStore: cookieStore,

View file

@ -1,53 +1,49 @@
package authenticate package authenticate
import ( import (
"net/url"
"testing" "testing"
"time"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
) )
func testOptions() config.Options { func newTestOptions(t *testing.T) *config.Options {
redirectURL, _ := url.Parse("https://example.com/oauth2/callback") opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
return config.Options{ if err != nil {
AuthenticateURL: *redirectURL, t.Fatal(err)
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
ClientID: "test-client-id",
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
CookieRefresh: time.Duration(1) * time.Hour,
CookieExpire: time.Duration(168) * time.Hour,
CookieName: "pomerium",
} }
opts.ClientID = "client-id"
opts.Provider = "google"
opts.ClientSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
return opts
} }
func TestOptions_Validate(t *testing.T) { func TestOptions_Validate(t *testing.T) {
good := testOptions() good := newTestOptions(t)
badRedirectURL := testOptions() badRedirectURL := newTestOptions(t)
badRedirectURL.AuthenticateURL = url.URL{} badRedirectURL.AuthenticateURL = nil
emptyClientID := testOptions() emptyClientID := newTestOptions(t)
emptyClientID.ClientID = "" emptyClientID.ClientID = ""
emptyClientSecret := testOptions() emptyClientSecret := newTestOptions(t)
emptyClientSecret.ClientSecret = "" emptyClientSecret.ClientSecret = ""
emptyCookieSecret := testOptions() emptyCookieSecret := newTestOptions(t)
emptyCookieSecret.CookieSecret = "" emptyCookieSecret.CookieSecret = ""
invalidCookieSecret := testOptions() invalidCookieSecret := newTestOptions(t)
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^" invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
shortCookieLength := testOptions() shortCookieLength := newTestOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
badSharedKey := testOptions() badSharedKey := newTestOptions(t)
badSharedKey.SharedKey = "" badSharedKey.SharedKey = ""
badAuthenticateURL := testOptions() badAuthenticateURL := newTestOptions(t)
badAuthenticateURL.AuthenticateURL = url.URL{} badAuthenticateURL.AuthenticateURL = nil
tests := []struct { tests := []struct {
name string name string
o config.Options o *config.Options
wantErr bool wantErr bool
}{ }{
{"minimum options", good, false}, {"minimum options", good, false},
{"nil options", config.Options{}, true}, {"nil options", &config.Options{}, true},
{"bad redirect url", badRedirectURL, true}, {"bad redirect url", badRedirectURL, true},
{"no cookie secret", emptyCookieSecret, true}, {"no cookie secret", emptyCookieSecret, true},
{"invalid cookie secret", invalidCookieSecret, true}, {"invalid cookie secret", invalidCookieSecret, true},
@ -59,8 +55,7 @@ func TestOptions_Validate(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := tt.o if err := ValidateOptions(*tt.o); (err != nil) != tt.wantErr {
if err := ValidateOptions(o); (err != nil) != tt.wantErr {
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
@ -68,25 +63,24 @@ func TestOptions_Validate(t *testing.T) {
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
good := testOptions() good := newTestOptions(t)
good.Provider = "google"
badRedirectURL := testOptions() badRedirectURL := newTestOptions(t)
badRedirectURL.AuthenticateURL = url.URL{} badRedirectURL.AuthenticateURL = nil
tests := []struct { tests := []struct {
name string name string
opts config.Options opts *config.Options
// want *Authenticate // want *Authenticate
wantErr bool wantErr bool
}{ }{
{"good", good, false}, {"good", good, false},
{"empty opts", config.Options{}, true}, {"empty opts", &config.Options{}, true},
{"fails to validate", badRedirectURL, true}, {"fails to validate", badRedirectURL, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := New(tt.opts) _, err := New(*tt.opts)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -309,5 +309,4 @@ func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusInternalServerError, Message: "authenticate: failed returning new session"}) httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusInternalServerError, Message: "authenticate: failed returning new session"})
return return
} }
return
} }

View file

@ -4,11 +4,8 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/policy"
) )
// ValidateOptions checks to see if configuration values are valid for the // ValidateOptions checks to see if configuration values are valid for the
@ -21,7 +18,6 @@ func ValidateOptions(o config.Options) error {
if len(decoded) != 32 { if len(decoded) != 32 {
return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded)) return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded))
} }
return nil return nil
} }
@ -50,7 +46,7 @@ func New(opts config.Options) (*Authorize, error) {
// NewIdentityWhitelist returns an indentity validator. // NewIdentityWhitelist returns an indentity validator.
// todo(bdd) : a radix-tree implementation is probably more efficient // todo(bdd) : a radix-tree implementation is probably more efficient
func NewIdentityWhitelist(policies []policy.Policy, admins []string) IdentityValidator { func NewIdentityWhitelist(policies []config.Policy, admins []string) IdentityValidator {
return newIdentityWhitelistMap(policies, admins) return newIdentityWhitelistMap(policies, admins)
} }
@ -59,7 +55,7 @@ func (a *Authorize) ValidIdentity(route string, identity *Identity) bool {
return a.identityAccess.Valid(route, identity) return a.identityAccess.Valid(route, identity)
} }
// UpdateOptions updates internal structres based on config.Options // UpdateOptions updates internal structures based on config.Options
func (a *Authorize) UpdateOptions(o config.Options) error { func (a *Authorize) UpdateOptions(o config.Options) error {
log.Info().Msg("authorize: updating options") log.Info().Msg("authorize: updating options")
a.identityAccess = NewIdentityWhitelist(o.Policies, o.Administrators) a.identityAccess = NewIdentityWhitelist(o.Policies, o.Administrators)

View file

@ -4,26 +4,24 @@ import (
"testing" "testing"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/policy"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
t.Parallel() t.Parallel()
policies := testPolicies() policies := testPolicies(t)
tests := []struct { tests := []struct {
name string name string
SharedKey string SharedKey string
Policies []policy.Policy Policies []config.Policy
wantErr bool wantErr bool
}{ }{
{"good", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, false}, {"good", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, false},
{"bad shared secret", "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", policies, true}, {"bad shared secret", "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", policies, true},
{"really bad shared secret", "sup", policies, true}, {"really bad shared secret", "sup", policies, true},
{"validation error, short secret", "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", policies, true}, {"validation error, short secret", "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", policies, true},
{"empty options", "", []policy.Policy{}, true}, // special case {"empty options", "", []config.Policy{}, true}, // special case
{"missing policies", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", []policy.Policy{}, false}, // special case
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -43,10 +41,13 @@ func TestNew(t *testing.T) {
} }
} }
func testPolicies() []policy.Policy { func testPolicies(t *testing.T) []config.Policy {
testPolicy := policy.Policy{From: "pomerium.io", To: "httpbin.org", AllowedEmails: []string{"test@gmail.com"}} testPolicy := config.Policy{From: "https://pomerium.io", To: "http://httpbin.org", AllowedEmails: []string{"test@gmail.com"}}
testPolicy.Validate() err := testPolicy.Validate()
policies := []policy.Policy{ if err != nil {
t.Fatal(err)
}
policies := []config.Policy{
testPolicy, testPolicy,
} }
@ -55,31 +56,39 @@ func testPolicies() []policy.Policy {
func Test_UpdateOptions(t *testing.T) { func Test_UpdateOptions(t *testing.T) {
t.Parallel() t.Parallel()
policies := testPolicies() policies := testPolicies(t)
newPolicy := policy.Policy{From: "foo.notatld", To: "bar.notatld", AllowedEmails: []string{"test@gmail.com"}} newPolicy := config.Policy{From: "https://source.example", To: "http://destination.example", AllowedEmails: []string{"test@gmail.com"}}
newPolicy.Validate() if err := newPolicy.Validate(); err != nil {
newPolicies := []policy.Policy{ t.Fatal(err)
}
newPolicies := []config.Policy{
newPolicy, newPolicy,
} }
identity := &Identity{Email: "test@gmail.com"} identity := &Identity{Email: "test@gmail.com"}
tests := []struct { tests := []struct {
name string name string
SharedKey string SharedKey string
Policies []policy.Policy Policies []config.Policy
newPolices []policy.Policy newPolices []config.Policy
route string route string
wantAllowed bool wantAllowed bool
}{ }{
{"good", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, policies, "pomerium.io", true}, {"good", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, policies, "pomerium.io", true},
{"changed", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, newPolicies, "foo.notatld", true}, {"changed", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, newPolicies, "source.example", true},
{"changed and missing", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, newPolicies, "pomerium.io", false}, {"changed and missing", "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", policies, newPolicies, "pomerium.io", false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := config.Options{SharedKey: tt.SharedKey, Policies: tt.Policies} o := config.Options{SharedKey: tt.SharedKey, Policies: tt.Policies}
authorize, _ := New(o) authorize, err := New(o)
if err != nil {
t.Fatal(err)
}
o.Policies = tt.newPolices o.Policies = tt.newPolices
authorize.UpdateOptions(o) if err := authorize.UpdateOptions(o); err != nil {
t.Fatal(err)
}
allowed := authorize.ValidIdentity(tt.route, identity) allowed := authorize.ValidIdentity(tt.route, identity)
if allowed != tt.wantAllowed { if allowed != tt.wantAllowed {

View file

@ -5,8 +5,8 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/policy"
) )
// Identity contains a user's identity information. // Identity contains a user's identity information.
@ -55,28 +55,24 @@ type whitelist struct {
// newIdentityWhitelistMap takes a slice of policies and creates a hashmap of identity // newIdentityWhitelistMap takes a slice of policies and creates a hashmap of identity
// authorizations per-route for each allowed group, domain, and email. // authorizations per-route for each allowed group, domain, and email.
func newIdentityWhitelistMap(policies []policy.Policy, admins []string) *whitelist { func newIdentityWhitelistMap(policies []config.Policy, admins []string) *whitelist {
if len(policies) == 0 {
policyCount := len(policies) log.Warn().Msg("authorize: loaded configuration with no policies")
if policyCount == 0 {
log.Warn().Msg("authorize: loaded configuration with no policies specified")
} }
log.Info().Int("policy-count", policyCount).Msg("authorize: updated policies")
var wl whitelist var wl whitelist
wl.access = make(map[string]bool, len(policies)*3) wl.access = make(map[string]bool, len(policies)*3)
for _, p := range policies { for _, p := range policies {
for _, group := range p.AllowedGroups { for _, group := range p.AllowedGroups {
wl.PutGroup(p.From, group) wl.PutGroup(p.Source.Host, group)
log.Debug().Str("route", p.From).Str("group", group).Msg("add group") log.Debug().Str("route", p.Source.Host).Str("group", group).Msg("add group")
} }
for _, domain := range p.AllowedDomains { for _, domain := range p.AllowedDomains {
wl.PutDomain(p.From, domain) wl.PutDomain(p.Source.Host, domain)
log.Debug().Str("route", p.From).Str("domain", domain).Msg("add domain") log.Debug().Str("route", p.Source.Host).Str("domain", domain).Msg("add domain")
} }
for _, email := range p.AllowedEmails { for _, email := range p.AllowedEmails {
wl.PutEmail(p.From, email) wl.PutEmail(p.Source.Host, email)
log.Debug().Str("route", p.From).Str("email", email).Msg("add email") log.Debug().Str("route", p.Source.Host).Str("email", email).Msg("add email")
} }
} }

View file

@ -3,7 +3,7 @@ package authorize
import ( import (
"testing" "testing"
"github.com/pomerium/pomerium/internal/policy" "github.com/pomerium/pomerium/internal/config"
) )
func TestIdentity_EmailDomain(t *testing.T) { func TestIdentity_EmailDomain(t *testing.T) {
@ -32,40 +32,46 @@ func Test_IdentityWhitelistMap(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
policies []policy.Policy policies []config.Policy
route string route string
Identity *Identity Identity *Identity
admins []string admins []string
want bool want bool
}{ }{
{"valid domain", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "user@example.com"}, nil, true}, {"valid domain", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "user@example.com"}, nil, true},
{"valid domain with admins", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "user@example.com"}, []string{"admin@example.com"}, true}, {"valid domain with admins", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "user@example.com"}, []string{"admin@example.com"}, true},
{"invalid domain prepend", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "a@1example.com"}, nil, false}, {"invalid domain prepend", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "a@1example.com"}, nil, false},
{"invalid domain postpend", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "user@example.com2"}, nil, false}, {"invalid domain postpend", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "user@example.com2"}, nil, false},
{"valid group", []policy.Policy{{From: "example.com", AllowedGroups: []string{"admin"}}}, "example.com", &Identity{Email: "user@example.com", Groups: []string{"admin"}}, nil, true}, {"valid group", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"admin"}}}, "from.example", &Identity{Email: "user@example.com", Groups: []string{"admin"}}, nil, true},
{"invalid group", []policy.Policy{{From: "example.com", AllowedGroups: []string{"admin"}}}, "example.com", &Identity{Email: "user@example.com", Groups: []string{"everyone"}}, nil, false}, {"invalid group", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"admin"}}}, "from.example", &Identity{Email: "user@example.com", Groups: []string{"everyone"}}, nil, false},
{"invalid empty", []policy.Policy{{From: "example.com", AllowedGroups: []string{"admin"}}}, "example.com", &Identity{Email: "user@example.com", Groups: []string{""}}, nil, false}, {"invalid empty", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"admin"}}}, "from.example", &Identity{Email: "user@example.com", Groups: []string{""}}, nil, false},
{"valid group multiple", []policy.Policy{{From: "example.com", AllowedGroups: []string{"admin"}}}, "example.com", &Identity{Email: "user@example.com", Groups: []string{"everyone", "admin"}}, nil, true}, {"valid group multiple", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"admin"}}}, "from.example", &Identity{Email: "user@example.com", Groups: []string{"everyone", "admin"}}, nil, true},
{"invalid group multiple", []policy.Policy{{From: "example.com", AllowedGroups: []string{"admin"}}}, "example.com", &Identity{Email: "user@example.com", Groups: []string{"everyones", "sadmin"}}, nil, false}, {"invalid group multiple", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"admin"}}}, "from.example", &Identity{Email: "user@example.com", Groups: []string{"everyones", "sadmin"}}, nil, false},
{"valid user email", []policy.Policy{{From: "example.com", AllowedEmails: []string{"user@example.com"}}}, "example.com", &Identity{Email: "user@example.com"}, nil, true}, {"valid user email", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedEmails: []string{"user@example.com"}}}, "from.example", &Identity{Email: "user@example.com"}, nil, true},
{"invalid user email", []policy.Policy{{From: "example.com", AllowedEmails: []string{"user@example.com"}}}, "example.com", &Identity{Email: "user2@example.com"}, nil, false}, {"invalid user email", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedEmails: []string{"user@example.com"}}}, "from.example", &Identity{Email: "user2@example.com"}, nil, false},
{"empty everything", []policy.Policy{{From: "example.com"}}, "example.com", &Identity{Email: "user2@example.com"}, nil, false}, {"empty everything", []config.Policy{{From: "https://from.example", To: "https://to.example"}}, "from.example", &Identity{Email: "user2@example.com"}, nil, false},
{"empty policy", []policy.Policy{}, "example.com", &Identity{Email: "user@example.com"}, nil, false}, {"empty policy", []config.Policy{}, "from.example", &Identity{Email: "user2@example.com"}, nil, false},
// impersonation related // impersonation related
{"admin not impersonating allowed", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "admin@example.com"}, []string{"admin@example.com"}, true}, {"admin not impersonating allowed", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "admin@example.com"}, []string{"admin@example.com"}, true},
{"admin not impersonating denied", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "admin@admin-domain.com"}, []string{"admin@admin-domain.com"}, false}, {"admin not impersonating denied", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "admin@admin-domain.com"}, []string{"admin@admin-domain.com"}, false},
{"impersonating match domain", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@example.com"}, []string{"admin@admin-domain.com"}, true}, {"impersonating match domain", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@example.com"}, []string{"admin@admin-domain.com"}, true},
{"impersonating does not match domain", []policy.Policy{{From: "example.com", AllowedDomains: []string{"example.com"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@not-example.com"}, []string{"admin@admin-domain.com"}, false}, {"impersonating does not match domain", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedDomains: []string{"example.com"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@not-example.com"}, []string{"admin@admin-domain.com"}, false},
{"impersonating match email", []policy.Policy{{From: "example.com", AllowedEmails: []string{"user@example.com"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@example.com"}, []string{"admin@admin-domain.com"}, true}, {"impersonating match email", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedEmails: []string{"user@example.com"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@example.com"}, []string{"admin@admin-domain.com"}, true},
{"impersonating does not match email", []policy.Policy{{From: "example.com", AllowedEmails: []string{"user@example.com"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@not-example.com"}, []string{"admin@admin-domain.com"}, false}, {"impersonating does not match email", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedEmails: []string{"user@example.com"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateEmail: "user@not-example.com"}, []string{"admin@admin-domain.com"}, false},
{"impersonating match groups", []policy.Policy{{From: "example.com", AllowedGroups: []string{"support"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"support"}}, []string{"admin@admin-domain.com"}, true}, {"impersonating match groups", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"support"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"support"}}, []string{"admin@admin-domain.com"}, true},
{"impersonating match many groups", []policy.Policy{{From: "example.com", AllowedGroups: []string{"support"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"a", "b", "c", "support"}}, []string{"admin@admin-domain.com"}, true}, {"impersonating match many groups", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"support"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"a", "b", "c", "support"}}, []string{"admin@admin-domain.com"}, true},
{"impersonating does not match groups", []policy.Policy{{From: "example.com", AllowedGroups: []string{"support"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"not support"}}, []string{"admin@admin-domain.com"}, false}, {"impersonating does not match groups", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"support"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"not support"}}, []string{"admin@admin-domain.com"}, false},
{"impersonating does not match many groups", []policy.Policy{{From: "example.com", AllowedGroups: []string{"support"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"not support", "b", "c"}}, []string{"admin@admin-domain.com"}, false}, {"impersonating does not match many groups", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"support"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{"not support", "b", "c"}}, []string{"admin@admin-domain.com"}, false},
{"impersonating does not match empty groups", []policy.Policy{{From: "example.com", AllowedGroups: []string{"support"}}}, "example.com", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{""}}, []string{"admin@admin-domain.com"}, false}, {"impersonating does not match empty groups", []config.Policy{{From: "https://from.example", To: "https://to.example", AllowedGroups: []string{"support"}}}, "from.example", &Identity{Email: "admin@admin-domain.com", ImpersonateGroups: []string{""}}, []string{"admin@admin-domain.com"}, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
for i := range tt.policies {
if err := (&tt.policies[i]).Validate(); err != nil {
t.Fatal(err)
}
}
wl := NewIdentityWhitelist(tt.policies, tt.admins) wl := NewIdentityWhitelist(tt.policies, tt.admins)
if got := wl.Valid(tt.route, tt.Identity); got != tt.want { if got := wl.Valid(tt.route, tt.Identity); got != tt.want {
t.Errorf("wl.Valid() = %v, want %v", got, tt.want) t.Errorf("wl.Valid() = %v, want %v", got, tt.want)

View file

@ -15,7 +15,7 @@ import (
"github.com/pomerium/pomerium/authenticate" "github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/authorize" "github.com/pomerium/pomerium/authorize"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/https" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/metrics" "github.com/pomerium/pomerium/internal/metrics"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
@ -46,17 +46,17 @@ func main() {
mux := http.NewServeMux() mux := http.NewServeMux()
_, err = newAuthenticateService(opt, mux, grpcServer) _, err = newAuthenticateService(*opt, mux, grpcServer)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate") log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
} }
authz, err := newAuthorizeService(opt, grpcServer) authz, err := newAuthorizeService(*opt, grpcServer)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authorize") log.Fatal().Err(err).Msg("cmd/pomerium: authorize")
} }
proxy, err := newProxyService(opt, mux) proxy, err := newProxyService(*opt, mux)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: proxy") log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
} }
@ -73,7 +73,7 @@ func main() {
// defer proxyService.AuthenticateClient.Close() // defer proxyService.AuthenticateClient.Close()
// defer proxyService.AuthorizeClient.Close() // defer proxyService.AuthorizeClient.Close()
httpOpts := &https.Options{ httpOpts := &httputil.Options{
Addr: opt.Addr, Addr: opt.Addr,
Cert: opt.Cert, Cert: opt.Cert,
Key: opt.Key, Key: opt.Key,
@ -95,7 +95,7 @@ func main() {
defer srv.Close() defer srv.Close()
} }
if err := https.ListenAndServeTLS(httpOpts, wrapMiddleware(opt, mux), grpcServer); err != nil { if err := httputil.ListenAndServeTLS(httpOpts, wrapMiddleware(opt, mux), grpcServer); err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: https server") log.Fatal().Err(err).Msg("cmd/pomerium: https server")
} }
} }
@ -166,7 +166,7 @@ func newPromListener(addr string) {
log.Error().Err(metrics.NewPromHTTPListener(addr)).Str("MetricsAddr", addr).Msg("cmd/pomerium: could not start metrics exporter") log.Error().Err(metrics.NewPromHTTPListener(addr)).Str("MetricsAddr", addr).Msg("cmd/pomerium: could not start metrics exporter")
} }
func wrapMiddleware(o config.Options, mux *http.ServeMux) http.Handler { func wrapMiddleware(o *config.Options, mux *http.ServeMux) http.Handler {
c := middleware.NewChain() c := middleware.NewChain()
c = c.Append(metrics.HTTPMetricsHandler("proxy")) c = c.Append(metrics.HTTPMetricsHandler("proxy"))
c = c.Append(log.NewHandler(log.Logger)) c = c.Append(log.NewHandler(log.Logger))
@ -194,10 +194,10 @@ func wrapMiddleware(o config.Options, mux *http.ServeMux) http.Handler {
return c.Then(mux) return c.Then(mux)
} }
func parseOptions(configFile string) (config.Options, error) { func parseOptions(configFile string) (*config.Options, error) {
o, err := config.OptionsFromViper(configFile) o, err := config.OptionsFromViper(configFile)
if err != nil { if err != nil {
return o, err return nil, err
} }
if o.Debug { if o.Debug {
log.SetDebugMode() log.SetDebugMode()
@ -209,8 +209,12 @@ func parseOptions(configFile string) (config.Options, error) {
return o, nil return o, nil
} }
func handleConfigUpdate(opt config.Options, services []config.OptionsUpdater) config.Options { func handleConfigUpdate(opt *config.Options, services []config.OptionsUpdater) *config.Options {
newOpt, err := parseOptions(*configFile) newOpt, err := parseOptions(*configFile)
if err != nil {
log.Error().Err(err).Msg("cmd/pomerium: could not reload configuration")
return opt
}
optChecksum := opt.Checksum() optChecksum := opt.Checksum()
newOptChecksum := newOpt.Checksum() newOptChecksum := newOpt.Checksum()
@ -224,22 +228,10 @@ func handleConfigUpdate(opt config.Options, services []config.OptionsUpdater) co
return opt return opt
} }
if err != nil { log.Info().Str("checksum", newOptChecksum).Msg("cmd/pomerium: checksum changed")
log.Error().
Err(err).
Msg("cmd/pomerium: could not reload configuration")
return opt
}
log.Info().
Str("config-checksum", newOptChecksum).
Msg("cmd/pomerium: running configuration has changed")
for _, service := range services { for _, service := range services {
err := service.UpdateOptions(newOpt) if err := service.UpdateOptions(*newOpt); err != nil {
if err != nil { log.Error().Err(err).Msg("cmd/pomerium: could not update options")
log.Error().
Err(err).
Msg("cmd/pomerium: could not update options")
} }
} }

View file

@ -11,8 +11,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/pomerium/pomerium/internal/policy"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -72,21 +70,22 @@ func Test_newAuthenticateService(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
authURL, _ := url.Parse("http://auth.server.com") testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
testOpts := config.NewOptions() if err != nil {
t.Fatal(err)
}
testOpts.Provider = "google" testOpts.Provider = "google"
testOpts.ClientSecret = "TEST" testOpts.ClientSecret = "TEST"
testOpts.SharedKey = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=" testOpts.SharedKey = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=" testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testOpts.AuthenticateURL = *authURL
testOpts.Services = tt.s testOpts.Services = tt.s
if tt.Field != "" { if tt.Field != "" {
testOptsField := reflect.ValueOf(&testOpts).Elem().FieldByName(tt.Field) testOptsField := reflect.ValueOf(testOpts).Elem().FieldByName(tt.Field)
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value")) testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
} }
_, err := newAuthenticateService(testOpts, mux, grpcServer) _, err = newAuthenticateService(*testOpts, mux, grpcServer)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -116,21 +115,26 @@ func Test_newAuthorizeService(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
testOpts := config.NewOptions() testOpts, err := config.NewOptions("https://some.example", "https://some.example")
if err != nil {
t.Fatal(err)
}
testOpts.Services = tt.s testOpts.Services = tt.s
testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=" testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testPolicy := policy.Policy{From: "pomerium.io", To: "httpbin.org"} testPolicy := config.Policy{From: "http://some.example", To: "https://some.example"}
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
testOpts.Policies = []policy.Policy{ t.Fatal(err)
}
testOpts.Policies = []config.Policy{
testPolicy, testPolicy,
} }
if tt.Field != "" { if tt.Field != "" {
testOptsField := reflect.ValueOf(&testOpts).Elem().FieldByName(tt.Field) testOptsField := reflect.ValueOf(testOpts).Elem().FieldByName(tt.Field)
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value")) testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
} }
_, err := newAuthorizeService(testOpts, grpcServer) _, err = newAuthorizeService(*testOpts, grpcServer)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("newAuthorizeService() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("newAuthorizeService() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -156,26 +160,31 @@ func Test_newProxyeService(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
mux := http.NewServeMux() mux := http.NewServeMux()
testOpts := config.NewOptions() testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
testPolicy := policy.Policy{From: "pomerium.io", To: "httpbin.org"} if err != nil {
testPolicy.Validate() t.Fatal(err)
testOpts.Policies = []policy.Policy{ }
testPolicy := config.Policy{From: "http://some.example", To: "http://some.example"}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
testOpts.Policies = []config.Policy{
testPolicy, testPolicy,
} }
AuthenticateURL, _ := url.Parse("https://authenticate.example.com") AuthenticateURL, _ := url.Parse("https://authenticate.example.com")
AuthorizeURL, _ := url.Parse("https://authorize.example.com") AuthorizeURL, _ := url.Parse("https://authorize.example.com")
testOpts.AuthenticateURL = *AuthenticateURL testOpts.AuthenticateURL = AuthenticateURL
testOpts.AuthorizeURL = *AuthorizeURL testOpts.AuthorizeURL = AuthorizeURL
testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=" testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testOpts.Services = tt.s testOpts.Services = tt.s
if tt.Field != "" { if tt.Field != "" {
testOptsField := reflect.ValueOf(&testOpts).Elem().FieldByName(tt.Field) testOptsField := reflect.ValueOf(testOpts).Elem().FieldByName(tt.Field)
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value")) testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
} }
_, err := newProxyService(testOpts, mux) _, err = newProxyService(*testOpts, mux)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("newProxyService() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("newProxyService() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -205,7 +214,7 @@ func Test_wrapMiddleware(t *testing.T) {
}) })
mux.Handle("/404", h) mux.Handle("/404", h)
out := wrapMiddleware(o, mux) out := wrapMiddleware(&o, mux)
out.ServeHTTP(rr, req) out.ServeHTTP(rr, req)
expected := fmt.Sprintf("OK") expected := fmt.Sprintf("OK")
body := rr.Body.String() body := rr.Body.String()
@ -216,27 +225,31 @@ func Test_wrapMiddleware(t *testing.T) {
} }
func Test_parseOptions(t *testing.T) { func Test_parseOptions(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
envKey string envKey string
envValue string envValue string
servicesEnvKey string
wantSharedKey string servicesEnvValue string
wantErr bool wantSharedKey string
wantErr bool
}{ }{
{"no shared secret", "", "", "", true}, {"no shared secret", "", "", "SERVICES", "authenticate", "skip", true},
{"good", "SHARED_SECRET", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", false}, {"no shared secret in all mode", "", "", "", "", "", false},
{"good", "SHARED_SECRET", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", "", "", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
os.Setenv(tt.servicesEnvKey, tt.servicesEnvValue)
os.Setenv(tt.envKey, tt.envValue) os.Setenv(tt.envKey, tt.envValue)
defer os.Unsetenv(tt.envKey) defer os.Unsetenv(tt.envKey)
defer os.Unsetenv(tt.servicesEnvKey)
got, err := parseOptions("") got, err := parseOptions("")
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parseOptions() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parseOptions() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if got.SharedKey != tt.wantSharedKey { if got != nil && got.Services != "all" && got.SharedKey != tt.wantSharedKey {
t.Errorf("parseOptions()\n") t.Errorf("parseOptions()\n")
t.Errorf("got: %+v\n", got.SharedKey) t.Errorf("got: %+v\n", got.SharedKey)
t.Errorf("want: %+v\n", tt.wantSharedKey) t.Errorf("want: %+v\n", tt.wantSharedKey)
@ -265,22 +278,29 @@ func Test_handleConfigUpdate(t *testing.T) {
os.Setenv("SHARED_SECRET", "foo") os.Setenv("SHARED_SECRET", "foo")
defer os.Unsetenv("SHARED_SECRET") defer os.Unsetenv("SHARED_SECRET")
blankOpts := config.NewOptions() blankOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
goodOpts, _ := config.OptionsFromViper("") if err != nil {
t.Fatal(err)
}
goodOpts, err := config.OptionsFromViper("")
if err != nil {
t.Fatal(err)
}
tests := []struct { tests := []struct {
name string name string
service *mockService service *mockService
oldOpts config.Options oldOpts config.Options
wantUpdate bool wantUpdate bool
}{ }{
{"good", &mockService{fail: false}, blankOpts, true}, {"good", &mockService{fail: false}, *blankOpts, true},
{"bad", &mockService{fail: true}, blankOpts, true}, {"bad", &mockService{fail: true}, *blankOpts, true},
{"no change", &mockService{fail: false}, goodOpts, false}, {"no change", &mockService{fail: false}, *goodOpts, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
handleConfigUpdate(tt.oldOpts, []config.OptionsUpdater{tt.service}) handleConfigUpdate(&tt.oldOpts, []config.OptionsUpdater{tt.service})
if tt.service.Updated != tt.wantUpdate { if tt.service.Updated != tt.wantUpdate {
t.Errorf("Failed to update config on service") t.Errorf("Failed to update config on service")
} }

View file

@ -15,7 +15,7 @@ If you are coming from a kubernetes or docker background this should feel famili
Using both [environmental variables] and config file keys is allowed and encouraged (for instance, secret keys are probably best set as environmental variables). However, if duplicate configuration keys are found, environment variables take precedence. Using both [environmental variables] and config file keys is allowed and encouraged (for instance, secret keys are probably best set as environmental variables). However, if duplicate configuration keys are found, environment variables take precedence.
Pomerium will automatically reload the configuration file if it is changed. At this time, only policy is re-configured when this reload occurs, but additional options may be added in the future. It is suggested that your policy is stored in a configuration file so that you can take advantage of this feature. Pomerium will automatically reload the configuration file if it is changed. At this time, only policy is re-configured when this reload occurs, but additional options may be added in the future. It is suggested that your policy is stored in a configuration file so that you can take advantage of this feature.
## Global settings ## Global settings
@ -73,7 +73,7 @@ head -c32 /dev/urandom | base64
::: danger ::: danger
Enabling the debug flag will result in sensitive information being logged!!! Enabling the debug flag will result in sensitive information being logged!!!
::: :::
@ -149,19 +149,6 @@ Timeouts set the global server timeouts. For route-specific timeouts, see [polic
If set, the HTTP Redirect Address specifies the host and port to redirect http to https traffic on. If unset, no redirect server is started. If set, the HTTP Redirect Address specifies the host and port to redirect http to https traffic on. If unset, no redirect server is started.
### Websocket Connections
- Environmental Variable: `ALLOW_WEBSOCKETS`
- Config File Key: `allow_websockets`
- Type: `bool`
- Default: `false`
If set, enables proxying of websocket connections.
Otherwise the proxy responds with `400 Bad Request` to all websocket connections.
**Use with caution:** By definition, websockets are long-lived connections, so [global timeouts](#global-timeouts) are not enforced.
Allowing websocket connections to the proxy could result in abuse via DOS attacks.
### Metrics Address ### Metrics Address
- Environmental Variable: `METRICS_ADDRESS` - Environmental Variable: `METRICS_ADDRESS`
@ -171,31 +158,32 @@ Allowing websocket connections to the proxy could result in abuse via DOS attack
- Default: `disabled` - Default: `disabled`
- Optional - Optional
Expose a prometheus format HTTP endpoint on the specified port. Disabled by default. Expose a prometheus format HTTP endpoint on the specified port. Disabled by default.
**Use with caution:** the endpoint can expose frontend and backend server names or addresses. Do not expose the metrics port publicly. **Use with caution:** the endpoint can expose frontend and backend server names or addresses. Do not expose the metrics port publicly.
#### Metrics tracked #### Metrics tracked
| Name | Type | Description | Name | Type | Description
|:------------- |:-------------|:-----| :------------------------------ | :-------- | :--------------------------------------------
|http_server_requests_total| Counter | Total HTTP server requests handled by service| http_server_requests_total | Counter | Total HTTP server requests handled by service
|http_server_response_size_bytes| Histogram | HTTP server response size by service| http_server_response_size_bytes | Histogram | HTTP server response size by service
|http_server_request_duration_ms| Histogram | HTTP server request duration by service| http_server_request_duration_ms | Histogram | HTTP server request duration by service
|http_client_requests_total| Counter | Total HTTP client requests made by service| http_client_requests_total | Counter | Total HTTP client requests made by service
|http_client_response_size_bytes| Histogram | HTTP client response size by service| http_client_response_size_bytes | Histogram | HTTP client response size by service
|http_client_request_duration_ms| Histogram | HTTP client request duration by service| http_client_request_duration_ms | Histogram | HTTP client request duration by service
|grpc_client_requests_total| Counter | Total GRPC client requests made by service| grpc_client_requests_total | Counter | Total GRPC client requests made by service
|grpc_client_response_size_bytes| Histogram | GRPC client response size by service| grpc_client_response_size_bytes | Histogram | GRPC client response size by service
|grpc_client_request_duration_ms| Histogram | GRPC client request duration by service| grpc_client_request_duration_ms | Histogram | GRPC client request duration by service
### Policy ### Policy
- Environmental Variable: `POLICY` - Environmental Variable: `POLICY`
- Config File Key: `policy` - Config File Key: `policy`
- Type: [base64 encoded] `string` or inline policy structure in config file - Type: [base64 encoded] `string` or inline policy structure in config file
- Required - Required
- Required to forward traffic. Pomerium will safely start without a policy configured, but will be unable to authorize or proxy traffic until the configuration is updated to contain a policy.
- Required to forward traffic. Pomerium will safely start without a policy configured, but will be unable to authorize or proxy traffic until the configuration is updated to contain a policy.
Policy contains route specific settings, and access control details. If you are configuring via POLICY environment variable, just the contents of the policy needs to be passed. If you are configuring via file, the policy should be present under the policy key. For example, Policy contains route specific settings, and access control details. If you are configuring via POLICY environment variable, just the contents of the policy needs to be passed. If you are configuring via file, the policy should be present under the policy key. For example,
@ -277,6 +265,34 @@ If this setting is enabled, no whitelists (e.g. Allowed Users) should be provide
Policy timeout establishes the per-route timeout value. Cannot exceed global timeout values. Policy timeout establishes the per-route timeout value. Cannot exceed global timeout values.
#### Websocket Connections
- Config File Key: `allow_websockets`
- Type: `bool`
- Default: `false`
If set, enables proxying of websocket connections.
**Use with caution:** By definition, websockets are long-lived connections, so [global timeouts](#global-timeouts) are not enforced. Allowing websocket connections to the proxy could result in abuse via [DOS attacks](https://www.cloudflare.com/learning/ddos/ddos-attack-tools/slowloris/).
#### TLS Skip Verification
- Config File Key: `tls_skip_verify`
- Type: `bool`
- Default: `false`
TLS Skip Verification controls whether a client verifies the server's certificate chain and host name. If enabled, TLS accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
#### TLS Custom Certificate Authority
- Config File Key: `tls_custom_ca`
- Type: [base64 encoded] `string`
- Optional
TLS Custom Certificate Authority defines the set of root certificate authorities that clients use when verifying server certificates.
Note: This setting will replace (not append) the system's trust store for a given route.
## Authenticate Service ## Authenticate Service
### Authenticate Service URL ### Authenticate Service URL
@ -398,7 +414,7 @@ If your load balancer does not support gRPC pass-through you'll need to set this
- Optional (but typically required if Authenticate Internal Service Address is set) - Optional (but typically required if Authenticate Internal Service Address is set)
- Example: `*.corp.example.com` if wild card or `authenticate.corp.example.com`/`authorize.corp.example.com` - Example: `*.corp.example.com` if wild card or `authenticate.corp.example.com`/`authorize.corp.example.com`
When Authenticate Internal Service Address is set, secure service communication can fail because the external certificate name will not match the internally routed service hostname/[SNI](<https://en.wikipedia.org/wiki/Server_Name_Indication>). This setting allows you to override that check. When Authenticate Internal Service Address is set, secure service communication can fail because the external certificate name will not match the internally routed service hostname/[SNI](https://en.wikipedia.org/wiki/Server_Name_Indication). This setting allows you to override that check.
### Certificate Authority ### Certificate Authority
@ -414,17 +430,19 @@ Certificate Authority is set when behind-the-ingress service communication uses
- Environmental Variable: `HEADERS` - Environmental Variable: `HEADERS`
- Config File Key: `headers` - Config File Key: `headers`
- Type: map of `strings` key value pairs - Type: map of `strings` key value pairs
- Examples: - Examples:
- Comma Separated:
`X-Content-Type-Options:nosniff,X-Frame-Options:SAMEORIGIN` - Comma Separated: `X-Content-Type-Options:nosniff,X-Frame-Options:SAMEORIGIN`
- JSON: `'{"X-Test": "X-Value"}'` - JSON: `'{"X-Test": "X-Value"}'`
- YAML: - YAML:
```yaml
headers: ```yaml
X-Test: X-Value headers:
``` X-Test: X-Value
```
- To disable: `disable:true` - To disable: `disable:true`
- Default : - Default :
```javascript ```javascript
@ -460,7 +478,6 @@ Refresh cooldown is the minimum amount of time between allowed manually refreshe
Default Upstream Timeout is the default timeout applied to a proxied route when no `timeout` key is specified by the policy. Default Upstream Timeout is the default timeout applied to a proxied route when no `timeout` key is specified by the policy.
[base64 encoded]: https://en.wikipedia.org/wiki/Base64 [base64 encoded]: https://en.wikipedia.org/wiki/Base64
[environmental variables]: https://en.wikipedia.org/wiki/Environment_variable [environmental variables]: https://en.wikipedia.org/wiki/Environment_variable
[identity provider]: ./identity-providers.md [identity provider]: ./identity-providers.md

View file

@ -0,0 +1,58 @@
package config // import "github.com/pomerium/pomerium/internal/config"
import "os"
// findPwd returns best guess at current working directory
func findPwd() string {
p, err := os.Getwd()
if err != nil {
return "."
}
return p
}
// IsValidService checks to see if a service is a valid service mode
func IsValidService(s string) bool {
switch s {
case
"all",
"proxy",
"authorize",
"authenticate":
return true
}
return false
}
// IsAuthenticate checks to see if we should be running the authenticate service
func IsAuthenticate(s string) bool {
switch s {
case
"all",
"authenticate":
return true
}
return false
}
// IsAuthorize checks to see if we should be running the authorize service
func IsAuthorize(s string) bool {
switch s {
case
"all",
"authorize":
return true
}
return false
}
// IsProxy checks to see if we should be running the proxy service
func IsProxy(s string) bool {
switch s {
case
"all",
"proxy":
return true
}
return false
}

View file

@ -0,0 +1,94 @@
package config // import "github.com/pomerium/pomerium/internal/config"
import (
"testing"
)
func Test_isValidService(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", true},
{"all", "all", true},
{"authenticate", "authenticate", true},
{"authenticate bad case", "AuThenticate", false},
{"authorize implemented", "authorize", true},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidService(tt.service); got != tt.want {
t.Errorf("isValidService() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isAuthenticate(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", false},
{"all", "all", true},
{"authenticate", "authenticate", true},
{"authenticate bad case", "AuThenticate", false},
{"authorize implemented", "authorize", false},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsAuthenticate(tt.service); got != tt.want {
t.Errorf("isAuthenticate() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isAuthorize(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", false},
{"all", "all", true},
{"authorize", "authorize", true},
{"authorize bad case", "AuThorize", false},
{"authenticate implemented", "authenticate", false},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsAuthorize(tt.service); got != tt.want {
t.Errorf("isAuthenticate() = %v, want %v", got, tt.want)
}
})
}
}
func Test_IsProxy(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", true},
{"all", "all", true},
{"authorize", "authorize", false},
{"proxy bad case", "PrOxY", false},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsProxy(tt.service); got != tt.want {
t.Errorf("IsProxy() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,19 +1,20 @@
package config package config // import "github.com/pomerium/pomerium/internal/config"
import ( import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"time" "time"
"github.com/mitchellh/hashstructure" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/policy" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/mitchellh/hashstructure"
"github.com/spf13/viper" "github.com/spf13/viper"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -22,9 +23,6 @@ import (
const DisableHeaderKey = "disable" const DisableHeaderKey = "disable"
// Options are the global environmental flags used to set up pomerium's services. // Options are the global environmental flags used to set up pomerium's services.
// If a base64 encoded certificate and key are not provided as environmental variables,
// or if a file location is not provided, the server will attempt to find a matching keypair
// in the local directory as `./cert.pem` and `./privkey.pem` respectively.
type Options struct { type Options struct {
// Debug outputs human-readable logs to Stdout. // Debug outputs human-readable logs to Stdout.
Debug bool `mapstructure:"pomerium_debug"` Debug bool `mapstructure:"pomerium_debug"`
@ -42,10 +40,10 @@ type Options struct {
Services string `mapstructure:"services"` Services string `mapstructure:"services"`
// Addr specifies the host and port on which the server should serve // Addr specifies the host and port on which the server should serve
// HTTPS requests. If empty, ":https" is used. // HTTPS requests. If empty, ":https" (localhost:443) is used.
Addr string `mapstructure:"address"` Addr string `mapstructure:"address"`
// Cert and Key specifies the base64 encoded TLS certificates to use. // Cert and Key specifies the TLS certificates to use.
Cert string `mapstructure:"certificate"` Cert string `mapstructure:"certificate"`
Key string `mapstructure:"certificate_key"` Key string `mapstructure:"certificate_key"`
@ -64,15 +62,15 @@ type Options struct {
ReadHeaderTimeout time.Duration `mapstructure:"timeout_read_header"` ReadHeaderTimeout time.Duration `mapstructure:"timeout_read_header"`
IdleTimeout time.Duration `mapstructure:"timeout_idle"` IdleTimeout time.Duration `mapstructure:"timeout_idle"`
// Policy is a base64 encoded yaml blob which enumerates // Policies define per-route configuration and access control policies.
// per-route access control policies. Policies []Policy
PolicyEnv string PolicyEnv string
PolicyFile string `mapstructure:"policy_file"` PolicyFile string `mapstructure:"policy_file"`
// AuthenticateURL represents the externally accessible http endpoints // AuthenticateURL represents the externally accessible http endpoints
// used for authentication requests and callbacks // used for authentication requests and callbacks
AuthenticateURLString string `mapstructure:"authenticate_service_url"` AuthenticateURLString string `mapstructure:"authenticate_service_url"`
AuthenticateURL url.URL AuthenticateURL *url.URL
// Session/Cookie management // Session/Cookie management
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
@ -93,8 +91,6 @@ type Options struct {
Scopes []string `mapstructure:"idp_scopes"` Scopes []string `mapstructure:"idp_scopes"`
ServiceAccount string `mapstructure:"idp_service_account"` ServiceAccount string `mapstructure:"idp_service_account"`
Policies []policy.Policy
// Administrators contains a set of emails with users who have super user // Administrators contains a set of emails with users who have super user
// (sudo) access including the ability to impersonate other users' access // (sudo) access including the ability to impersonate other users' access
Administrators []string `mapstructure:"administrators"` Administrators []string `mapstructure:"administrators"`
@ -104,20 +100,20 @@ type Options struct {
// NOTE: As many load balancers do not support externally routed gRPC so // NOTE: As many load balancers do not support externally routed gRPC so
// this may be an internal location. // this may be an internal location.
AuthenticateInternalAddrString string `mapstructure:"authenticate_internal_url"` AuthenticateInternalAddrString string `mapstructure:"authenticate_internal_url"`
AuthenticateInternalAddr url.URL AuthenticateInternalAddr *url.URL
// AuthorizeURL is the routable destination of the authorize service's // AuthorizeURL is the routable destination of the authorize service's
// gRPC endpoint. NOTE: As many load balancers do not support // gRPC endpoint. NOTE: As many load balancers do not support
// externally routed gRPC so this may be an internal location. // externally routed gRPC so this may be an internal location.
AuthorizeURLString string `mapstructure:"authorize_service_url"` AuthorizeURLString string `mapstructure:"authorize_service_url"`
AuthorizeURL url.URL AuthorizeURL *url.URL
// Settings to enable custom behind-the-ingress service communication // Settings to enable custom behind-the-ingress service communication
OverrideCertificateName string `mapstructure:"override_certificate_name"` OverrideCertificateName string `mapstructure:"override_certificate_name"`
CA string `mapstructure:"certificate_authority"` CA string `mapstructure:"certificate_authority"`
CAFile string `mapstructure:"certificate_authority_file"` CAFile string `mapstructure:"certificate_authority_file"`
// SigningKey is a base64 encoded private key used to add a JWT-signature. // SigningKey is the private key used to add a JWT-signature.
// https://www.pomerium.io/docs/signed-headers.html // https://www.pomerium.io/docs/signed-headers.html
SigningKey string `mapstructure:"signing_key"` SigningKey string `mapstructure:"signing_key"`
@ -128,219 +124,170 @@ type Options struct {
// RefreshCooldown limits the rate a user can refresh her session // RefreshCooldown limits the rate a user can refresh her session
RefreshCooldown time.Duration `mapstructure:"refresh_cooldown"` RefreshCooldown time.Duration `mapstructure:"refresh_cooldown"`
// Sub-routes //Routes map[string]string `mapstructure:"routes"`
Routes map[string]string `mapstructure:"routes"` DefaultUpstreamTimeout time.Duration `mapstructure:"default_upstream_timeout"`
DefaultUpstreamTimeout time.Duration `mapstructure:"default_upstream_timeout"`
// Enable proxying of websocket connections. Defaults to "false".
// Caution: Enabling this feature could result in abuse via DOS attacks.
AllowWebsockets bool `mapstructure:"allow_websockets"`
// Address/Port to bind to for prometheus metrics // Address/Port to bind to for prometheus metrics
MetricsAddr string `mapstructure:"metrics_address"` MetricsAddr string `mapstructure:"metrics_address"`
} }
// NewOptions returns a new options struct with default values var defaultOptions = Options{
func NewOptions() Options { Debug: false,
o := Options{ LogLevel: "debug",
Debug: false, Services: "all",
LogLevel: "debug", CookieHTTPOnly: true,
Services: "all", CookieSecure: true,
CookieHTTPOnly: true, CookieExpire: time.Duration(14) * time.Hour,
CookieSecure: true, CookieRefresh: time.Duration(30) * time.Minute,
CookieExpire: time.Duration(14) * time.Hour, CookieName: "_pomerium",
CookieRefresh: time.Duration(30) * time.Minute, DefaultUpstreamTimeout: time.Duration(30) * time.Second,
CookieName: "_pomerium", Headers: map[string]string{
DefaultUpstreamTimeout: time.Duration(30) * time.Second, "X-Content-Type-Options": "nosniff",
Headers: map[string]string{ "X-Frame-Options": "SAMEORIGIN",
"X-Content-Type-Options": "nosniff", "X-XSS-Protection": "1; mode=block",
"X-Frame-Options": "SAMEORIGIN", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-XSS-Protection": "1; mode=block", },
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", Addr: ":https",
}, CertFile: filepath.Join(findPwd(), "cert.pem"),
Addr: ":https", KeyFile: filepath.Join(findPwd(), "privkey.pem"),
CertFile: filepath.Join(findPwd(), "cert.pem"), ReadHeaderTimeout: 10 * time.Second,
KeyFile: filepath.Join(findPwd(), "privkey.pem"), ReadTimeout: 30 * time.Second,
ReadHeaderTimeout: 10 * time.Second, WriteTimeout: 0, // support streaming by default
ReadTimeout: 30 * time.Second, IdleTimeout: 5 * time.Minute,
WriteTimeout: 0, // support streaming by default RefreshCooldown: time.Duration(5 * time.Minute),
IdleTimeout: 5 * time.Minute, }
RefreshCooldown: time.Duration(5 * time.Minute),
AllowWebsockets: false, // NewOptions returns a minimal options configuration built from default options.
// Any modifications to the structure should be followed up by a subsequent
// call to validate.
func NewOptions(authenticateURL, authorizeURL string) (*Options, error) {
o := defaultOptions
o.AuthenticateURLString = authenticateURL
o.AuthorizeURLString = authorizeURL
if err := o.Validate(); err != nil {
return nil, fmt.Errorf("internal/config: validation error %s", err)
} }
return o return &o, nil
} }
// OptionsFromViper builds the main binary's configuration // OptionsFromViper builds the main binary's configuration
// options by parsing environmental variables and config file // options by parsing environmental variables and config file
func OptionsFromViper(configFile string) (Options, error) { func OptionsFromViper(configFile string) (*Options, error) {
o := NewOptions() // start a copy of the default options
o := defaultOptions
// Load up config // Load up config
o.bindEnvs() o.bindEnvs()
if configFile != "" { if configFile != "" {
log.Info().
Str("file", configFile).
Msg("loading config from file")
viper.SetConfigFile(configFile) viper.SetConfigFile(configFile)
err := viper.ReadInConfig() if err := viper.ReadInConfig(); err != nil {
if err != nil { return nil, fmt.Errorf("internal/config: failed to read config: %s", err)
return o, fmt.Errorf("failed to read config: %s", err)
} }
} }
err := viper.Unmarshal(&o) if err := viper.Unmarshal(&o); err != nil {
if err != nil { return nil, fmt.Errorf("internal/config: failed to unmarshal config: %s", err)
return o, fmt.Errorf("failed to load options from config: %s", err)
} }
// Turn URL strings into url structs if err := o.Validate(); err != nil {
err = o.parseURLs() return nil, fmt.Errorf("internal/config: validation error %s", err)
if err != nil {
return o, fmt.Errorf("failed to parse URLs: %s", err)
} }
return &o, nil
// Load and initialize policy
err = o.parsePolicy()
if err != nil {
return o, fmt.Errorf("failed to parse Policy: %s", err)
}
// Parse Headers
err = o.parseHeaders()
if err != nil {
return o, fmt.Errorf("failed to parse Headers: %s", err)
}
if o.Debug {
log.SetDebugMode()
}
if o.LogLevel != "" {
log.SetLevel(o.LogLevel)
}
if _, disable := o.Headers[DisableHeaderKey]; disable {
o.Headers = make(map[string]string)
}
err = o.validate()
if err != nil {
return o, err
}
log.Debug().
Str("config-checksum", o.Checksum()).
Msg("read configuration with checksum")
return o, nil
} }
// validate ensures the Options fields are properly formed and present // Validate ensures the Options fields are properly formed, present, and hydrated.
func (o *Options) validate() error { func (o *Options) Validate() error {
if !IsValidService(o.Services) { if !IsValidService(o.Services) {
return fmt.Errorf("%s is an invalid service type", o.Services) return fmt.Errorf("%s is an invalid service type", o.Services)
} }
// shared key must be set for all modes other than "all"
if o.SharedKey == "" { if o.SharedKey == "" {
return errors.New("shared-key cannot be empty") if o.Services == "all" {
o.SharedKey = cryptutil.GenerateRandomString(32)
} else {
return errors.New("shared-key cannot be empty")
}
} }
if len(o.Routes) != 0 { if o.AuthenticateURLString != "" {
return errors.New("routes setting is deprecated, use policy instead") u, err := urlutil.ParseAndValidateURL(o.AuthenticateURLString)
if err != nil {
return fmt.Errorf("bad authenticate-url %s : %v", o.AuthenticateURLString, err)
}
o.AuthenticateURL = u
} }
if o.AuthorizeURLString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthorizeURLString)
if err != nil {
return fmt.Errorf("bad authorize-url %s : %v", o.AuthorizeURLString, err)
}
o.AuthorizeURL = u
}
if o.AuthenticateInternalAddrString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthenticateInternalAddrString)
if err != nil {
return fmt.Errorf("bad authenticate-internal-addr %s : %v", o.AuthenticateInternalAddrString, err)
}
o.AuthenticateInternalAddr = u
}
if o.PolicyFile != "" { if o.PolicyFile != "" {
return errors.New("Setting POLICY_FILE is deprecated, use policy env var or config file instead") return errors.New("policy file setting is deprecated")
}
if err := o.parsePolicy(); err != nil {
return fmt.Errorf("failed to parse policy: %s", err)
}
if err := o.parseHeaders(); err != nil {
return fmt.Errorf("failed to parse headers: %s", err)
}
if _, disable := o.Headers[DisableHeaderKey]; disable {
o.Headers = make(map[string]string)
} }
return nil return nil
} }
// parsePolicy initializes policy // parsePolicy initializes policy to the options from either base64 environmental
// variables or from a file
func (o *Options) parsePolicy() error { func (o *Options) parsePolicy() error {
var policies []policy.Policy var policies []Policy
// Parse from base64 env var // Parse from base64 env var
if o.PolicyEnv != "" { if o.PolicyEnv != "" {
policyBytes, err := base64.StdEncoding.DecodeString(o.PolicyEnv) policyBytes, err := base64.StdEncoding.DecodeString(o.PolicyEnv)
if err != nil { if err != nil {
return fmt.Errorf("Could not decode POLICY env var: %s", err) return fmt.Errorf("could not decode POLICY env var: %s", err)
} }
if err := yaml.Unmarshal(policyBytes, &policies); err != nil { if err := yaml.Unmarshal(policyBytes, &policies); err != nil {
return fmt.Errorf("Could not parse POLICY env var: %s", err) return fmt.Errorf("could not unmarshal policy yaml: %s", err)
} }
// Parse from file
} else { } else {
err := viper.UnmarshalKey("policy", &policies) // Parse from file
if err != nil { if err := viper.UnmarshalKey("policy", &policies); err != nil {
return err return err
} }
} }
if len(policies) != 0 {
o.Policies = policies
}
// Finish initializing policies // Finish initializing policies
for i := range policies { for i := range o.Policies {
err := (&policies[i]).Validate() if err := (&o.Policies[i]).Validate(); err != nil {
if err != nil {
return err return err
} }
} }
o.Policies = policies
return nil return nil
} }
// parseAndValidateURL wraps standard library's default url.Parse because it's much more // parseHeaders handles unmarshalling any custom headers correctly from the
// lenient about what type of urls it accepts than pomerium can be. // environment or viper's parsed keys
func parseAndValidateURL(rawurl string) (*url.URL, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
if u.Host == "" {
return nil, fmt.Errorf("%s does have a valid hostname", rawurl)
}
if u.Scheme == "" || u.Scheme != "https" {
return nil, fmt.Errorf("%s does have a valid https scheme", rawurl)
}
return u, nil
}
// parseURLs parses URL strings into actual URL pointers
func (o *Options) parseURLs() error {
if o.AuthenticateURLString != "" {
AuthenticateURL, err := parseAndValidateURL(o.AuthenticateURLString)
if err != nil {
return fmt.Errorf("internal/config: bad authenticate-url %s : %v", o.AuthenticateURLString, err)
}
o.AuthenticateURL = *AuthenticateURL
}
if o.AuthorizeURLString != "" {
AuthorizeURL, err := parseAndValidateURL(o.AuthorizeURLString)
if err != nil {
return fmt.Errorf("internal/config: bad authorize-url %s : %v", o.AuthorizeURLString, err)
}
o.AuthorizeURL = *AuthorizeURL
}
if o.AuthenticateInternalAddrString != "" {
AuthenticateInternalAddr, err := parseAndValidateURL(o.AuthenticateInternalAddrString)
if err != nil {
return fmt.Errorf("internal/config: bad authenticate-internal-addr %s : %v", o.AuthenticateInternalAddrString, err)
}
o.AuthenticateInternalAddr = *AuthenticateInternalAddr
}
return nil
}
// parseHeaders handles unmarshalling any custom headers correctly from the environment or
// viper's parsed keys
func (o *Options) parseHeaders() error { func (o *Options) parseHeaders() error {
var headers map[string]string var headers map[string]string
if o.HeadersEnv != "" { if o.HeadersEnv != "" {
// Handle JSON by default via viper // Handle JSON by default via viper
if headers = viper.GetStringMapString("HeadersEnv"); len(headers) == 0 { if headers = viper.GetStringMapString("HeadersEnv"); len(headers) == 0 {
// Try to parse "Key1:Value1,Key2:Value2" syntax // Try to parse "Key1:Value1,Key2:Value2" syntax
headerSlice := strings.Split(o.HeadersEnv, ",") headerSlice := strings.Split(o.HeadersEnv, ",")
for n := range headerSlice { for n := range headerSlice {
@ -350,7 +297,7 @@ func (o *Options) parseHeaders() error {
} else { } else {
// Something went wrong // Something went wrong
return fmt.Errorf("Failed to decode headers environment variable from '%s'", o.HeadersEnv) return fmt.Errorf("failed to decode headers from '%s'", o.HeadersEnv)
} }
} }
@ -358,7 +305,7 @@ func (o *Options) parseHeaders() error {
o.Headers = headers o.Headers = headers
} else if viper.IsSet("headers") { } else if viper.IsSet("headers") {
if err := viper.UnmarshalKey("headers", &headers); err != nil { if err := viper.UnmarshalKey("headers", &headers); err != nil {
return err return fmt.Errorf("header %s failed to parse: %s", viper.Get("headers"), err)
} }
o.Headers = headers o.Headers = headers
} }
@ -381,61 +328,6 @@ func (o *Options) bindEnvs() {
viper.BindEnv("HeadersEnv", "HEADERS") viper.BindEnv("HeadersEnv", "HEADERS")
} }
// findPwd returns best guess at current working directory
func findPwd() string {
p, err := os.Getwd()
if err != nil {
return "."
}
return p
}
// IsValidService checks to see if a service is a valid service mode
func IsValidService(s string) bool {
switch s {
case
"all",
"proxy",
"authorize",
"authenticate":
return true
}
return false
}
// IsAuthenticate checks to see if we should be running the authenticate service
func IsAuthenticate(s string) bool {
switch s {
case
"all",
"authenticate":
return true
}
return false
}
// IsAuthorize checks to see if we should be running the authorize service
func IsAuthorize(s string) bool {
switch s {
case
"all",
"authorize":
return true
}
return false
}
// IsProxy checks to see if we should be running the proxy service
func IsProxy(s string) bool {
switch s {
case
"all",
"proxy":
return true
}
return false
}
// OptionsUpdater updates local state based on an Options struct // OptionsUpdater updates local state based on an Options struct
type OptionsUpdater interface { type OptionsUpdater interface {
UpdateOptions(Options) error UpdateOptions(Options) error
@ -444,10 +336,9 @@ type OptionsUpdater interface {
// Checksum returns the checksum of the current options struct // Checksum returns the checksum of the current options struct
func (o *Options) Checksum() string { func (o *Options) Checksum() string {
hash, err := hashstructure.Hash(o, nil) hash, err := hashstructure.Hash(o, nil)
if err != nil { if err != nil {
log.Warn().Msg("could not calculate Option checksum") log.Warn().Err(err).Msg("internal/config: checksum failure")
return "no checksum availablle" return "no checksum available"
} }
return fmt.Sprintf("%x", hash) return fmt.Sprintf("%x", hash)
} }

View file

@ -6,18 +6,16 @@ import (
"io/ioutil" "io/ioutil"
"net/url" "net/url"
"os" "os"
"reflect"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/policy" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
func Test_validate(t *testing.T) { func Test_validate(t *testing.T) {
testOptions := func() Options { testOptions := func() Options {
o := NewOptions() o := defaultOptions
o.SharedKey = "test" o.SharedKey = "test"
o.Services = "all" o.Services = "all"
return o return o
@ -27,8 +25,10 @@ func Test_validate(t *testing.T) {
badServices.Services = "blue" badServices.Services = "blue"
badSecret := testOptions() badSecret := testOptions()
badSecret.SharedKey = "" badSecret.SharedKey = ""
badRoutes := testOptions() badSecret.Services = "authenticate"
badRoutes.Routes = map[string]string{"foo": "bar"} badSecretAllServices := testOptions()
badSecretAllServices.SharedKey = ""
badPolicyFile := testOptions() badPolicyFile := testOptions()
badPolicyFile.PolicyFile = "file" badPolicyFile.PolicyFile = "file"
@ -40,12 +40,12 @@ func Test_validate(t *testing.T) {
{"good default with no env settings", good, false}, {"good default with no env settings", good, false},
{"invalid service type", badServices, true}, {"invalid service type", badServices, true},
{"missing shared secret", badSecret, true}, {"missing shared secret", badSecret, true},
{"routes present", badRoutes, true}, {"missing shared secret but all service", badSecretAllServices, false},
{"policy file specified", badPolicyFile, true}, {"policy file specified", badPolicyFile, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := tt.testOpts.validate() err := tt.testOpts.Validate()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("optionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("optionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -54,95 +54,6 @@ func Test_validate(t *testing.T) {
} }
} }
func Test_isValidService(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", true},
{"all", "all", true},
{"authenticate", "authenticate", true},
{"authenticate bad case", "AuThenticate", false},
{"authorize implemented", "authorize", true},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidService(tt.service); got != tt.want {
t.Errorf("isValidService() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isAuthenticate(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", false},
{"all", "all", true},
{"authenticate", "authenticate", true},
{"authenticate bad case", "AuThenticate", false},
{"authorize implemented", "authorize", false},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsAuthenticate(tt.service); got != tt.want {
t.Errorf("isAuthenticate() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isAuthorize(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", false},
{"all", "all", true},
{"authorize", "authorize", true},
{"authorize bad case", "AuThorize", false},
{"authenticate implemented", "authenticate", false},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsAuthorize(tt.service); got != tt.want {
t.Errorf("isAuthenticate() = %v, want %v", got, tt.want)
}
})
}
}
func Test_IsProxy(t *testing.T) {
tests := []struct {
name string
service string
want bool
}{
{"proxy", "proxy", true},
{"all", "all", true},
{"authorize", "authorize", false},
{"proxy bad case", "PrOxY", false},
{"jiberish", "xd23", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsProxy(tt.service); got != tt.want {
t.Errorf("IsProxy() = %v, want %v", got, tt.want)
}
})
}
}
func Test_bindEnvs(t *testing.T) { func Test_bindEnvs(t *testing.T) {
o := &Options{} o := &Options{}
os.Clearenv() os.Clearenv()
@ -171,44 +82,6 @@ func Test_bindEnvs(t *testing.T) {
} }
} }
func Test_parseURLs(t *testing.T) {
tests := []struct {
name string
authorizeURL string
authenticateURL string
authenticateInternalURL string
wantErr bool
}{
{"good", "https://authz.mydomain.example", "https://authn.mydomain.example", "https://internal.svc.local", false},
{"bad not https scheme", "http://authz.mydomain.example", "http://authn.mydomain.example", "http://internal.svc.local", true},
{"missing scheme", "authz.mydomain.example", "authn.mydomain.example", "internal.svc.local", true},
{"bad authorize", "notaurl", "https://authn.mydomain.example", "", true},
{"bad authenticate", "https://authz.mydomain.example", "notaurl", "", true},
{"bad authenticate internal", "", "", "just.some.naked.domain.example", true},
{"only authn", "", "https://authn.mydomain.example", "", false},
{"only authz", "https://authz.mydomain.example", "", "", false},
{"malformed", "http://a b.com/", "", "", true},
}
for _, test := range tests {
o := &Options{
AuthenticateURLString: test.authenticateURL,
AuthorizeURLString: test.authorizeURL,
AuthenticateInternalAddrString: test.authenticateInternalURL,
}
err := o.parseURLs()
if (err != nil) != test.wantErr {
t.Errorf("Failed to parse URLs %v: %s", test, err)
}
if err == nil && o.AuthenticateURL.String() != test.authenticateURL {
t.Errorf("Failed to update AuthenticateURL: %v", test)
}
if err == nil && o.AuthorizeURL.String() != test.authorizeURL {
t.Errorf("Failed to update AuthorizeURL: %v", test)
}
}
}
func Test_parseHeaders(t *testing.T) { func Test_parseHeaders(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -219,15 +92,15 @@ func Test_parseHeaders(t *testing.T) {
}{ }{
{"good env", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`, map[string]string{"X": "foo"}, false}, {"good env", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`, map[string]string{"X": "foo"}, false},
{"good env not_json", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, `X-Custom-1:foo,X-Custom-2:bar`, map[string]string{"X": "foo"}, false}, {"good env not_json", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, `X-Custom-1:foo,X-Custom-2:bar`, map[string]string{"X": "foo"}, false},
{"good viper", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, "", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, false},
{"bad env", map[string]string{}, "xyyyy", map[string]string{"X": "foo"}, true}, {"bad env", map[string]string{}, "xyyyy", map[string]string{"X": "foo"}, true},
{"bad env not_json", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, `X-Custom-1:foo,X-Custom-2bar`, map[string]string{"X": "foo"}, true}, {"bad env not_json", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, `X-Custom-1:foo,X-Custom-2bar`, map[string]string{"X": "foo"}, true},
{"bad viper", map[string]string{}, "", "notaheaderstruct", true}, {"bad viper", map[string]string{}, "", "notaheaderstruct", true},
{"good viper", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, "", map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"}, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := NewOptions() o := defaultOptions
viper.Set("headers", tt.viperHeaders) viper.Set("headers", tt.viperHeaders)
viper.Set("HeadersEnv", tt.envHeaders) viper.Set("HeadersEnv", tt.envHeaders)
o.HeadersEnv = tt.envHeaders o.HeadersEnv = tt.envHeaders
@ -241,23 +114,28 @@ func Test_parseHeaders(t *testing.T) {
if !tt.wantErr && !cmp.Equal(tt.want, o.Headers) { if !tt.wantErr && !cmp.Equal(tt.want, o.Headers) {
t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.Headers)) t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.Headers))
} }
viper.Reset()
}) })
} }
} }
func Test_OptionsFromViper(t *testing.T) { func Test_OptionsFromViper(t *testing.T) {
testPolicy := policy.Policy{ viper.Reset()
testPolicy := Policy{
To: "https://httpbin.org", To: "https://httpbin.org",
From: "https://pomerium.io", From: "https://pomerium.io",
} }
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
testPolicies := []policy.Policy{ t.Fatal(err)
}
testPolicies := []Policy{
testPolicy, testPolicy,
} }
goodConfigBytes := []byte(`{"authorize_service_url":"https://authorize.corp.example","authenticate_service_url":"https://authenticate.corp.example","shared_secret":"Setec Astronomy","service":"all","policy":[{"from":"https://pomerium.io","to":"https://httpbin.org"}]}`) goodConfigBytes := []byte(`{"authorize_service_url":"https://authorize.corp.example","authenticate_service_url":"https://authenticate.corp.example","shared_secret":"Setec Astronomy","service":"all","policy":[{"from":"https://pomerium.io","to":"https://httpbin.org"}]}`)
goodOptions := NewOptions() goodOptions := defaultOptions
goodOptions.SharedKey = "Setec Astronomy" goodOptions.SharedKey = "Setec Astronomy"
goodOptions.Services = "all" goodOptions.Services = "all"
goodOptions.Policies = testPolicies goodOptions.Policies = testPolicies
@ -272,21 +150,23 @@ func Test_OptionsFromViper(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
goodOptions.AuthorizeURL = *authorize goodOptions.AuthorizeURL = authorize
goodOptions.AuthenticateURL = *authenticate goodOptions.AuthenticateURL = authenticate
if err := goodOptions.Validate(); err != nil {
t.Fatal(err)
}
badConfigBytes := []byte("badjson!") badConfigBytes := []byte("badjson!")
badUnmarshalConfigBytes := []byte(`"debug": "blue"`) badUnmarshalConfigBytes := []byte(`"debug": "blue"`)
tests := []struct { tests := []struct {
name string name string
configBytes []byte configBytes []byte
want Options want *Options
wantErr bool wantErr bool
}{ }{
{"good", goodConfigBytes, goodOptions, false}, {"good", goodConfigBytes, &goodOptions, false},
{"bad json", badConfigBytes, NewOptions(), true}, {"bad json", badConfigBytes, nil, true},
{"bad unmarshal", badUnmarshalConfigBytes, NewOptions(), true}, {"bad unmarshal", badUnmarshalConfigBytes, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -302,8 +182,13 @@ func Test_OptionsFromViper(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OptionsFromViper() error = \n%v, wantErr \n%v", err, tt.wantErr) t.Errorf("OptionsFromViper() error = \n%v, wantErr \n%v", err, tt.wantErr)
} }
if tt.want != nil {
if err := tt.want.Validate(); err != nil {
t.Fatal(err)
}
}
if diff := cmp.Diff(got, tt.want); diff != "" { if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("OptionsFromViper() = \n%s\n, \ngot\n%v\n, want \n%v", diff, got, tt.want) t.Errorf("OptionsFromViper() = \n%s\n, \ngot\n%+v\n, want \n%+v", diff, got, tt.want)
} }
}) })
@ -318,6 +203,8 @@ func Test_OptionsFromViper(t *testing.T) {
func Test_parsePolicyEnv(t *testing.T) { func Test_parsePolicyEnv(t *testing.T) {
t.Parallel() t.Parallel()
viper.Reset()
source := "https://pomerium.io" source := "https://pomerium.io"
sourceURL, _ := url.ParseRequestURI(source) sourceURL, _ := url.ParseRequestURI(source)
dest := "https://httpbin.org" dest := "https://httpbin.org"
@ -326,12 +213,12 @@ func Test_parsePolicyEnv(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
policyBytes []byte policyBytes []byte
want []policy.Policy want []Policy
wantErr bool wantErr bool
}{ }{
{"simple json", []byte(fmt.Sprintf(`[{"from": "%s","to":"%s"}]`, source, dest)), []policy.Policy{{From: source, To: dest, Source: sourceURL, Destination: destURL}}, false}, {"simple json", []byte(fmt.Sprintf(`[{"from": "%s","to":"%s"}]`, source, dest)), []Policy{{From: source, To: dest, Source: sourceURL, Destination: destURL}}, false},
{"bad from", []byte(`[{"from": "%","to":"httpbin.org"}]`), nil, true}, {"bad from", []byte(`[{"from": "%","to":"httpbin.org"}]`), []Policy{{From: "%", To: "httpbin.org"}}, true},
{"bad to", []byte(`[{"from": "pomerium.io","to":"%"}]`), nil, true}, {"bad to", []byte(`[{"from": "pomerium.io","to":"%"}]`), []Policy{{From: "pomerium.io", To: "%"}}, true},
{"simple error", []byte(`{}`), nil, true}, {"simple error", []byte(`{}`), nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -341,11 +228,11 @@ func Test_parsePolicyEnv(t *testing.T) {
o.PolicyEnv = base64.StdEncoding.EncodeToString(tt.policyBytes) o.PolicyEnv = base64.StdEncoding.EncodeToString(tt.policyBytes)
err := o.parsePolicy() err := o.parsePolicy()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parasePolicy() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parsePolicyEnv() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(o.Policies, tt.want) { if diff := cmp.Diff(o.Policies, tt.want); diff != "" {
t.Errorf("parasePolicy() = \n%v, want \n%v", o, tt.want) t.Errorf("parsePolicyEnv() = %s", diff)
} }
}) })
} }
@ -355,11 +242,12 @@ func Test_parsePolicyEnv(t *testing.T) {
o.PolicyEnv = "foo" o.PolicyEnv = "foo"
err := o.parsePolicy() err := o.parsePolicy()
if err == nil { if err == nil {
t.Errorf("parasePolicy() did not catch bad base64 %v", o) t.Errorf("parsePolicyEnv() did not catch bad base64 %v", o)
} }
} }
func Test_parsePolicyFile(t *testing.T) { func Test_parsePolicyFile(t *testing.T) {
viper.Reset()
source := "https://pomerium.io" source := "https://pomerium.io"
sourceURL, _ := url.ParseRequestURI(source) sourceURL, _ := url.ParseRequestURI(source)
dest := "https://httpbin.org" dest := "https://httpbin.org"
@ -368,38 +256,41 @@ func Test_parsePolicyFile(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
policyBytes []byte policyBytes []byte
want []policy.Policy want []Policy
wantErr bool wantErr bool
}{ }{
{"simple json", []byte(fmt.Sprintf(`{"policy":[{"from": "%s","to":"%s"}]}`, source, dest)), []policy.Policy{{From: source, To: dest, Source: sourceURL, Destination: destURL}}, false}, {"simple json", []byte(fmt.Sprintf(`{"policy":[{"from": "%s","to":"%s"}]}`, source, dest)), []Policy{{From: source, To: dest, Source: sourceURL, Destination: destURL}}, false},
{"bad from", []byte(`{"policy":[{"from": "%","to":"httpbin.org"}]}`), nil, true}, {"bad from", []byte(`{"policy":[{"from": "%","to":"httpbin.org"}]}`), nil, true},
{"bad to", []byte(`{"policy":[{"from": "pomerium.io","to":"%"}]}`), nil, true}, {"bad to", []byte(`{"policy":[{"from": "pomerium.io","to":"%"}]}`), nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := new(Options)
tempFile, _ := ioutil.TempFile("", "*.json") tempFile, _ := ioutil.TempFile("", "*.json")
defer tempFile.Close() defer tempFile.Close()
defer os.Remove(tempFile.Name()) defer os.Remove(tempFile.Name())
tempFile.Write(tt.policyBytes) tempFile.Write(tt.policyBytes)
o = new(Options) o := new(Options)
viper.SetConfigFile(tempFile.Name()) viper.SetConfigFile(tempFile.Name())
err := viper.ReadInConfig() if err := viper.ReadInConfig(); err != nil {
err = o.parsePolicy() t.Fatal(err)
}
err := o.parsePolicy()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parasePolicy() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parsePolicyEnv() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(o.Policies, tt.want) { if err == nil {
t.Errorf("parasePolicy() = \n%v, want \n%v", o, tt.want) if diff := cmp.Diff(o.Policies, tt.want); diff != "" {
t.Errorf("parsePolicyEnv() = diff:%s", diff)
}
} }
}) })
} }
} }
func Test_Checksum(t *testing.T) { func Test_Checksum(t *testing.T) {
o := NewOptions() o := defaultOptions
oldChecksum := o.Checksum() oldChecksum := o.Checksum()
o.SharedKey = "changemeplease" o.SharedKey = "changemeplease"
@ -417,3 +308,103 @@ func Test_Checksum(t *testing.T) {
t.Error("Checksum() inconsistent") t.Error("Checksum() inconsistent")
} }
} }
func TestNewOptions(t *testing.T) {
viper.Reset()
tests := []struct {
name string
authenticateURL string
authorizeURL string
want *Options
wantErr bool
}{
{"good", "https://authenticate.example", "https://authorize.example", nil, false},
{"bad authenticate url no scheme", "authenticate.example", "https://authorize.example", nil, true},
{"bad authenticate url no host", "https://", "https://authorize.example", nil, true},
{"bad authorize url no scheme", "https://authenticate.example", "authorize.example", nil, true},
{"bad authorize url no host", "https://authenticate.example", "https://", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewOptions(tt.authenticateURL, tt.authorizeURL)
if (err != nil) != tt.wantErr {
t.Errorf("NewOptions() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestOptionsFromViper(t *testing.T) {
opts := []cmp.Option{
cmpopts.IgnoreFields(Options{}, "AuthenticateInternalAddr", "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout"),
cmpopts.IgnoreFields(Policy{}, "Source", "Destination"),
}
tests := []struct {
name string
configBytes []byte
want *Options
wantErr bool
}{
{"good",
[]byte(`{"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
CookieName: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
Headers: map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}},
false},
{"good with authenticate internal url",
[]byte(`{"authenticate_internal_url": "https://internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
AuthenticateInternalAddrString: "https://internal.example",
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
CookieName: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
Headers: map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}},
false},
{"good disable header",
[]byte(`{"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
CookieName: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
Headers: map[string]string{}},
false},
{"bad authenticate internal url", []byte(`{"authenticate_internal_url": "internal.example","policy":[{"from": "https://from.example","to":"https://to.example"}]}`), nil, true},
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},
{"bad file", []byte(`{''''}`), nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempFile, _ := ioutil.TempFile("", "*.json")
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.configBytes)
got, err := OptionsFromViper(tempFile.Name())
if (err != nil) != tt.wantErr {
t.Errorf("OptionsFromViper() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(got, tt.want, opts...); diff != "" {
t.Errorf("NewOptions() = %s", diff)
}
})
}
}

71
internal/config/policy.go Normal file
View file

@ -0,0 +1,71 @@
package config // import "github.com/pomerium/pomerium/internal/config"
import (
"errors"
"fmt"
"net/url"
"time"
"github.com/pomerium/pomerium/internal/urlutil"
)
// Policy contains route specific configuration and access settings.
type Policy struct {
From string `mapstructure:"from" yaml:"from"`
To string `mapstructure:"to" yaml:"to"`
// Identity related policy
AllowedEmails []string `mapstructure:"allowed_users" yaml:"allowed_users"`
AllowedGroups []string `mapstructure:"allowed_groups" yaml:"allowed_groups"`
AllowedDomains []string `mapstructure:"allowed_domains" yaml:"allowed_domains"`
Source *url.URL
Destination *url.URL
// Allow unauthenticated HTTP OPTIONS requests as per the CORS spec
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#Preflighted_requests
CORSAllowPreflight bool `mapstructure:"cors_allow_preflight" yaml:"cors_allow_preflight"`
// Allow any public request to access this route. **Bypasses authentication**
AllowPublicUnauthenticatedAccess bool `mapstructure:"allow_public_unauthenticated_access" yaml:"allow_public_unauthenticated_access"`
// UpstreamTimeout is the route specific timeout. Must be less than the global
// timeout. If unset, route will fallback to the proxy's DefaultUpstreamTimeout.
UpstreamTimeout time.Duration `mapstructure:"timeout" yaml:"timeout"`
// Enable proxying of websocket connections by removing the default timeout handler.
// Caution: Enabling this feature could result in abuse via DOS attacks.
AllowWebsockets bool `mapstructure:"allow_websockets" yaml:"allow_websockets"`
// TLSSkipVerify controls whether a client verifies the server's certificate
// chain and host name.
// If TLSSkipVerify is true, TLS accepts any certificate presented by the
// server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
TLSSkipVerify bool `mapstructure:"tls_skip_verify" yaml:"tls_skip_verify"`
// TLSCustomCA defines the root certificate to use with a given
// route when verifying server certificates.
TLSCustomCA string `mapstructure:"tls_custom_ca" yaml:"tls_custom_ca"`
}
// Validate checks the validity of a policy.
func (p *Policy) Validate() error {
var err error
p.Source, err = urlutil.ParseAndValidateURL(p.From)
if err != nil {
return fmt.Errorf("internal/config: bad source url %s", err)
}
p.Destination, err = urlutil.ParseAndValidateURL(p.To)
if err != nil {
return fmt.Errorf("internal/config: bad destination url %s", err)
}
// Only allow public access if no other whitelists are in place
if p.AllowPublicUnauthenticatedAccess && (p.AllowedDomains != nil || p.AllowedGroups != nil || p.AllowedEmails != nil) {
return errors.New("internal/config: route marked as public but contains whitelists")
}
return nil
}

View file

@ -0,0 +1,43 @@
package config // import "github.com/pomerium/pomerium/internal/config"
import (
"testing"
)
func Test_Validate(t *testing.T) {
t.Parallel()
basePolicy := Policy{From: "https://httpbin.corp.example", To: "https://httpbin.corp.notatld"}
corsPolicy := basePolicy
corsPolicy.CORSAllowPreflight = true
publicPolicy := basePolicy
publicPolicy.AllowPublicUnauthenticatedAccess = true
publicAndWhitelistPolicy := publicPolicy
publicAndWhitelistPolicy.AllowedEmails = []string{"test@gmail.com"}
tests := []struct {
name string
policy Policy
wantErr bool
}{
{"good", basePolicy, false},
{"empty to host", Policy{From: "https://httpbin.corp.example", To: "https://"}, true},
{"empty from host", Policy{From: "https://", To: "https://httpbin.corp.example"}, true},
{"empty from scheme", Policy{From: "httpbin.corp.example", To: "https://httpbin.corp.example"}, true},
{"empty to scheme", Policy{From: "https://httpbin.corp.example", To: "//httpbin.corp.example"}, true},
{"cors policy", corsPolicy, false},
{"public policy", publicPolicy, false},
{"public and whitelist", publicAndWhitelistPolicy, true},
{"route must have", publicAndWhitelistPolicy, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.policy.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, want %v", err, tt.wantErr)
}
})
}
}

View file

@ -13,14 +13,32 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
const DefaultKeySize = 32
// GenerateKey generates a random 32-byte key. // GenerateKey generates a random 32-byte key.
//
// Panics if source of randomness fails. // Panics if source of randomness fails.
func GenerateKey() []byte { func GenerateKey() []byte {
key := make([]byte, 32) return randomBytes(DefaultKeySize)
if _, err := rand.Read(key); err != nil { }
// GenerateRandomString returns base64 encoded securely generated random string
// of a given set of bytes.
//
// Panics if source of randomness fails.
func GenerateRandomString(c int) string {
return base64.StdEncoding.EncodeToString(randomBytes(c))
}
func randomBytes(c int) []byte {
if c < 0 {
c = DefaultKeySize
}
b := make([]byte, c)
if _, err := rand.Read(b); err != nil {
panic(err) panic(err)
} }
return key return b
} }
// Cipher provides methods to encrypt and decrypt values. // Cipher provides methods to encrypt and decrypt values.

View file

@ -1,8 +1,9 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" package cryptutil
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/base64"
"fmt" "fmt"
"reflect" "reflect"
"sync" "sync"
@ -162,3 +163,29 @@ func TestCipherDataRace(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
func TestGenerateRandomString(t *testing.T) {
t.Parallel()
tests := []struct {
name string
c int
want int
}{
{"simple", 32, 32},
{"zero", 0, 0},
{"negative", -1, 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := GenerateRandomString(tt.c)
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("GenerateRandomString() = %d, want %d", got, tt.want)
}
})
}
}

View file

@ -1,4 +1,4 @@
package https // import "github.com/pomerium/pomerium/internal/https" package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import ( import (
"crypto/tls" "crypto/tls"
@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"google.golang.org/grpc" "google.golang.org/grpc"
) )

View file

@ -18,8 +18,7 @@ func newPromHTTPHandler() http.Handler {
// TODO this is a cheap way to get thorough go process // TODO this is a cheap way to get thorough go process
// stats. It will not work with additional exporters. // stats. It will not work with additional exporters.
// It should turn into an FR to the OC framework // It should turn into an FR to the OC framework
var reg *prom.Registry reg := prom.DefaultRegisterer.(*prom.Registry)
reg = prom.DefaultRegisterer.(*prom.Registry)
pe, _ := ocProm.NewExporter(ocProm.Options{ pe, _ := ocProm.NewExporter(ocProm.Options{
Namespace: "pomerium", Namespace: "pomerium",
Registry: reg, Registry: reg,

View file

@ -13,12 +13,6 @@ import (
"go.opencensus.io/stats/view" "go.opencensus.io/stats/view"
) )
type measure struct {
Name string
Tags map[string]string
Measure int
}
func newTestMux() http.Handler { func newTestMux() http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/good", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/good", func(w http.ResponseWriter, r *http.Request) {

View file

@ -0,0 +1,42 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
)
func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwt, err := signer.SignJWT(
r.Header.Get(id),
r.Header.Get(email),
r.Header.Get(groups))
if err != nil {
log.Warn().Err(err).Msg("internal/middleware: failed signing request")
} else {
r.Header.Set(header, jwt)
}
next.ServeHTTP(w, r)
})
}
}
// StripPomeriumCookie ensures that every response includes some basic security headers
func StripPomeriumCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := make([]string, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if cookie.Name != cookieName {
headers = append(headers, cookie.String())
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r)
})
}
}

View file

@ -0,0 +1,94 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const exampleKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIM3mpZIWXCX9yEgxU6s57CbtbUNDBSCEAtQF5fUWHpcQoAoGCCqGSM49
AwEHoUQDQgAEhPQv+LACPVNmBTK0xSTzbpEPkRrk1eUt1BOa32SEfUPzNi4IWeZ/
KKITt2q1IqpV2KMSbVDyr9ijv/Xh98iyEw==
-----END EC PRIVATE KEY-----
`
func TestSignRequest(t *testing.T) {
tests := []struct {
name string
id string
email string
groups string
header string
}{
{"good", "id", "email", "group", "Jwt"},
}
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Header.Set(fmt.Sprintf("%s-header", tt.id), tt.id)
r.Header.Set(fmt.Sprintf("%s-header", tt.email), tt.email)
r.Header.Set(fmt.Sprintf("%s-header", tt.groups), tt.groups)
})
rr := httptest.NewRecorder()
signer, err := cryptutil.NewES256Signer([]byte(exampleKey), "audience")
if err != nil {
t.Fatal(err)
}
handler := SignRequest(signer, tt.id, tt.email, tt.groups, tt.header)(testHandler)
handler.ServeHTTP(rr, req)
jwt := req.Header["Jwt"]
if len(jwt) != 1 {
t.Errorf("no jwt found %v", req.Header)
}
})
}
}
func TestStripPomeriumCookie(t *testing.T) {
tests := []struct {
name string
pomeriumCookie string
otherCookies []string
}{
{"good", "pomerium", []string{"x", "y", "z"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, cookie := range r.Cookies() {
if cookie.Name == tt.pomeriumCookie {
t.Errorf("cookie not stripped %s", r.Cookies())
}
}
})
rr := httptest.NewRecorder()
for _, cn := range tt.otherCookies {
http.SetCookie(rr, &http.Cookie{
Name: cn,
Value: "some other cookie",
})
}
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie,
Value: "pomerium cookie!",
})
req := &http.Request{Header: http.Header{"Cookie": rr.HeaderMap["Set-Cookie"]}}
handler := StripPomeriumCookie(tt.pomeriumCookie)(testHandler)
handler.ServeHTTP(rr, req)
})
}
}

View file

@ -1,64 +0,0 @@
package policy // import "github.com/pomerium/pomerium/internal/policy"
import (
"errors"
"fmt"
"net/url"
"strings"
"time"
)
// Policy contains authorization policy information.
// todo(bdd) : add upstream timeout and configuration settings
type Policy struct {
//
From string `mapstructure:"from" yaml:"from"`
To string `mapstructure:"to" yaml:"to"`
// Identity related policy
AllowedEmails []string `mapstructure:"allowed_users" yaml:"allowed_users"`
AllowedGroups []string `mapstructure:"allowed_groups" yaml:"allowed_groups"`
AllowedDomains []string `mapstructure:"allowed_domains" yaml:"allowed_domains"`
Source *url.URL
Destination *url.URL
// Allow unauthenticated HTTP OPTIONS requests as per the CORS spec
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#Preflighted_requests
CORSAllowPreflight bool `mapstructure:"cors_allow_preflight" yaml:"cors_allow_preflight"`
// Allow any public request to access this route. **Bypasses authentication**
AllowPublicUnauthenticatedAccess bool `mapstructure:"allow_public_unauthenticated_access" yaml:"allow_public_unauthenticated_access"`
// UpstreamTimeout is the route specific timeout. Must be less than the global
// timeout. If unset, route will fallback to the proxy's DefaultUpstreamTimeout.
UpstreamTimeout time.Duration `mapstructure:"timeout" yaml:"timeout"`
}
// Validate parses the source and destination URLs in the Policy
func (p *Policy) Validate() (err error) {
p.Source, err = urlParse(p.From)
if err != nil {
return err
}
p.Destination, err = urlParse(p.To)
if err != nil {
return err
}
// Only allow public access if no other whitelists are in place
if p.AllowPublicUnauthenticatedAccess && (p.AllowedDomains != nil || p.AllowedGroups != nil || p.AllowedEmails != nil) {
return errors.New("route marked as public but contains whitelists")
}
return nil
}
// URLParse wraps url.Parse to add a scheme if none-exists.
// https://github.com/golang/go/issues/12585
func urlParse(uri string) (*url.URL, error) {
if !strings.Contains(uri, "://") {
uri = fmt.Sprintf("https://%s", uri)
}
return url.ParseRequestURI(uri)
}

View file

@ -1,66 +0,0 @@
package policy
import (
"net/url"
"reflect"
"testing"
)
func Test_urlParse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
uri string
want *url.URL
wantErr bool
}{
{"good url without schema", "accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
{"good url with schema", "https://accounts.google.com", &url.URL{Scheme: "https", Host: "accounts.google.com"}, false},
{"bad url, malformed", "https://accounts.google.^", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := urlParse(tt.uri)
if (err != nil) != tt.wantErr {
t.Errorf("urlParse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("urlParse() = %v, want %v", got, tt.want)
}
})
}
}
func Test_Validate(t *testing.T) {
t.Parallel()
basePolicy := Policy{From: "httpbin.corp.example", To: "httpbin.corp.notatld"}
corsPolicy := basePolicy
corsPolicy.CORSAllowPreflight = true
publicPolicy := basePolicy
publicPolicy.AllowPublicUnauthenticatedAccess = true
publicAndWhitelistPolicy := publicPolicy
publicAndWhitelistPolicy.AllowedEmails = []string{"test@gmail.com"}
tests := []struct {
name string
policy Policy
wantErr bool
}{
{"good", basePolicy, false},
{"cors policy", corsPolicy, false},
{"public policy", publicPolicy, false},
{"public and whitelist", publicAndWhitelistPolicy, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.policy.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, want %v", err, tt.wantErr)
}
})
}
}

View file

@ -57,7 +57,6 @@ func (s *RestStore) ClearSession(w http.ResponseWriter, r *http.Request) {
"error_description": "The token has expired." "error_description": "The token has expired."
}` }`
w.Write([]byte(errMsg)) w.Write([]byte(errMsg))
return
} }
// LoadSession attempts to load a pomerium session from a Bearer Token set // LoadSession attempts to load a pomerium session from a Bearer Token set

View file

@ -1,6 +1,10 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
import "strings" import (
"fmt"
"net/url"
"strings"
)
// StripPort returns a host, without any port number. // StripPort returns a host, without any port number.
// //
@ -17,3 +21,19 @@ func StripPort(hostport string) string {
} }
return hostport[:colon] return hostport[:colon]
} }
// ParseAndValidateURL wraps standard library's default url.Parse because
// it's much more lenient about what type of urls it accepts than pomerium.
func ParseAndValidateURL(rawurl string) (*url.URL, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
if u.Scheme == "" {
return nil, fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", rawurl, rawurl)
}
if u.Host == "" {
return nil, fmt.Errorf("%s url does contain a valid hostname", rawurl)
}
return u, nil
}

View file

@ -1,6 +1,11 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" package urlutil
import "testing" import (
"net/url"
"testing"
"github.com/google/go-cmp/cmp"
)
func Test_StripPort(t *testing.T) { func Test_StripPort(t *testing.T) {
t.Parallel() t.Parallel()
@ -27,3 +32,30 @@ func Test_StripPort(t *testing.T) {
}) })
} }
} }
func TestParseAndValidateURL(t *testing.T) {
tests := []struct {
name string
rawurl string
want *url.URL
wantErr bool
}{
{"good", "https://some.example", &url.URL{Scheme: "https", Host: "some.example"}, false},
{"bad schema", "//some.example", nil, true},
{"bad hostname", "https://", nil, true},
{"bad parse", "https://^", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseAndValidateURL(tt.rawurl)
if (err != nil) != tt.wantErr {
t.Errorf("ParseAndValidateURL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("TestParseAndValidateURL() = %s", diff)
}
})
}
}

View file

@ -3,6 +3,7 @@ package clients // import "github.com/pomerium/pomerium/proxy/clients"
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -23,8 +24,8 @@ func TestNew(t *testing.T) {
opts *Options opts *Options
wantErr bool wantErr bool
}{ }{
{"grpc good", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: "secret"}, false}, {"grpc good", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "secret"}, false},
{"grpc missing shared secret", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: ""}, true}, {"grpc missing shared secret", "grpc", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: ""}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -211,15 +212,17 @@ func TestNewGRPC(t *testing.T) {
wantTarget string wantTarget string
}{ }{
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""}, {"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret", ""},
{"empty connection", &Options{Addr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""}, {"empty connection", &Options{Addr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
{"both internal and addr empty", &Options{Addr: "", InternalAddr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""}, {"both internal and addr empty", &Options{Addr: nil, InternalAddr: nil, SharedSecret: "shh"}, true, "proxy/authenticator: connection address required", ""},
{"internal addr with port", &Options{Addr: "", InternalAddr: "intranet.local:8443", SharedSecret: "shh"}, false, "", "intranet.local:8443"}, {"addr with port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
{"internal addr without port", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, "", "intranet.local:443"}, {"addr without port", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"cert override", &Options{Addr: "", InternalAddr: "intranet.local", OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "intranet.local:443"}, {"internal addr with port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh"}, false, "", "localhost.example:8443"},
{"custom ca", &Options{Addr: "", InternalAddr: "intranet.local", OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "intranet.local:443"}, {"internal addr without port", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"bad ca encoding", &Options{Addr: "", InternalAddr: "intranet.local", OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "intranet.local:443"}, {"cert override", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh"}, false, "", "localhost.example:443"},
{"custom ca file", &Options{Addr: "", InternalAddr: "intranet.local", OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "intranet.local:443"}, {"custom ca", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "localhost.example:443"},
{"bad custom ca file", &Options{Addr: "", InternalAddr: "intranet.local", OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "intranet.local:443"}, {"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"},
{"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"},
{"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proto/authorize" "github.com/pomerium/pomerium/proto/authorize"
mock "github.com/pomerium/pomerium/proto/authorize/mock_authorize" mock "github.com/pomerium/pomerium/proto/authorize/mock_authorize"

View file

@ -7,15 +7,15 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/url"
"strings" "strings"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/metrics" "github.com/pomerium/pomerium/internal/metrics"
"github.com/pomerium/pomerium/internal/middleware"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
) )
const defaultGRPCPort = 443 const defaultGRPCPort = 443
@ -23,10 +23,10 @@ const defaultGRPCPort = 443
// Options contains options for connecting to a pomerium rpc service. // Options contains options for connecting to a pomerium rpc service.
type Options struct { type Options struct {
// Addr is the location of the authenticate service. e.g. "service.corp.example:8443" // Addr is the location of the authenticate service. e.g. "service.corp.example:8443"
Addr string Addr *url.URL
// InternalAddr is the internal (behind the ingress) address to use when // InternalAddr is the internal (behind the ingress) address to use when
// making a connection. If empty, Addr is used. // making a connection. If empty, Addr is used.
InternalAddr string InternalAddr *url.URL
// OverrideCertificateName overrides the server name used to verify the hostname on the // OverrideCertificateName overrides the server name used to verify the hostname on the
// returned certificates from the server. gRPC internals also use it to override the virtual // returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set. // hosting name if it is set.
@ -45,16 +45,17 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
if opts.SharedSecret == "" { if opts.SharedSecret == "" {
return nil, errors.New("proxy/clients: grpc client requires shared secret") return nil, errors.New("proxy/clients: grpc client requires shared secret")
} }
if opts.InternalAddr == nil && opts.Addr == nil {
return nil, errors.New("proxy/clients: connection address required")
}
grpcAuth := middleware.NewSharedSecretCred(opts.SharedSecret) grpcAuth := middleware.NewSharedSecretCred(opts.SharedSecret)
var connAddr string var connAddr string
if opts.InternalAddr != "" { if opts.InternalAddr != nil {
connAddr = opts.InternalAddr connAddr = opts.InternalAddr.Host
} else { } else {
connAddr = opts.Addr connAddr = opts.Addr.Host
}
if connAddr == "" {
return nil, errors.New("proxy/clients: connection address required")
} }
// no colon exists in the connection string, assume one must be added manually // no colon exists in the connection string, assume one must be added manually
if !strings.Contains(connAddr, ":") { if !strings.Contains(connAddr, ":") {

View file

@ -9,12 +9,10 @@ import (
"time" "time"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/policy"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
) )
@ -345,7 +343,6 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
CSRF: csrf.SessionID, CSRF: csrf.SessionID,
} }
templates.New().ExecuteTemplate(w, "dashboard.html", t) templates.New().ExecuteTemplate(w, "dashboard.html", t)
return
} }
// Refresh redeems and extends an existing authenticated oidc session with // Refresh redeems and extends an existing authenticated oidc session with
@ -366,8 +363,7 @@ func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) {
return return
} }
// reject a refresh if it's been less than 5 minutes to prevent a bad actor // reject a refresh if it's been less than the refresh cooldown to prevent abuse
// trying to DOS the identity provider.
if time.Since(iss) < p.refreshCooldown { if time.Since(iss) < p.refreshCooldown {
log.FromRequest(r).Error().Dur("cooldown", p.refreshCooldown).Err(err).Msg("proxy: refresh cooldown") log.FromRequest(r).Error().Dur("cooldown", p.refreshCooldown).Err(err).Msg("proxy: refresh cooldown")
httpErr := &httputil.Error{ httpErr := &httputil.Error{
@ -467,22 +463,21 @@ func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request, s *sessions
if err != nil { if err != nil {
return fmt.Errorf("proxy: session refresh failed : %v", err) return fmt.Errorf("proxy: session refresh failed : %v", err)
} }
err = p.sessionStore.SaveSession(w, r, s) if err := p.sessionStore.SaveSession(w, r, s); err != nil {
if err != nil {
return fmt.Errorf("proxy: refresh failed : %v", err) return fmt.Errorf("proxy: refresh failed : %v", err)
} }
} else { } else {
valid, err := p.AuthenticateClient.Validate(r.Context(), s.IDToken) valid, err := p.AuthenticateClient.Validate(r.Context(), s.IDToken)
if err != nil || !valid { if err != nil || !valid {
return fmt.Errorf("proxy: session valid: %v : %v", valid, err) return fmt.Errorf("proxy: session validate failed: %v : %v", valid, err)
} }
} }
return nil return nil
} }
// router attempts to find a route for a request. If a route is successfully matched, // router attempts to find a route for a request. If a route is successfully matched,
// it returns the route information and a bool value of `true`. If a route can not be matched, // it returns the route information and a bool value of `true`. If a route can
// a nil value for the route and false bool value is returned. // not be matched, a nil value for the route and false bool value is returned.
func (p *Proxy) router(r *http.Request) (http.Handler, bool) { func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
config, ok := p.routeConfigs[r.Host] config, ok := p.routeConfigs[r.Host]
if ok { if ok {
@ -494,7 +489,7 @@ func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
// policy attempts to find a policy for a request. If a policy is successfully matched, // policy attempts to find a policy for a request. If a policy is successfully matched,
// it returns the policy information and a bool value of `true`. If a policy can not be matched, // it returns the policy information and a bool value of `true`. If a policy can not be matched,
// a nil value for the policy and false bool value is returned. // a nil value for the policy and false bool value is returned.
func (p *Proxy) policy(r *http.Request) (*policy.Policy, bool) { func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) {
config, ok := p.routeConfigs[r.Host] config, ok := p.routeConfigs[r.Host]
if ok { if ok {
return &config.policy, true return &config.policy, true
@ -546,32 +541,3 @@ func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
a.RawQuery = params.Encode() a.RawQuery = params.Encode()
return a return a
} }
func extendDeadline(ttl time.Duration) time.Time {
return time.Now().Add(ttl).Truncate(time.Second)
}
// websocketHandlerFunc splits request serving with timeouts depending on the protocol
func websocketHandlerFunc(baseHandler http.Handler, timeoutHandler http.Handler, o config.Options) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do not use timeouts for websockets because they are long-lived connections.
if r.ProtoMajor == 1 &&
strings.EqualFold(r.Header.Get("Connection"), "upgrade") &&
strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
if o.AllowWebsockets {
baseHandler.ServeHTTP(w, r)
return
}
log.FromRequest(r).Warn().Msg("proxy: attempt to proxy a websocket connection, but websocket support is disabled in the configuration")
httpErr := &httputil.Error{Message: "websockets not supported by proxy", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return
}
// All other non-websocket requests are served with timeouts to prevent abuse
timeoutHandler.ServeHTTP(w, r)
})
}

View file

@ -14,7 +14,6 @@ import (
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/policy"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients" "github.com/pomerium/pomerium/proxy/clients"
) )
@ -108,13 +107,12 @@ func TestProxy_GetSignOutURL(t *testing.T) {
redirect string redirect string
wantPrefix string wantPrefix string
}{ }{
{"without scheme", "auth.corp.pomerium.io", "hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"}, {"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
{"with scheme", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
authenticateURL, _ := urlParse(tt.authenticate) authenticateURL, _ := url.Parse(tt.authenticate)
redirectURL, _ := urlParse(tt.redirect) redirectURL, _ := url.Parse(tt.redirect)
p := &Proxy{} p := &Proxy{}
// signature is ignored as it is tested above. Avoids testing time.Now // signature is ignored as it is tested above. Avoids testing time.Now
@ -135,14 +133,13 @@ func TestProxy_GetSignInURL(t *testing.T) {
wantPrefix string wantPrefix string
}{ }{
{"without scheme", "auth.corp.pomerium.io", "hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"}, {"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
{"with scheme", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p := &Proxy{SharedKey: "shared-secret"} p := &Proxy{SharedKey: "shared-secret"}
authenticateURL, _ := urlParse(tt.authenticate) authenticateURL, _ := url.Parse(tt.authenticate)
redirectURL, _ := urlParse(tt.redirect) redirectURL, _ := url.Parse(tt.redirect)
if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) { if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix) t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
@ -153,7 +150,12 @@ func TestProxy_GetSignInURL(t *testing.T) {
} }
func TestProxy_Signout(t *testing.T) { func TestProxy_Signout(t *testing.T) {
proxy, err := New(testOptions()) opts := testOptions(t)
err := ValidateOptions(opts)
if err != nil {
t.Fatal(err)
}
proxy, err := New(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -171,7 +173,7 @@ func TestProxy_Signout(t *testing.T) {
} }
func TestProxy_OAuthStart(t *testing.T) { func TestProxy_OAuthStart(t *testing.T) {
proxy, err := New(testOptions()) proxy, err := New(testOptions(t))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -184,14 +186,14 @@ func TestProxy_OAuthStart(t *testing.T) {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
} }
// expected url // expected url
expected := `<a href="https://authenticate.corp.beyondperimeter.com/sign_in` expected := `<a href="https://authenticate.example/sign_in`
body := rr.Body.String() body := rr.Body.String()
if !strings.HasPrefix(body, expected) { if !strings.HasPrefix(body, expected) {
t.Errorf("handler returned unexpected body: got %v want %v", body, expected) t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
} }
} }
func TestProxy_Handler(t *testing.T) { func TestProxy_Handler(t *testing.T) {
proxy, err := New(testOptions()) proxy, err := New(testOptions(t))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -209,32 +211,16 @@ func TestProxy_Handler(t *testing.T) {
} }
} }
func Test_extendDeadline(t *testing.T) {
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_router(t *testing.T) { func TestProxy_router(t *testing.T) {
testPolicy := policy.Policy{From: "corp.example.com", To: "example.com"} testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com"}
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
policies := []policy.Policy{testPolicy} t.Fatal(err)
}
policies := []config.Policy{testPolicy}
tests := []struct { tests := []struct {
name string name string
host string host string
mux []policy.Policy mux []config.Policy
route http.Handler route http.Handler
wantOk bool wantOk bool
}{ }{
@ -242,13 +228,13 @@ func TestProxy_router(t *testing.T) {
{"good with slash", "https://corp.example.com/", policies, nil, true}, {"good with slash", "https://corp.example.com/", policies, nil, true},
{"good with path", "https://corp.example.com/123", policies, nil, true}, {"good with path", "https://corp.example.com/123", policies, nil, true},
// {"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true}, // {"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
{"no policies", "https://notcorp.example.com/123", []policy.Policy{}, nil, false}, {"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
{"bad corp", "https://notcorp.example.com/123", policies, nil, false}, {"bad corp", "https://notcorp.example.com/123", policies, nil, false},
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false}, {"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
opts := testOptions() opts := testOptions(t)
opts.Policies = tt.mux opts.Policies = tt.mux
p, err := New(opts) p, err := New(opts)
if err != nil { if err != nil {
@ -278,11 +264,10 @@ func TestProxy_Proxy(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
opts, optsWs := testOptionsTestServer(ts.URL), testOptionsTestServer(ts.URL) opts := testOptionsTestServer(t, ts.URL)
optsCORS := testOptionsWithCORS(ts.URL) optsCORS := testOptionsWithCORS(t, ts.URL)
optsPublic := testOptionsWithPublicAccess(ts.URL) optsPublic := testOptionsWithPublicAccess(t, ts.URL)
optsNoPolicies := testOptionsWithEmptyPolicies(ts.URL) optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL)
optsWs.AllowWebsockets = true
defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{} defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{}
goodCORSHeaders.Set("origin", "anything") goodCORSHeaders.Set("origin", "anything")
@ -325,15 +310,15 @@ func TestProxy_Proxy(t *testing.T) {
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized}, {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
// no session, redirect to login // no session, redirect to login
{"no http found (no session)", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest}, {"no http found (no session)", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
// Should be expecting a 101 Switching Protocols, but expect a 200 OK because we don't have a websocket backend to respond
{"ws supported, ws connection", optsWs, http.MethodGet, headersWs, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"ws supported, http connection", optsWs, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"ws unsupported, ws connection", opts, http.MethodGet, headersWs, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest},
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{ValidateResponse: true}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := ValidateOptions(tt.options)
if err != nil {
t.Fatal(err)
}
p, err := New(tt.options) p, err := New(tt.options)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -361,7 +346,7 @@ func TestProxy_Proxy(t *testing.T) {
} }
func TestProxy_UserDashboard(t *testing.T) { func TestProxy_UserDashboard(t *testing.T) {
opts := testOptions() opts := testOptions(t)
tests := []struct { tests := []struct {
name string name string
options config.Options options config.Options
@ -409,9 +394,9 @@ func TestProxy_UserDashboard(t *testing.T) {
} }
func TestProxy_Refresh(t *testing.T) { func TestProxy_Refresh(t *testing.T) {
opts := testOptions() opts := testOptions(t)
opts.RefreshCooldown = 0 opts.RefreshCooldown = 0
timeSinceError := testOptions() timeSinceError := testOptions(t)
timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1)) timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1))
tests := []struct { tests := []struct {
@ -455,7 +440,7 @@ func TestProxy_Refresh(t *testing.T) {
} }
func TestProxy_Impersonate(t *testing.T) { func TestProxy_Impersonate(t *testing.T) {
opts := testOptions() opts := testOptions(t)
tests := []struct { tests := []struct {
name string name string
@ -535,7 +520,7 @@ func TestProxy_OAuthCallback(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
proxy, err := New(testOptions()) proxy, err := New(testOptions(t))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -576,7 +561,7 @@ func TestProxy_SignOut(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
opts := testOptions() opts := testOptions(t)
p, err := New(opts) p, err := New(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -1,6 +1,8 @@
package proxy // import "github.com/pomerium/pomerium/proxy" package proxy // import "github.com/pomerium/pomerium/proxy"
import ( import (
"crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -9,14 +11,13 @@ import (
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"strings"
"time" "time"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/metrics" "github.com/pomerium/pomerium/internal/metrics"
"github.com/pomerium/pomerium/internal/policy" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/tripper" "github.com/pomerium/pomerium/internal/tripper"
@ -44,13 +45,13 @@ func ValidateOptions(o config.Options) error {
if len(decoded) != 32 { if len(decoded) != 32 {
return fmt.Errorf("`SHARED_SECRET` want 32 but got %d bytes", len(decoded)) return fmt.Errorf("`SHARED_SECRET` want 32 but got %d bytes", len(decoded))
} }
if o.AuthenticateURL.String() == "" { if o.AuthenticateURL == nil || o.AuthenticateURL.String() == "" {
return errors.New("missing setting: authenticate-service-url") return errors.New("missing setting: authenticate-service-url")
} }
if o.AuthenticateURL.Scheme != "https" { if o.AuthenticateURL.Scheme != "https" {
return errors.New("authenticate-service-url must be a valid https url") return errors.New("authenticate-service-url must be a valid https url")
} }
if o.AuthorizeURL.String() == "" { if o.AuthorizeURL == nil || o.AuthorizeURL.String() == "" {
return errors.New("missing setting: authorize-service-url") return errors.New("missing setting: authorize-service-url")
} }
if o.AuthorizeURL.Scheme != "https" { if o.AuthorizeURL.Scheme != "https" {
@ -67,40 +68,42 @@ func ValidateOptions(o config.Options) error {
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret)) return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
} }
if len(o.SigningKey) != 0 { if len(o.SigningKey) != 0 {
_, err := base64.StdEncoding.DecodeString(o.SigningKey) decodedSigningKey, err := base64.StdEncoding.DecodeString(o.SigningKey)
if err != nil { if err != nil {
return fmt.Errorf("signing key is invalid base64: %v", err) return fmt.Errorf("signing key is invalid base64: %v", err)
} }
_, err = cryptutil.NewES256Signer(decodedSigningKey, "localhost")
if err != nil {
return fmt.Errorf("invalid signing key is : %v", err)
}
} }
return nil return nil
} }
// Proxy stores all the information associated with proxying a request. // Proxy stores all the information associated with proxying a request.
type Proxy struct { type Proxy struct {
SharedKey string // SharedKey used to mutually authenticate service communication
SharedKey string
// authenticate service
AuthenticateURL *url.URL AuthenticateURL *url.URL
AuthenticateClient clients.Authenticator AuthenticateClient clients.Authenticator
AuthorizeClient clients.Authorizer
// authorize service cipher cryptutil.Cipher
AuthorizeClient clients.Authorizer cookieName string
csrfStore sessions.CSRFStore
// session defaultUpstreamTimeout time.Duration
cipher cryptutil.Cipher redirectURL *url.URL
csrfStore sessions.CSRFStore refreshCooldown time.Duration
sessionStore sessions.SessionStore restStore sessions.SessionStore
restStore sessions.SessionStore routeConfigs map[string]*routeConfig
sessionStore sessions.SessionStore
redirectURL *url.URL signingKey string
templates *template.Template templates *template.Template
routeConfigs map[string]*routeConfig
refreshCooldown time.Duration
} }
type routeConfig struct { type routeConfig struct {
mux http.Handler mux http.Handler
policy policy.Policy policy config.Policy
} }
// New takes a Proxy service from options and a validation function. // New takes a Proxy service from options and a validation function.
@ -134,29 +137,32 @@ func New(opts config.Options) (*Proxy, error) {
return nil, err return nil, err
} }
p := &Proxy{ p := &Proxy{
SharedKey: opts.SharedKey,
routeConfigs: make(map[string]*routeConfig), routeConfigs: make(map[string]*routeConfig),
// services // services
AuthenticateURL: &opts.AuthenticateURL, AuthenticateURL: opts.AuthenticateURL,
// session state
cipher: cipher, cipher: cipher,
csrfStore: cookieStore, cookieName: opts.CookieName,
sessionStore: cookieStore, csrfStore: cookieStore,
restStore: restStore, defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
SharedKey: opts.SharedKey, redirectURL: &url.URL{Path: "/.pomerium/callback"},
redirectURL: &url.URL{Path: "/.pomerium/callback"}, refreshCooldown: opts.RefreshCooldown,
templates: templates.New(), restStore: restStore,
refreshCooldown: opts.RefreshCooldown, sessionStore: cookieStore,
signingKey: opts.SigningKey,
templates: templates.New(),
} }
err = p.UpdatePolicies(opts) if err := p.UpdatePolicies(&opts); err != nil {
if err != nil {
return nil, err return nil, err
} }
p.AuthenticateClient, err = clients.NewAuthenticateClient("grpc", p.AuthenticateClient, err = clients.NewAuthenticateClient("grpc",
&clients.Options{ &clients.Options{
Addr: opts.AuthenticateURL.Host, Addr: opts.AuthenticateURL,
InternalAddr: opts.AuthenticateInternalAddr.Host, InternalAddr: opts.AuthenticateInternalAddr,
OverrideCertificateName: opts.OverrideCertificateName, OverrideCertificateName: opts.OverrideCertificateName,
SharedSecret: opts.SharedKey, SharedSecret: opts.SharedKey,
CA: opts.CA, CA: opts.CA,
@ -167,7 +173,7 @@ func New(opts config.Options) (*Proxy, error) {
} }
p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc", p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc",
&clients.Options{ &clients.Options{
Addr: opts.AuthorizeURL.Host, Addr: opts.AuthorizeURL,
OverrideCertificateName: opts.OverrideCertificateName, OverrideCertificateName: opts.OverrideCertificateName,
SharedSecret: opts.SharedKey, SharedSecret: opts.SharedKey,
CA: opts.CA, CA: opts.CA,
@ -177,26 +183,44 @@ func New(opts config.Options) (*Proxy, error) {
} }
// UpdatePolicies updates the handlers based on the configured policies // UpdatePolicies updates the handlers based on the configured policies
func (p *Proxy) UpdatePolicies(opts config.Options) error { func (p *Proxy) UpdatePolicies(opts *config.Options) error {
routeConfigs := make(map[string]*routeConfig) routeConfigs := make(map[string]*routeConfig, len(opts.Policies))
if len(opts.Policies) == 0 {
policyCount := len(opts.Policies) log.Warn().Msg("proxy: configuration has no policies")
if policyCount == 0 {
log.Warn().Msg("proxy: loaded configuration with no policies specified")
} }
log.Info().Int("policy-count", policyCount).Msg("proxy: updated policies") for _, policy := range opts.Policies {
if err := policy.Validate(); err != nil {
return fmt.Errorf("proxy: couldn't update policies %s", err)
}
proxy := NewReverseProxy(policy.Destination)
// build http transport (roundtripper) middleware chain
// todo(bdd): this will make vet complain, it is safe
// and can be replaced with transport.Clone() in go 1.13
// https://go-review.googlesource.com/c/go/+/174597/
// https://github.com/golang/go/issues/26013#issuecomment-399481302
transport := *(http.DefaultTransport.(*http.Transport))
c := tripper.NewChain()
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy"))
if policy.TLSSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if policy.TLSCustomCA != "" {
rootCA, err := p.customCAPool(policy.TLSCustomCA)
if err != nil {
return fmt.Errorf("proxy: couldn't add custom ca to policy %s", policy.From)
}
transport.TLSClientConfig = &tls.Config{RootCAs: rootCA}
}
proxy.Transport = c.Then(&transport)
for _, route := range opts.Policies { handler, err := p.newReverseProxyHandler(proxy, &policy)
proxy := NewReverseProxy(route.Destination)
handler, err := NewReverseProxyHandler(opts, proxy, &route)
if err != nil { if err != nil {
return err return err
} }
routeConfigs[route.Source.Host] = &routeConfig{ routeConfigs[policy.Source.Host] = &routeConfig{
mux: handler, mux: handler,
policy: route, policy: policy,
} }
log.Info().Str("src", route.Source.Host).Str("dst", route.Destination.Host).Msg("proxy: new route")
} }
p.routeConfigs = routeConfigs p.routeConfigs = routeConfigs
return nil return nil
@ -204,40 +228,12 @@ func (p *Proxy) UpdatePolicies(opts config.Options) error {
// UpstreamProxy stores information for proxying the request to the upstream. // UpstreamProxy stores information for proxying the request to the upstream.
type UpstreamProxy struct { type UpstreamProxy struct {
name string name string
cookieName string handler http.Handler
handler http.Handler
signer cryptutil.JWTSigner
} }
// deleteUpstreamCookies deletes the session cookie from the request header string. // ServeHTTP handles the second (reverse-proxying) leg of pomerium's request flow
func deleteUpstreamCookies(req *http.Request, cookieName string) {
headers := []string{}
for _, cookie := range req.Cookies() {
if cookie.Name != cookieName {
headers = append(headers, cookie.String())
}
}
req.Header.Set("Cookie", strings.Join(headers, ";"))
}
func (u *UpstreamProxy) signRequest(r *http.Request) {
if u.signer != nil {
jwt, err := u.signer.SignJWT(
r.Header.Get(HeaderUserID),
r.Header.Get(HeaderEmail),
r.Header.Get(HeaderGroups))
if err == nil {
r.Header.Set(HeaderJWT, jwt)
}
}
}
// ServeHTTP signs the http request and deletes cookie headers
// before calling the upstream's ServeHTTP function.
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
deleteUpstreamCookies(r, u.cookieName)
u.signRequest(r)
u.handler.ServeHTTP(w, r) u.handler.ServeHTTP(w, r)
} }
@ -247,8 +243,6 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(to) proxy := httputil.NewSingleHostReverseProxy(to)
sublogger := log.With().Str("proxy", to.Host).Logger() sublogger := log.With().Str("proxy", to.Host).Logger()
proxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0) proxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0)
// todo(bdd): default is already http.DefaultTransport)
// proxy.Transport = defaultUpstreamTransport
director := proxy.Director director := proxy.Director
proxy.Director = func(req *http.Request) { proxy.Director = func(req *http.Request) {
// Identifies the originating IP addresses of a client connecting to // Identifies the originating IP addresses of a client connecting to
@ -257,51 +251,62 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
director(req) director(req)
req.Host = to.Host req.Host = to.Host
} }
chain := tripper.NewChain().Append(metrics.HTTPMetricsRoundTripper("proxy"))
proxy.Transport = chain.Then(nil)
return proxy return proxy
} }
// NewReverseProxyHandler applies handler specific options to a given route. // newRouteSigner creates a route specific signer.
func NewReverseProxyHandler(o config.Options, proxy *httputil.ReverseProxy, route *policy.Policy) (http.Handler, error) { func (p *Proxy) newRouteSigner(audience string) (cryptutil.JWTSigner, error) {
up := &UpstreamProxy{ decodedSigningKey, err := base64.StdEncoding.DecodeString(p.signingKey)
name: route.Destination.Host, if err != nil {
handler: proxy, return nil, err
cookieName: o.CookieName,
} }
if len(o.SigningKey) != 0 { return cryptutil.NewES256Signer(decodedSigningKey, audience)
decodedSigningKey, _ := base64.StdEncoding.DecodeString(o.SigningKey) }
signer, err := cryptutil.NewES256Signer(decodedSigningKey, route.Source.Host)
func (p *Proxy) customCAPool(cert string) (*x509.CertPool, error) {
certPool := x509.NewCertPool()
decodedCert, err := base64.StdEncoding.DecodeString(cert)
if err != nil {
return nil, fmt.Errorf("failed to decode cert: %s", err)
}
if ok := certPool.AppendCertsFromPEM(decodedCert); !ok {
return nil, fmt.Errorf("could not append cert: %s", decodedCert)
}
return certPool, nil
}
// newReverseProxyHandler applies handler specific options to a given route.
func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (http.Handler, error) {
var handler http.Handler
handler = &UpstreamProxy{
name: route.Destination.Host,
handler: rp,
}
c := middleware.NewChain()
c = c.Append(middleware.StripPomeriumCookie(p.cookieName))
// if signing key is set, add signer to middleware
if len(p.signingKey) != 0 {
signer, err := p.newRouteSigner(route.Source.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
up.signer = signer c = c.Append(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
} }
timeout := o.DefaultUpstreamTimeout // websockets cannot use the non-hijackable timeout-handler
if route.UpstreamTimeout != 0 { if !route.AllowWebsockets {
timeout = route.UpstreamTimeout timeout := p.defaultUpstreamTimeout
if route.UpstreamTimeout != 0 {
timeout = route.UpstreamTimeout
}
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", route.Destination.Host, timeout)
handler = http.TimeoutHandler(handler, timeout, timeoutMsg)
} }
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", route.Destination.Host, timeout)
timeoutHandler := http.TimeoutHandler(up, timeout, timeoutMsg)
return websocketHandlerFunc(up, timeoutHandler, o), nil
}
// urlParse wraps url.Parse to add a scheme if none-exists. return c.Then(handler), nil
// https://github.com/golang/go/issues/12585
func urlParse(uri string) (*url.URL, error) {
if !strings.Contains(uri, "://") {
uri = fmt.Sprintf("https://%s", uri)
}
return url.ParseRequestURI(uri)
} }
// UpdateOptions updates internal structures based on config.Options // UpdateOptions updates internal structures based on config.Options
func (p *Proxy) UpdateOptions(o config.Options) error { func (p *Proxy) UpdateOptions(o config.Options) error {
log.Info().Msg("proxy: updating options") return p.UpdatePolicies(&o)
err := p.UpdatePolicies(o)
if err != nil {
return fmt.Errorf("Could not update policies: %s", err)
}
return nil
} }

View file

@ -10,12 +10,19 @@ import (
"time" "time"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/policy"
) )
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
return opts
}
func TestNewReverseProxy(t *testing.T) { func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -54,14 +61,19 @@ func TestNewReverseProxyHandler(t *testing.T) {
backendHost := net.JoinHostPort(backendHostname, backendPort) backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL) proxyHandler := NewReverseProxy(proxyURL)
opts := config.NewOptions() opts := newTestOptions(t)
opts.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" opts.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
testPolicy := policy.Policy{From: "corp.example.com", To: "example.com", UpstreamTimeout: 1 * time.Second} testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com", UpstreamTimeout: 1 * time.Second}
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
handle, err := NewReverseProxyHandler(opts, proxyHandler, &testPolicy) }
p, err := New(*opts)
if err != nil { if err != nil {
t.Errorf("got %q", err) t.Fatal(err)
}
handle, err := p.newReverseProxyHandler(proxyHandler, &testPolicy)
if err != nil {
t.Fatal(err)
} }
frontend := httptest.NewServer(handle) frontend := httptest.NewServer(handle)
@ -77,109 +89,104 @@ func TestNewReverseProxyHandler(t *testing.T) {
} }
} }
func testOptions() config.Options { func testOptions(t *testing.T) config.Options {
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com") authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com") authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
opts := config.NewOptions() opts := newTestOptions(t)
testPolicy := policy.Policy{From: "corp.example.notatld", To: "example.notatld"} testPolicy := config.Policy{From: "https://corp.example.example", To: "https://example.example"}
testPolicy.Validate() opts.Policies = []config.Policy{testPolicy}
opts.Policies = []policy.Policy{testPolicy} opts.AuthenticateURL = authenticateService
opts.AuthenticateURL = *authenticateService opts.AuthorizeURL = authorizeService
opts.AuthorizeURL = *authorizeService
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
opts.CookieName = "pomerium" opts.CookieName = "pomerium"
return opts err := opts.Validate()
if err != nil {
t.Fatal(err)
}
return *opts
} }
func testOptionsTestServer(uri string) config.Options { func testOptionsTestServer(t *testing.T, uri string) config.Options {
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com") authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com") authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
// RFC 2606 testPolicy := config.Policy{
testPolicy := policy.Policy{ From: "https://httpbin.corp.example",
From: "httpbin.corp.example",
To: uri, To: uri,
} }
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
opts := config.NewOptions() t.Fatal(err)
opts.Policies = []policy.Policy{testPolicy} }
opts.AuthenticateURL = *authenticateService opts := newTestOptions(t)
opts.AuthorizeURL = *authorizeService opts.Policies = []config.Policy{testPolicy}
opts.AuthenticateURL = authenticateService
opts.AuthorizeURL = authorizeService
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
opts.CookieName = "pomerium" opts.CookieName = "pomerium"
return opts return *opts
} }
func testOptionsWithCORS(uri string) config.Options { func testOptionsWithCORS(t *testing.T, uri string) config.Options {
testPolicy := policy.Policy{ testPolicy := config.Policy{
From: "httpbin.corp.example", From: "https://httpbin.corp.example",
To: uri, To: uri,
CORSAllowPreflight: true, CORSAllowPreflight: true,
} }
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
opts := testOptionsTestServer(uri) t.Fatal(err)
opts.Policies = []policy.Policy{testPolicy} }
opts := testOptionsTestServer(t, uri)
opts.Policies = []config.Policy{testPolicy}
return opts return opts
} }
func testOptionsWithPublicAccess(uri string) config.Options { func testOptionsWithPublicAccess(t *testing.T, uri string) config.Options {
testPolicy := policy.Policy{ testPolicy := config.Policy{
From: "httpbin.corp.example", From: "https://httpbin.corp.example",
To: uri, To: uri,
AllowPublicUnauthenticatedAccess: true, AllowPublicUnauthenticatedAccess: true,
} }
testPolicy.Validate() if err := testPolicy.Validate(); err != nil {
opts := testOptions() t.Fatal(err)
opts.Policies = []policy.Policy{testPolicy}
return opts
}
func testOptionsWithPublicAccessAndWhitelist(uri string) config.Options {
testPolicy := policy.Policy{
From: "httpbin.corp.example",
To: uri,
AllowPublicUnauthenticatedAccess: true,
AllowedEmails: []string{"test@gmail.com"},
} }
testPolicy.Validate() opts := testOptions(t)
opts := testOptions() opts.Policies = []config.Policy{testPolicy}
opts.Policies = []policy.Policy{testPolicy}
return opts return opts
} }
func testOptionsWithEmptyPolicies(uri string) config.Options { func testOptionsWithEmptyPolicies(t *testing.T, uri string) config.Options {
opts := testOptionsTestServer(uri) opts := testOptionsTestServer(t, uri)
opts.Policies = []policy.Policy{} opts.Policies = []config.Policy{}
return opts return opts
} }
func TestOptions_Validate(t *testing.T) { func TestOptions_Validate(t *testing.T) {
good := testOptions() good := testOptions(t)
badAuthURL := testOptions() badAuthURL := testOptions(t)
badAuthURL.AuthenticateURL = url.URL{} badAuthURL.AuthenticateURL = nil
authurl, _ := url.Parse("http://authenticate.corp.beyondperimeter.com") authurl, _ := url.Parse("http://authenticate.corp.beyondperimeter.com")
authenticateBadScheme := testOptions() authenticateBadScheme := testOptions(t)
authenticateBadScheme.AuthenticateURL = *authurl authenticateBadScheme.AuthenticateURL = authurl
authorizeBadSCheme := testOptions() authorizeBadSCheme := testOptions(t)
authorizeBadSCheme.AuthorizeURL = *authurl authorizeBadSCheme.AuthorizeURL = authurl
authorizeNil := testOptions() authorizeNil := testOptions(t)
authorizeNil.AuthorizeURL = url.URL{} authorizeNil.AuthorizeURL = nil
emptyCookieSecret := testOptions() emptyCookieSecret := testOptions(t)
emptyCookieSecret.CookieSecret = "" emptyCookieSecret.CookieSecret = ""
invalidCookieSecret := testOptions() invalidCookieSecret := testOptions(t)
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^" invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
shortCookieLength := testOptions() shortCookieLength := testOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
invalidSignKey := testOptions() invalidSignKey := testOptions(t)
invalidSignKey.SigningKey = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^" invalidSignKey.SigningKey = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
badSharedKey := testOptions() badSharedKey := testOptions(t)
badSharedKey.SharedKey = "" badSharedKey.SharedKey = ""
sharedKeyBadBas64 := testOptions() sharedKeyBadBas64 := testOptions(t)
sharedKeyBadBas64.SharedKey = "%(*@389" sharedKeyBadBas64.SharedKey = "%(*@389"
missingPolicy := testOptions() missingPolicy := testOptions(t)
missingPolicy.Policies = []policy.Policy{} missingPolicy.Policies = []config.Policy{}
tests := []struct { tests := []struct {
name string name string
@ -197,7 +204,6 @@ func TestOptions_Validate(t *testing.T) {
{"short cookie secret", shortCookieLength, true}, {"short cookie secret", shortCookieLength, true},
{"no shared secret", badSharedKey, true}, {"no shared secret", badSharedKey, true},
{"invalid signing key", invalidSignKey, true}, {"invalid signing key", invalidSignKey, true},
{"missing policy", missingPolicy, false},
{"shared secret bad base64", sharedKeyBadBas64, true}, {"shared secret bad base64", sharedKeyBadBas64, true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -212,10 +218,10 @@ func TestOptions_Validate(t *testing.T) {
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
good := testOptions() good := testOptions(t)
shortCookieLength := testOptions() shortCookieLength := testOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
badRoutedProxy := testOptions() badRoutedProxy := testOptions(t)
badRoutedProxy.SigningKey = "YmFkIGtleQo=" badRoutedProxy.SigningKey = "YmFkIGtleQo="
tests := []struct { tests := []struct {
name string name string
@ -240,7 +246,7 @@ func TestNew(t *testing.T) {
t.Errorf("New() expected valid proxy struct") t.Errorf("New() expected valid proxy struct")
} }
if got != nil && len(got.routeConfigs) != tt.numRoutes { if got != nil && len(got.routeConfigs) != tt.numRoutes {
t.Errorf("New() = num routeConfigs \n%+v, want \n%+v", got, tt.numRoutes) t.Errorf("New() = num routeConfigs \n%+v, want \n%+v \nfrom %+v", got, tt.numRoutes, tt.opts)
} }
}) })
} }
@ -248,34 +254,65 @@ func TestNew(t *testing.T) {
func Test_UpdateOptions(t *testing.T) { func Test_UpdateOptions(t *testing.T) {
good := testOptions() good := testOptions(t)
bad := testOptions() newPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example"}
bad.SigningKey = "f" newPolicies := testOptions(t)
newPolicy := policy.Policy{To: "foo.notatld", From: "bar.notatld"} newPolicies.Policies = []config.Policy{
newPolicy.Validate()
newPolicies := []policy.Policy{
newPolicy, newPolicy,
} }
err := newPolicy.Validate()
if err != nil {
t.Fatal(err)
}
badPolicyURL := config.Policy{To: "http://", From: "http://bar.example"}
badNewPolicy := testOptions(t)
badNewPolicy.Policies = []config.Policy{
badPolicyURL,
}
disableTLSPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example", TLSSkipVerify: true}
disableTLSPolicies := testOptions(t)
disableTLSPolicies.Policies = []config.Policy{
disableTLSPolicy,
}
customCAPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}
customCAPolicies := testOptions(t)
customCAPolicies.Policies = []config.Policy{
customCAPolicy,
}
badCustomCAPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"}
badCustomCAPolicies := testOptions(t)
badCustomCAPolicies.Policies = []config.Policy{
badCustomCAPolicy,
}
tests := []struct { tests := []struct {
name string name string
opts config.Options originalOptions config.Options
newPolicy []policy.Policy updatedOptions config.Options
host string signingKey string
wantErr bool host string
wantRoute bool wantErr bool
wantRoute bool
}{ }{
{"good", good, good.Policies, "https://corp.example.notatld", false, true}, {"good no change", good, good, "", "https://corp.example.example", false, true},
{"changed", good, newPolicies, "https://bar.notatld", false, true}, {"changed", good, newPolicies, "", "https://bar.example", false, true},
{"changed and missing", good, newPolicies, "https://corp.example.notatld", false, false}, {"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
{"bad options", bad, good.Policies, "https://corp.example.notatld", true, false}, // todo(bdd): not sure what intent of this test is?
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
// todo: stand up a test server using self signed certificates
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := tt.opts p, err := New(tt.originalOptions)
p, _ := New(o) if err != nil {
t.Fatal(err)
}
o.Policies = tt.newPolicy p.signingKey = tt.signingKey
err := p.UpdateOptions(o) err = p.UpdateOptions(tt.updatedOptions)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UpdateOptions: err = %v, wantErr = %v", err, tt.wantErr) t.Errorf("UpdateOptions: err = %v, wantErr = %v", err, tt.wantErr)
return return