all: support insecure mode

- pomerium/authenticate: add cookie secure setting
- internal/config: transport security validation moved to options
- internal/config: certificate struct hydrated
- internal/grpcutil: add grpc server mirroring http one
- internal/grpcutil: move grpc middleware
- cmd/pomerium: use run wrapper around main to pass back errors
- cmd/pomerium: add waitgroup (block on) all servers http/grpc

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-09-30 23:50:39 -07:00
parent 40920b9092
commit df822a4bae
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
26 changed files with 1039 additions and 1090 deletions

View file

@ -45,6 +45,7 @@ type Authenticate struct {
RedirectURL *url.URL
cookieName string
cookieSecure bool
cookieDomain string
cookieSecret []byte
templates *template.Template
@ -108,5 +109,6 @@ func New(opts config.Options) (*Authenticate, error) {
cookieSecret: decodedCookieSecret,
cookieName: opts.CookieName,
cookieDomain: opts.CookieDomain,
cookieSecure: opts.CookieSecure,
}, nil
}

View file

@ -7,14 +7,19 @@ import (
)
func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
opts := config.NewDefaultOptions()
opts.AuthenticateURLString = "https://authenticate.example"
opts.AuthorizeURLString = "https://authorize.example"
opts.InsecureServer = true
opts.ClientID = "client-id"
opts.Provider = "google"
opts.ClientSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
err := opts.Validate()
if err != nil {
t.Fatal(err)
}
return opts
}

View file

@ -37,6 +37,7 @@ func (a *Authenticate) Handler() http.Handler {
r.Use(middleware.SetHeaders(CSPHeaders))
r.Use(csrf.Protect(
a.cookieSecret,
csrf.Secure(a.cookieSecure),
csrf.Path("/"),
csrf.Domain(a.cookieDomain),
csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler

View file

@ -4,7 +4,7 @@ import (
"flag"
"fmt"
"net/http"
"os"
"sync"
"time"
"github.com/fsnotify/fsnotify"
@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/authorize"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/grpcutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
@ -30,36 +31,47 @@ var versionFlag = flag.Bool("version", false, "prints the version")
var configFile = flag.String("config", "", "Specify configuration file location")
func main() {
if err := run(); err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium")
}
}
func run() error {
flag.Parse()
if *versionFlag {
fmt.Println(version.FullVersion())
os.Exit(0)
return nil
}
opt, err := config.ParseOptions(*configFile)
opt, err := config.NewOptionsFromConfig(*configFile)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: options")
return err
}
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
setupMetrics(opt)
setupTracing(opt)
setupHTTPRedirectServer(opt)
r := newGlobalRouter(opt)
grpcServer := setupGRPCServer(opt)
_, err = newAuthenticateService(*opt, r.Host(urlutil.StripPort(opt.AuthenticateURL.Host)).Subrouter())
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
// since we can have multiple listeners, we create a wait group
var wg sync.WaitGroup
if err := setupMetrics(opt, &wg); err != nil {
return err
}
if err := setupTracing(opt); err != nil {
return err
}
if err := setupHTTPRedirectServer(opt, &wg); err != nil {
return err
}
authz, err := newAuthorizeService(*opt, grpcServer)
r := newGlobalRouter(opt)
_, err = newAuthenticateService(*opt, r)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authorize")
return err
}
authz, err := newAuthorizeService(*opt, &wg)
if err != nil {
return err
}
proxy, err := newProxyService(*opt, r)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
return err
}
if proxy != nil {
defer proxy.AuthorizeClient.Close()
@ -71,13 +83,15 @@ func main() {
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy})
})
srv, err := httputil.NewTLSServer(configToServerOptions(opt), r, grpcServer)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium")
}
httputil.Shutdown(srv)
os.Exit(0)
srv, err := httputil.NewServer(httpServerOptions(opt), r, &wg)
if err != nil {
return err
}
go httputil.Shutdown(srv)
// Blocks and waits until ALL WaitGroup members have signaled completion
wg.Wait()
return nil
}
func newAuthenticateService(opt config.Options, r *mux.Router) (*authenticate.Authenticate, error) {
@ -88,19 +102,33 @@ func newAuthenticateService(opt config.Options, r *mux.Router) (*authenticate.Au
if err != nil {
return nil, err
}
r.PathPrefix("/").Handler(service.Handler())
sr := r.Host(urlutil.StripPort(opt.AuthenticateURL.Host)).Subrouter()
sr.PathPrefix("/").Handler(service.Handler())
return service, nil
}
func newAuthorizeService(opt config.Options, rpc *grpc.Server) (*authorize.Authorize, error) {
func newAuthorizeService(opt config.Options, wg *sync.WaitGroup) (*authorize.Authorize, error) {
if !config.IsAuthorize(opt.Services) {
return nil, nil
}
log.Info().Interface("opts", opt).Msg("newAuthorizeService")
service, err := authorize.New(opt)
if err != nil {
return nil, err
}
pbAuthorize.RegisterAuthorizerServer(rpc, service)
regFn := func(s *grpc.Server) {
pbAuthorize.RegisterAuthorizerServer(s, service)
}
so := &grpcutil.ServerOptions{
Addr: opt.GRPCAddr,
SharedKey: opt.SharedKey,
}
if !opt.GRPCInsecure {
so.TLSCertificate = opt.TLSCertificate
}
grpcSrv := grpcutil.NewServer(so, regFn, wg)
go grpcutil.Shutdown(grpcSrv)
return service, nil
}
@ -146,35 +174,25 @@ func newGlobalRouter(o *config.Options) *mux.Router {
return mux
}
func configToServerOptions(opt *config.Options) *httputil.ServerOptions {
return &httputil.ServerOptions{
Addr: opt.Addr,
Cert: opt.Cert,
Key: opt.Key,
CertFile: opt.CertFile,
KeyFile: opt.KeyFile,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
ReadHeaderTimeout: opt.ReadHeaderTimeout,
IdleTimeout: opt.IdleTimeout,
}
}
func setupMetrics(opt *config.Options) {
func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error {
if opt.MetricsAddr != "" {
if handler, err := metrics.PrometheusHandler(); err != nil {
log.Error().Err(err).Msg("cmd/pomerium: metrics failed to start")
} else {
metrics.SetBuildInfo(opt.Services)
metrics.RegisterInfoMetrics()
serverOpts := &httputil.ServerOptions{Addr: opt.MetricsAddr}
srv := httputil.NewHTTPServer(serverOpts, handler)
go httputil.Shutdown(srv)
handler, err := metrics.PrometheusHandler()
if err != nil {
return err
}
metrics.SetBuildInfo(opt.Services)
metrics.RegisterInfoMetrics()
serverOpts := &httputil.ServerOptions{Addr: opt.MetricsAddr}
srv, err := httputil.NewServer(serverOpts, handler, wg)
if err != nil {
return err
}
go httputil.Shutdown(srv)
}
return nil
}
func setupTracing(opt *config.Options) {
func setupTracing(opt *config.Options) error {
if opt.TracingProvider != "" {
tracingOpts := &trace.TracingOptions{
Provider: opt.TracingProvider,
@ -184,25 +202,31 @@ func setupTracing(opt *config.Options) {
JaegerCollectorEndpoint: opt.TracingJaegerCollectorEndpoint,
}
if err := trace.RegisterTracing(tracingOpts); err != nil {
log.Error().Err(err).Msg("cmd/pomerium: couldn't register tracing")
} else {
log.Info().Interface("options", tracingOpts).Msg("cmd/pomerium: metrics configured")
return err
}
}
return nil
}
func setupHTTPRedirectServer(opt *config.Options) {
func setupHTTPRedirectServer(opt *config.Options, wg *sync.WaitGroup) error {
if opt.HTTPRedirectAddr != "" {
serverOpts := httputil.ServerOptions{Addr: opt.HTTPRedirectAddr}
srv := httputil.NewHTTPServer(&serverOpts, httputil.RedirectHandler())
srv, err := httputil.NewServer(&serverOpts, httputil.RedirectHandler(), wg)
if err != nil {
return err
}
go httputil.Shutdown(srv)
}
return nil
}
func setupGRPCServer(opt *config.Options) *grpc.Server {
grpcAuth := middleware.NewSharedSecretCred(opt.SharedKey)
grpcOpts := []grpc.ServerOption{
grpc.UnaryInterceptor(grpcAuth.ValidateRequest),
grpc.StatsHandler(metrics.NewGRPCServerStatsHandler(opt.Services))}
return grpc.NewServer(grpcOpts...)
func httpServerOptions(opt *config.Options) *httputil.ServerOptions {
return &httputil.ServerOptions{
Addr: opt.Addr,
TLSCertificate: opt.TLSCertificate,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
ReadHeaderTimeout: opt.ReadHeaderTimeout,
IdleTimeout: opt.IdleTimeout,
}
}

View file

@ -3,12 +3,12 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/signal"
"reflect"
"sync"
"syscall"
"testing"
"time"
@ -16,151 +16,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/middleware"
"google.golang.org/grpc"
)
func Test_newAuthenticateService(t *testing.T) {
mux := httputil.NewRouter()
tests := []struct {
name string
s string
Field string
Value string
wantHostname string
wantErr bool
}{
{"wrong service", "proxy", "", "", "", false},
{"bad", "authenticate", "SharedKey", "error!", "", true},
{"good", "authenticate", "ClientID", "test", "auth.server.com", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
testOpts.Provider = "google"
testOpts.ClientSecret = "TEST"
testOpts.SharedKey = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testOpts.Services = tt.s
if tt.Field != "" {
testOptsField := reflect.ValueOf(testOpts).Elem().FieldByName(tt.Field)
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
}
_, err = newAuthenticateService(*testOpts, mux)
if (err != nil) != tt.wantErr {
t.Errorf("newAuthenticateService() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_newAuthorizeService(t *testing.T) {
os.Clearenv()
grpcAuth := middleware.NewSharedSecretCred("test")
grpcOpts := []grpc.ServerOption{grpc.UnaryInterceptor(grpcAuth.ValidateRequest)}
grpcServer := grpc.NewServer(grpcOpts...)
tests := []struct {
name string
s string
Field string
Value string
wantErr bool
}{
{"wrong service", "proxy", "", "", false},
{"bad option parsing", "authorize", "SharedKey", "false", true},
{"good", "authorize", "SharedKey", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testOpts, err := config.NewMinimalOptions("https://some.example", "https://some.example")
if err != nil {
t.Fatal(err)
}
testOpts.Services = tt.s
testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testPolicy := config.Policy{From: "http://some.example", To: "https://some.example"}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
testOpts.Policies = []config.Policy{
testPolicy,
}
if tt.Field != "" {
testOptsField := reflect.ValueOf(testOpts).Elem().FieldByName(tt.Field)
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
}
_, err = newAuthorizeService(*testOpts, grpcServer)
if (err != nil) != tt.wantErr {
t.Errorf("newAuthorizeService() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_newProxyeService(t *testing.T) {
os.Clearenv()
tests := []struct {
name string
s string
Field string
Value string
wantErr bool
}{
{"wrong service", "authenticate", "", "", false},
{"bad option parsing", "proxy", "SharedKey", "false", true},
{"good", "proxy", "SharedKey", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mux := httputil.NewRouter()
testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
testPolicy := config.Policy{From: "http://some.example", To: "http://some.example"}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
testOpts.Policies = []config.Policy{
testPolicy,
}
AuthenticateURL, _ := url.Parse("https://authenticate.example.com")
AuthorizeURL, _ := url.Parse("https://authorize.example.com")
testOpts.AuthenticateURL = AuthenticateURL
testOpts.AuthorizeURL = AuthorizeURL
testOpts.CookieSecret = "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="
testOpts.Services = tt.s
if tt.Field != "" {
testOptsField := reflect.ValueOf(testOpts).Elem().FieldByName(tt.Field)
testOptsField.Set(reflect.ValueOf(tt).FieldByName("Value"))
}
_, err = newProxyService(*testOpts, mux)
if (err != nil) != tt.wantErr {
t.Errorf("newProxyService() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_newGlobalRouter(t *testing.T) {
o := config.Options{
Services: "all",
@ -192,7 +49,7 @@ func Test_newGlobalRouter(t *testing.T) {
}
}
func Test_configToServerOptions(t *testing.T) {
func Test_httpServerOptions(t *testing.T) {
tests := []struct {
name string
opt *config.Options
@ -202,25 +59,8 @@ func Test_configToServerOptions(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if diff := cmp.Diff(configToServerOptions(tt.opt), tt.want); diff != "" {
t.Errorf("configToServerOptions() = \n %s", diff)
}
})
}
}
func Test_setupGRPCServer(t *testing.T) {
tests := []struct {
name string
opt *config.Options
dontWant *grpc.Server
}{
{"good", &config.Options{SharedKey: "test"}, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if diff := cmp.Diff(setupGRPCServer(tt.opt), tt.dontWant); diff == "" {
t.Errorf("setupGRPCServer() = \n %s", diff)
if diff := cmp.Diff(httpServerOptions(tt.opt), tt.want); diff != "" {
t.Errorf("httpServerOptions() = \n %s", diff)
}
})
}
@ -255,7 +95,9 @@ func Test_setupMetrics(t *testing.T) {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
setupMetrics(tt.opt)
var wg sync.WaitGroup
setupMetrics(tt.opt, &wg)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
@ -265,18 +107,26 @@ func Test_setupMetrics(t *testing.T) {
func Test_setupHTTPRedirectServer(t *testing.T) {
tests := []struct {
name string
opt *config.Options
name string
opt *config.Options
wantErr bool
}{
{"dont register aything", &config.Options{}},
{"good redirect server", &config.Options{HTTPRedirectAddr: "localhost:0"}},
{"dont register anything", &config.Options{}, false},
{"good redirect server", &config.Options{HTTPRedirectAddr: "localhost:0"}, false},
{"bad redirect server port", &config.Options{HTTPRedirectAddr: "localhost:-1"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := make(chan os.Signal, 1)
var wg sync.WaitGroup
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
setupHTTPRedirectServer(tt.opt)
err := setupHTTPRedirectServer(tt.opt, &wg)
if (err != nil) != tt.wantErr {
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
}
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
@ -294,3 +144,167 @@ func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
t.Fatalf("timeout waiting for %v", sig)
}
}
func Test_run(t *testing.T) {
os.Clearenv()
t.Parallel()
tests := []struct {
name string
versionFlag bool
configFileFlag string
wantErr bool
}{
{"simply print version", true, "", false},
{"nil configuration", false, "", true},
{"simple proxy", false, `
{
"address": ":9433",
"grpc_address": ":9444",
"grpc_insecure": true,
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "proxy",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, false},
{"simple authorize", false, `
{
"address": ":9433",
"grpc_address": ":9444",
"grpc_insecure": false,
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "authorize",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, false},
{"bad proxy no authenticate url", false, `
{
"address": ":9433",
"grpc_address": ":9444",
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "proxy",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
{"bad authenticate no cookie secret", false, `
{
"address": ":9433",
"grpc_address": ":9444",
"insecure_server": true,
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "authenticate",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
{"bad authorize service bad shared key", false, `
{
"address": ":9433",
"grpc_address": ":9444",
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"shared_secret": "^^^",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "authorize",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
{"bad http port", false, `
{
"address": ":-1",
"grpc_address": ":9444",
"grpc_insecure": true,
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "proxy",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
{"bad redirect port", false, `
{
"address": ":9433",
"http_redirect_addr":":-1",
"grpc_address": ":9444",
"grpc_insecure": true,
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "proxy",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
{"bad metrics port ", false, `
{
"address": ":9433",
"metrics_address": ":-1",
"grpc_insecure": true,
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "proxy",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
{"malformed tracing provider", false, `
{
"tracing_provider": "bad tracing provider",
"address": ":9433",
"grpc_address": ":9444",
"grpc_insecure": true,
"insecure_server": true,
"authorize_service_url": "https://authorize.corp.example",
"authenticate_service_url": "https://authenticate.corp.example",
"shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=",
"services": "proxy",
"policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }]
}
`, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
versionFlag = &tt.versionFlag
tmpFile, err := ioutil.TempFile(os.TempDir(), "*.json")
if err != nil {
t.Fatal("Cannot create temporary file", err)
}
defer os.Remove(tmpFile.Name())
fn := tmpFile.Name()
if _, err := tmpFile.Write([]byte(tt.configFileFlag)); err != nil {
tmpFile.Close()
t.Fatal(err)
}
configFile = &fn
proc, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatal(err)
}
go func() {
time.Sleep(time.Millisecond * 500)
proc.Signal(os.Interrupt)
}()
err = run()
if (err != nil) != tt.wantErr {
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -17,6 +17,7 @@
- Fixed an issue where CSRF would fail if multiple tabs were open. [GH-306](https://github.com/pomerium/pomerium/issues/306)
- Fixed an issue where pomerium would clean double slashes from paths.[GH-262](https://github.com/pomerium/pomerium/issues/262)
- Fixed a bug where the impersonate form would persist an empty string for groups value if none set.[GH-303](https://github.com/pomerium/pomerium/issues/303)
### Changed
@ -24,15 +25,13 @@
- Authenticate service no longer uses gRPC.
- The global request logger now captures the full array of proxies from `X-Forwarded-For`, in addition to just the client IP.
- Options code refactored to eliminate global Viper state. [GH-332](https://github.com/pomerium/pomerium/pull/332/files)
- Pomerium will no longer default to looking for certificates in the root directory. [GH-328](https://github.com/pomerium/pomerium/issues/328)
- Pomerium will validate that either `insecure_server`, or a valid certificate bundle is set. [GH-328](https://github.com/pomerium/pomerium/issues/328)
### Removed
- Removed `AUTHENTICATE_INTERNAL_URL`/`authenticate_internal_url` which is no longer used.
## Fixed
- Fixed a bug where the impersonate form would persist an empty string for groups value if none set.[GH-303](https://github.com/pomerium/pomerium/issues/303)
## v0.3.0
### New

View file

@ -45,7 +45,7 @@ Service mode sets the pomerium service(s) to run. If testing, you may want to se
- Default: `:443`
- Required
Address specifies the host and port to serve HTTPS and gRPC requests from. If empty, `:443` is used.
Address specifies the host and port to serve HTTP requests from. If empty, `:443` is used.
## Administrators
@ -112,6 +112,21 @@ If `false`
Log level sets the global logging level for pomerium. Only logs of the desired level and above will be logged.
## Insecure Server
- Environmental Variable: `INSECURE_SERVER`
- Config File Key: `insecure_server`
- Type: `bool`
- Required if certificates unset
Turning on insecure server mode will result in pomerium starting, and operating without any protocol encryption in transit.
This setting can be useful in a situation where you have Pomerium behind a TLS terminating ingress or proxy. However, even in that case, it is highly recommended to use TLS to protect the confidentiality and integrity of service communication even behind the ingress using self-signed certificates or an internal CA. Please see our helm-chart for an example of just that.
:::warning
Pomerium should _never_ be exposed to the internet without TLS encryption.
:::
## Certificate
- Environmental Variable: either `CERTIFICATE` or `CERTIFICATE_FILE`
@ -119,7 +134,7 @@ Log level sets the global logging level for pomerium. Only logs of the desired l
- Type: [base64 encoded] `string` or relative file location
- Required
Certificate is the x509 _public-key_ used to establish secure HTTP and gRPC connections. If unset, pomerium will attempt to find and use `./cert.pem`.
Certificate is the x509 _public-key_ used to establish secure HTTP and gRPC connections.
## Certificate Key
@ -128,7 +143,7 @@ Certificate is the x509 _public-key_ used to establish secure HTTP and gRPC conn
- Type: [base64 encoded] `string`
- Required
Certificate key is the x509 _private-key_ used to establish secure HTTP and gRPC connections. If unset, pomerium will attempt to find and use `./privkey.pem`.
Certificate key is the x509 _private-key_ used to establish secure HTTP and gRPC connections.
## Global Timeouts
@ -148,9 +163,28 @@ Timeouts set the global server timeouts. For route-specific timeouts, see [polic
These settings control upstream connections to the Authorize service.
## GRPC Address
- Environmental Variable: `GRPC_ADDRESS`
- Config File Key: `grpc_address`
- Type: `string`
- Example: `:443`, `:8443`
- Default: `:443` or `:5443` if in all-in-one mode
Address specifies the host and port to serve GRPC requests from. Defaults to `:443` (or `:5443` in all in one mode).
## GRPC Insecure
- Environmental Variable: `GRPC_INSECURE`
- Config File Key: `grpc_insecure`
- Type: `bool`
- Default: `:443` (or `:5443` if in all-in-one mode)
If set, GRPC Insecure disables transport security for communication between the proxy and authorize components. If running in all-in-one mode, defaults to true as communication will run over localhost's own socket.
### GRPC Client Timeout
Maxmimum time before canceling an upstream RPC request. During transient failures, the proxy will retry upstreams for this duration, if possible. You should leave this high enough to handle backend service restart and rediscovery so that client requests do not fail.
Maximum time before canceling an upstream RPC request. During transient failures, the proxy will retry upstreams for this duration, if possible. You should leave this high enough to handle backend service restart and rediscovery so that client requests do not fail.
- Environmental Variable: `GRPC_CLIENT_TIMEOUT`
- Config File Key: `grpc_client_timeout`

View file

@ -7,6 +7,12 @@ description: >-
# Upgrade Guide
## Since 0.3.0
### Breaking: No default certificate location
In previous versions, if no explicit certificate pair (in base64 or file form) was set, Pomerium would make a last ditch effort to check for certificate files (`cert.key`/`privkey.pem`) in the root directory. With the introduction of insecure server configuration, we've removed that functionality. If there settings for certificates and insecure server mode are unset, pomerium will give a appropriate error instead of a failed to find/open certificate error.
## Since 0.2.0
Pomerium `v0.3.0` has no known breaking changes compared to `v0.2.0`.

7
go.mod
View file

@ -21,17 +21,18 @@ require (
github.com/rs/zerolog v1.14.3
github.com/spf13/afero v1.2.2 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/spf13/viper v1.4.0
github.com/stretchr/testify v1.3.0 // indirect
github.com/stretchr/testify v1.4.0 // indirect
go.opencensus.io v0.22.0
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8
golang.org/x/net v0.0.0-20190611141213-3f473d35a33a
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae // indirect
golang.org/x/sys v0.0.0-20190927073244-c990c680b611 // indirect
google.golang.org/api v0.6.0
google.golang.org/appengine v1.6.1 // indirect
google.golang.org/genproto v0.0.0-20190611190212-a7e196e89fd3 // indirect
google.golang.org/grpc v1.22.0
gopkg.in/square/go-jose.v2 v2.3.1
gopkg.in/yaml.v2 v2.2.2
gopkg.in/yaml.v2 v2.2.3
)

12
go.sum
View file

@ -171,13 +171,15 @@ github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmq
github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo=
github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.4.0 h1:yXHLWeravcrgGyFSyCgdYpXQ9dR9c/WED3pg1RhxqEU=
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
@ -245,8 +247,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae h1:xiXzMMEQdQcric9hXtr1QU98MHunKK7OTtsoU6bYWs4=
golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190927073244-c990c680b611 h1:q9u40nxWT5zRClI/uU9dHCiYGottAg6Nzz4YUQyHxdA=
golang.org/x/sys v0.0.0-20190927073244-c990c680b611/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -305,6 +307,8 @@ gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bl
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.3 h1:fvjTMHxHEw/mxHbtzPi3JCcKXQRAnQTBRo6YCJSVHKI=
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099 h1:XJP7lxbSxWLOMNdBE4B/STaqVy6L73o0knwj2vIlxnw=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -1,13 +1,24 @@
package config // import "github.com/pomerium/pomerium/internal/config"
const (
// ServiceAll represents running all services in "all-in-one" mode
ServiceAll = "all"
// ServiceProxy represents running the proxy service component
ServiceProxy = "proxy"
// ServiceAuthorize represents running the authorize service component
ServiceAuthorize = "authorize"
// ServiceAuthenticate represents running the authenticate service component
ServiceAuthenticate = "authenticate"
)
// IsValidService checks to see if a service is a valid service mode
func IsValidService(s string) bool {
switch s {
case
"all",
"proxy",
"authorize",
"authenticate":
ServiceAll,
ServiceProxy,
ServiceAuthorize,
ServiceAuthenticate:
return true
}
return false
@ -17,8 +28,8 @@ func IsValidService(s string) bool {
func IsAuthenticate(s string) bool {
switch s {
case
"all",
"authenticate":
ServiceAll,
ServiceAuthenticate:
return true
}
return false
@ -28,8 +39,8 @@ func IsAuthenticate(s string) bool {
func IsAuthorize(s string) bool {
switch s {
case
"all",
"authorize":
ServiceAll,
ServiceAuthorize:
return true
}
return false
@ -39,9 +50,14 @@ func IsAuthorize(s string) bool {
func IsProxy(s string) bool {
switch s {
case
"all",
"proxy":
ServiceAll,
ServiceProxy:
return true
}
return false
}
// IsAll checks to see if we should be running all services
func IsAll(s string) bool {
return s == ServiceAll
}

View file

@ -1,18 +1,17 @@
package config // import "github.com/pomerium/pomerium/internal/config"
import (
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"net/url"
"path/filepath"
"reflect"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
@ -25,8 +24,8 @@ import (
// DisableHeaderKey is the key used to check whether to disable setting header
const DisableHeaderKey = "disable"
// Options are the global environmental flags used to set up pomerium's services. Use NewXXXOptions() methods
// for a safely initialized data structure.
// Options are the global environmental flags used to set up pomerium's services.
// Use NewXXXOptions() methods for a safely initialized data structure.
type Options struct {
// Debug outputs human-readable logs to Stdout.
Debug bool `mapstructure:"pomerium_debug"`
@ -47,14 +46,22 @@ type Options struct {
// HTTPS requests. If empty, ":443" (localhost:443) is used.
Addr string `mapstructure:"address"`
// Cert and Key specifies the TLS certificates to use.
// InsecureServer when enabled disables all transport security.
// In this mode, Pomerium is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureServer bool `mapstructure:"insecure_server"`
// Cert and Key is the x509 certificate used to hydrate TLSCertificate
Cert string `mapstructure:"certificate"`
Key string `mapstructure:"certificate_key"`
// CertFile and KeyFile specifies the TLS certificates to use.
// CertFile and KeyFile is the x509 certificate used to hydrate TLSCertificate
CertFile string `mapstructure:"certificate_file"`
KeyFile string `mapstructure:"certificate_key_file"`
// TLSCertificate is the hydrated tls.Certificate.
TLSCertificate *tls.Certificate
// HttpRedirectAddr, if set, specifies the host and port to run the HTTP
// to HTTPS redirect server on. If empty, no redirect server is started.
HTTPRedirectAddr string `mapstructure:"http_redirect_addr"`
@ -131,7 +138,7 @@ type Options struct {
TracingDebug bool `mapstructure:"tracing_debug"`
// Jaeger
//
// CollectorEndpoint is the full url to the Jaeger HTTP Thrift collector.
// For example, http://localhost:14268/api/traces
TracingJaegerCollectorEndpoint string `mapstructure:"tracing_jaeger_collector_endpoint"`
@ -140,13 +147,22 @@ type Options struct {
TracingJaegerAgentEndpoint string `mapstructure:"tracing_jaeger_agent_endpoint"`
// GRPC Service Settings
// GRPCAddr specifies the host and port on which the server should serve
// gRPC requests. If running in all-in-one mode, ":5443" (localhost:5443) is used.
GRPCAddr string `mapstructure:"grpc_address"`
// GRPCInsecure disables transport security.
// If running in all-in-one mode, defaults to true.
GRPCInsecure bool `mapstructure:"grpc_insecure"`
GRPCClientTimeout time.Duration `mapstructure:"grpc_client_timeout"`
GRPCClientDNSRoundRobin bool `mapstructure:"grpc_client_dns_roundrobin"`
// Scoped viper instance
viper *viper.Viper
}
// DefaultOptions are the default configuration options for pomerium
var defaultOptions = Options{
Debug: false,
LogLevel: "debug",
@ -164,47 +180,50 @@ var defaultOptions = Options{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
},
Addr: ":443",
CertFile: filepath.Join(fileutil.Getwd(), "cert.pem"),
KeyFile: filepath.Join(fileutil.Getwd(), "privkey.pem"),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // support streaming by default
IdleTimeout: 5 * time.Minute,
RefreshCooldown: 5 * time.Minute,
GRPCAddr: ":443",
GRPCClientTimeout: 10 * time.Second, // Try to withstand transient service failures for a single request
GRPCClientDNSRoundRobin: true,
}
// NewOptions creates a new Options struct with only viper initialized
func NewOptions() *Options {
o := Options{}
o.viper = viper.New()
return &o
}
// NewDefaultOptions returns an Options struct with defaults set and viper initialized
// NewDefaultOptions returns a copy the default options. It's the caller's
// responsibility to do a follow up Validate call.
func NewDefaultOptions() *Options {
o := defaultOptions
o.viper = viper.New()
return &o
newOpts := defaultOptions
newOpts.viper = viper.New()
return &newOpts
}
// NewMinimalOptions returns a minimal options configuration built from default options.
// Any modifications to the structure should be followed up by a subsequent
// call to validate.
func NewMinimalOptions(authenticateURL, authorizeURL string) (*Options, error) {
o := NewDefaultOptions()
o.AuthenticateURLString = authenticateURL
o.AuthorizeURLString = authorizeURL
if err := o.Validate(); err != nil {
return nil, fmt.Errorf("internal/config: validation error %s", err)
// NewOptionsFromConfig builds the main binary's configuration options by parsing
// environmental variables and config file
func NewOptionsFromConfig(configFile string) (*Options, error) {
o, err := optionsFromViper(configFile)
if err != nil {
return nil, fmt.Errorf("internal/config: options from viper %w", err)
}
if o.Debug {
log.SetDebugMode()
}
if o.LogLevel != "" {
log.SetLevel(o.LogLevel)
}
metrics.AddPolicyCountCallback(o.Services, func() int64 {
return int64(len(o.Policies))
})
checksumDec, err := strconv.ParseUint(o.Checksum(), 16, 64)
if err != nil {
log.Warn().Err(err).Msg("internal/config: could not parse config checksum into decimal")
}
metrics.SetConfigChecksum(o.Services, checksumDec)
return o, nil
}
// OptionsFromViper builds the main binary's configuration
// options by parsing environmental variables and config file
func OptionsFromViper(configFile string) (*Options, error) {
func optionsFromViper(configFile string) (*Options, error) {
// start a copy of the default options
o := NewDefaultOptions()
// New viper instance to save into Options later
@ -218,71 +237,22 @@ func OptionsFromViper(configFile string) (*Options, error) {
if configFile != "" {
v.SetConfigFile(configFile)
if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("internal/config: failed to read config: %s", err)
return nil, fmt.Errorf("failed to read config: %w", err)
}
}
if err := v.Unmarshal(&o); err != nil {
return nil, fmt.Errorf("internal/config: failed to unmarshal config: %s", err)
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
o.viper = v
if err := o.Validate(); err != nil {
return nil, fmt.Errorf("internal/config: validation error %s", err)
return nil, fmt.Errorf("validation error %w", err)
}
return o, nil
}
// Validate ensures the Options fields are properly formed, present, and hydrated.
func (o *Options) Validate() error {
if !IsValidService(o.Services) {
return fmt.Errorf("%s is an invalid service type", o.Services)
}
// shared key must be set for all modes other than "all"
if o.SharedKey == "" {
if o.Services == "all" {
o.SharedKey = cryptutil.NewBase64Key()
} else {
return errors.New("shared-key cannot be empty")
}
}
if o.AuthenticateURLString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthenticateURLString)
if err != nil {
return fmt.Errorf("bad authenticate-url %s : %v", o.AuthenticateURLString, err)
}
o.AuthenticateURL = u
}
if o.AuthorizeURLString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthorizeURLString)
if err != nil {
return fmt.Errorf("bad authorize-url %s : %v", o.AuthorizeURLString, err)
}
o.AuthorizeURL = u
}
if o.PolicyFile != "" {
return errors.New("policy file setting is deprecated")
}
if err := o.parsePolicy(); err != nil {
return fmt.Errorf("failed to parse policy: %s", err)
}
if err := o.parseHeaders(); err != nil {
return fmt.Errorf("failed to parse headers: %s", err)
}
if _, disable := o.Headers[DisableHeaderKey]; disable {
o.Headers = make(map[string]string)
}
return nil
}
// parsePolicy initializes policy to the options from either base64 environmental
// variables or from a file
func (o *Options) parsePolicy() error {
@ -291,12 +261,12 @@ func (o *Options) parsePolicy() error {
if o.PolicyEnv != "" {
policyBytes, err := base64.StdEncoding.DecodeString(o.PolicyEnv)
if err != nil {
return fmt.Errorf("could not decode POLICY env var: %s", err)
return fmt.Errorf("could not decode POLICY env var: %w", err)
}
if err := yaml.Unmarshal(policyBytes, &policies); err != nil {
return fmt.Errorf("could not unmarshal policy yaml: %s", err)
return fmt.Errorf("could not unmarshal policy yaml: %w", err)
}
} else if err := o.viper.UnmarshalKey("policy", &policies); err != nil {
} else if err := o.viperUnmarshalKey("policy", &policies); err != nil {
return err
}
if len(policies) != 0 {
@ -311,6 +281,18 @@ func (o *Options) parsePolicy() error {
return nil
}
func (o *Options) viperUnmarshalKey(key string, rawVal interface{}) error {
return o.viper.UnmarshalKey(key, &rawVal)
}
func (o *Options) viperSet(key string, value interface{}) {
o.viper.Set(key, value)
}
func (o *Options) viperIsSet(key string) bool {
return o.viper.IsSet(key)
}
// parseHeaders handles unmarshalling any custom headers correctly from the
// environment or viper's parsed keys
func (o *Options) parseHeaders() error {
@ -333,8 +315,8 @@ func (o *Options) parseHeaders() error {
}
o.Headers = headers
} else if o.viper.IsSet("headers") {
if err := o.viper.UnmarshalKey("headers", &headers); err != nil {
} else if o.viperIsSet("headers") {
if err := o.viperUnmarshalKey("headers", &headers); err != nil {
return fmt.Errorf("header %s failed to parse: %s", o.viper.Get("headers"), err)
}
o.Headers = headers
@ -342,7 +324,8 @@ func (o *Options) parseHeaders() error {
return nil
}
// bindEnvs binds a viper instance to each env var of an Options struct based on the mapstructure tag
// bindEnvs binds a viper instance to each env var of an Options struct based
// on the mapstructure tag
func bindEnvs(o *Options, v *viper.Viper) error {
tagName := `mapstructure`
t := reflect.TypeOf(*o)
@ -370,6 +353,81 @@ func bindEnvs(o *Options, v *viper.Viper) error {
return nil
}
// Validate ensures the Options fields are valid, and hydrated.
func (o *Options) Validate() error {
var err error
if !IsValidService(o.Services) {
return fmt.Errorf("internal/config: %s is an invalid service type", o.Services)
}
if IsAll(o.Services) {
// mutual auth between services on the same host can be generated at runtime
if o.SharedKey == "" {
o.SharedKey = cryptutil.NewBase64Key()
}
// in all in one mode we are running just over the local socket
o.GRPCInsecure = true
// to avoid port collision when running on localhost
if o.GRPCAddr == defaultOptions.GRPCAddr {
o.GRPCAddr = ":5443"
}
// and we can set the corresponding client
if o.AuthorizeURLString == "" {
o.AuthorizeURLString = "https://localhost:5443"
}
}
if o.SharedKey == "" {
return errors.New("internal/config: shared-key cannot be empty")
}
if o.AuthenticateURLString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthenticateURLString)
if err != nil {
return fmt.Errorf("internal/config: bad authenticate-url %s : %v", o.AuthenticateURLString, err)
}
o.AuthenticateURL = u
}
if o.AuthorizeURLString != "" {
u, err := urlutil.ParseAndValidateURL(o.AuthorizeURLString)
if err != nil {
return fmt.Errorf("internal/config: bad authorize-url %s : %w", o.AuthorizeURLString, err)
}
o.AuthorizeURL = u
}
if o.PolicyFile != "" {
return errors.New("internal/config: policy file setting is deprecated")
}
if err := o.parsePolicy(); err != nil {
return fmt.Errorf("internal/config: failed to parse policy: %w", err)
}
if err := o.parseHeaders(); err != nil {
return fmt.Errorf("internal/config: failed to parse headers: %w", err)
}
if _, disable := o.Headers[DisableHeaderKey]; disable {
o.Headers = make(map[string]string)
}
if o.InsecureServer {
log.Warn().Msg("internal/config: insecure mode enabled")
} else if o.Cert != "" || o.Key != "" {
o.TLSCertificate, err = cryptutil.CertifcateFromBase64(o.Cert, o.Key)
} else if o.CertFile != "" || o.KeyFile != "" {
o.TLSCertificate, err = cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
} else {
err = errors.New("internal/config:no certificates supplied nor was insecure mode set")
}
if err != nil {
return err
}
return nil
}
// OptionsUpdater updates local state based on an Options struct
type OptionsUpdater interface {
UpdateOptions(Options) error
@ -385,34 +443,10 @@ func (o *Options) Checksum() string {
return fmt.Sprintf("%x", hash)
}
func ParseOptions(configFile string) (*Options, error) {
o, err := OptionsFromViper(configFile)
if err != nil {
return nil, err
}
if o.Debug {
log.SetDebugMode()
}
if o.LogLevel != "" {
log.SetLevel(o.LogLevel)
}
metrics.AddPolicyCountCallback(o.Services, func() int64 {
return int64(len(o.Policies))
})
checksumDec, err := strconv.ParseUint(o.Checksum(), 16, 64)
if err != nil {
log.Warn().Err(err).Msg("internal/config: could not parse config checksum into decimal")
}
metrics.SetConfigChecksum(o.Services, checksumDec)
return o, nil
}
func HandleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options {
newOpt, err := ParseOptions(configFile)
newOpt, err := NewOptionsFromConfig(configFile)
if err != nil {
log.Error().Err(err).Msg("config: could not reload configuration")
log.Error().Err(err).Msg("internal/config: could not reload configuration")
metrics.SetConfigInfo(opt.Services, false, "")
return opt
}
@ -426,16 +460,16 @@ func HandleConfigUpdate(configFile string, opt *Options, services []OptionsUpdat
return opt
}
errored := false
var updateFailed bool
for _, service := range services {
if err := service.UpdateOptions(*newOpt); err != nil {
log.Error().Err(err).Msg("internal/config: could not update options")
errored = true
updateFailed = true
metrics.SetConfigInfo(opt.Services, false, "")
}
}
if !errored {
if !updateFailed {
metrics.SetConfigInfo(newOpt.Services, true, newOptChecksum)
}
return newOpt

View file

@ -6,6 +6,7 @@ import (
"io/ioutil"
"net/url"
"os"
"sync"
"testing"
"github.com/google/go-cmp/cmp"
@ -15,14 +16,16 @@ import (
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{})
func Test_validate(t *testing.T) {
func Test_Validate(t *testing.T) {
t.Parallel()
testOptions := func() Options {
testOptions := func() *Options {
o := NewDefaultOptions()
o.SharedKey = "test"
o.Services = "all"
return *o
o.CertFile = "./testdata/example-cert.pem"
o.KeyFile = "./testdata/example-key.pem"
return o
}
good := testOptions()
badServices := testOptions()
@ -38,7 +41,7 @@ func Test_validate(t *testing.T) {
tests := []struct {
name string
testOpts Options
testOpts *Options
wantErr bool
}{
{"good default with no env settings", good, false},
@ -51,7 +54,7 @@ func Test_validate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
err := tt.testOpts.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("optionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
@ -59,7 +62,8 @@ func Test_validate(t *testing.T) {
}
func Test_bindEnvs(t *testing.T) {
o := NewOptions()
o := new(Options)
o.viper = viper.New()
v := viper.New()
os.Clearenv()
defer os.Unsetenv("POMERIUM_DEBUG")
@ -92,7 +96,7 @@ func Test_bindEnvs(t *testing.T) {
}
func Test_parseHeaders(t *testing.T) {
t.Parallel()
// t.Parallel()
tests := []struct {
name string
want map[string]string
@ -110,11 +114,16 @@ func Test_parseHeaders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewDefaultOptions()
o.viper.Set("headers", tt.viperHeaders)
o.viper.Set("HeadersEnv", tt.envHeaders)
var (
o *Options
mu sync.Mutex
)
mu.Lock()
defer mu.Unlock()
o = NewDefaultOptions()
o.viperSet("headers", tt.viperHeaders)
o.viperSet("HeadersEnv", tt.envHeaders)
o.HeadersEnv = tt.envHeaders
err := o.parseHeaders()
if (err != nil) != tt.wantErr {
@ -129,130 +138,6 @@ func Test_parseHeaders(t *testing.T) {
}
func Test_OptionsFromViper(t *testing.T) {
testPolicy := Policy{
To: "https://httpbin.org",
From: "https://pomerium.io",
}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
testPolicies := []Policy{
testPolicy,
}
goodConfigBytes := []byte(`{"authorize_service_url":"https://authorize.corp.example","authenticate_service_url":"https://authenticate.corp.example","shared_secret":"Setec Astronomy","service":"all","policy":[{"from":"https://pomerium.io","to":"https://httpbin.org"}]}`)
goodOptions := *(NewDefaultOptions())
goodOptions.SharedKey = "Setec Astronomy"
goodOptions.Services = "all"
goodOptions.Policies = testPolicies
goodOptions.CookieName = "oatmeal"
goodOptions.AuthorizeURLString = "https://authorize.corp.example"
goodOptions.AuthenticateURLString = "https://authenticate.corp.example"
authorize, err := url.Parse(goodOptions.AuthorizeURLString)
if err != nil {
t.Fatal(err)
}
authenticate, err := url.Parse(goodOptions.AuthenticateURLString)
if err != nil {
t.Fatal(err)
}
goodOptions.AuthorizeURL = authorize
goodOptions.AuthenticateURL = authenticate
if err := goodOptions.Validate(); err != nil {
t.Fatal(err)
}
badConfigBytes := []byte("badjson!")
badUnmarshalConfigBytes := []byte(`"debug": "blue"`)
tests := []struct {
name string
configBytes []byte
want *Options
wantErr bool
}{
{"good", goodConfigBytes, &goodOptions, false},
{"bad json", badConfigBytes, nil, true},
{"bad unmarshal", badUnmarshalConfigBytes, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Clearenv()
os.Setenv("COOKIE_NAME", "oatmeal")
defer os.Unsetenv("COOKIE_NAME")
tempFile, _ := ioutil.TempFile("", "*.json")
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.configBytes)
got, err := OptionsFromViper(tempFile.Name())
if (err != nil) != tt.wantErr {
t.Errorf("OptionsFromViper() error = \n%v, wantErr \n%v", err, tt.wantErr)
}
if tt.want != nil {
if err := tt.want.Validate(); err != nil {
t.Fatal(err)
}
}
if diff := cmp.Diff(got, tt.want, cmpOptIgnoreUnexported); diff != "" {
t.Errorf("OptionsFromViper() = \n%s\n, \ngot\n%+v\n, want \n%+v", diff, got, tt.want)
}
})
}
// Test for missing config file
_, err = OptionsFromViper("filedoesnotexist")
if err == nil {
t.Errorf("OptionsFromViper(): Did when loading missing file")
}
}
func Test_parsePolicyEnv(t *testing.T) {
t.Parallel()
source := "https://pomerium.io"
sourceURL, _ := url.ParseRequestURI(source)
dest := "https://httpbin.org"
destURL, _ := url.ParseRequestURI(dest)
tests := []struct {
name string
policyBytes []byte
want []Policy
wantErr bool
}{
{"simple json", []byte(fmt.Sprintf(`[{"from": "%s","to":"%s"}]`, source, dest)), []Policy{{From: source, To: dest, Source: sourceURL, Destination: destURL}}, false},
{"bad from", []byte(`[{"from": "%","to":"httpbin.org"}]`), []Policy{{From: "%", To: "httpbin.org"}}, true},
{"bad to", []byte(`[{"from": "pomerium.io","to":"%"}]`), []Policy{{From: "pomerium.io", To: "%"}}, true},
{"simple error", []byte(`{}`), nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewOptions()
o.PolicyEnv = base64.StdEncoding.EncodeToString(tt.policyBytes)
err := o.parsePolicy()
if (err != nil) != tt.wantErr {
t.Errorf("parsePolicyEnv() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(o.Policies, tt.want); diff != "" {
t.Errorf("parsePolicyEnv() = %s", diff)
}
})
}
// Catch bad base64
o := NewOptions()
o.PolicyEnv = "foo"
err := o.parsePolicy()
if err == nil {
t.Errorf("parsePolicyEnv() did not catch bad base64 %v", o)
}
}
func Test_parsePolicyFile(t *testing.T) {
t.Parallel()
source := "https://pomerium.io"
@ -276,7 +161,8 @@ func Test_parsePolicyFile(t *testing.T) {
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.policyBytes)
o := NewOptions()
var o Options
o.viper = viper.New()
o.viper.SetConfigFile(tempFile.Name())
if err := o.viper.ReadInConfig(); err != nil {
t.Fatal(err)
@ -316,36 +202,10 @@ func Test_Checksum(t *testing.T) {
}
}
func TestNewOptions(t *testing.T) {
t.Parallel()
tests := []struct {
name string
authenticateURL string
authorizeURL string
want *Options
wantErr bool
}{
{"good", "https://authenticate.example", "https://authorize.example", nil, false},
{"bad authenticate url no scheme", "authenticate.example", "https://authorize.example", nil, true},
{"bad authenticate url no host", "https://", "https://authorize.example", nil, true},
{"bad authorize url no scheme", "https://authenticate.example", "authorize.example", nil, true},
{"bad authorize url no host", "https://authenticate.example", "https://", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewMinimalOptions(tt.authenticateURL, tt.authorizeURL)
if (err != nil) != tt.wantErr {
t.Errorf("NewOptions() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestOptionsFromViper(t *testing.T) {
t.Parallel()
opts := []cmp.Option{
cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
cmpopts.IgnoreFields(Options{}, "CookieSecret", "GRPCInsecure", "GRPCAddr", "AuthorizeURL", "AuthorizeURLString", "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
cmpopts.IgnoreFields(Policy{}, "Source", "Destination"),
cmpOptIgnoreUnexported,
}
@ -357,11 +217,12 @@ func TestOptionsFromViper(t *testing.T) {
wantErr bool
}{
{"good",
[]byte(`{"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
[]byte(`{"insecure_server":true,"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
CookieName: "_pomerium",
CookieSecure: true,
InsecureServer: true,
CookieHTTPOnly: true,
Headers: map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
@ -371,12 +232,13 @@ func TestOptionsFromViper(t *testing.T) {
}},
false},
{"good disable header",
[]byte(`{"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
[]byte(`{"insecure_server":true,"headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: "https://to.example"}},
CookieName: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
InsecureServer: true,
Headers: map[string]string{}},
false},
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
@ -390,49 +252,45 @@ func TestOptionsFromViper(t *testing.T) {
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.configBytes)
got, err := OptionsFromViper(tempFile.Name())
got, err := optionsFromViper(tempFile.Name())
if (err != nil) != tt.wantErr {
t.Errorf("OptionsFromViper() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("optionsFromViper() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(got, tt.want, opts...); diff != "" {
t.Errorf("NewOptions() = %s", diff)
t.Errorf("NewOptionsFromConfig() = %s", diff)
}
})
}
}
func Test_parseOptions(t *testing.T) {
func Test_NewOptionsFromConfigEnvVar(t *testing.T) {
tests := []struct {
name string
envKey string
envValue string
servicesEnvKey string
servicesEnvValue string
wantSharedKey string
wantErr bool
name string
envKeyPairs map[string]string
wantErr bool
}{
{"no shared secret", "", "", "SERVICES", "authenticate", "skip", true},
{"no shared secret in all mode", "", "", "", "", "", false},
{"good", "SHARED_SECRET", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", "", "", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", false},
{"good", map[string]string{"INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"bad no shared secret", map[string]string{"INSECURE_SERVER": "true", "SERVICES": "authenticate"}, true},
{"good no shared secret in all mode", map[string]string{"INSECURE_SERVER": "true"}, false},
{"bad header", map[string]string{"HEADERS": "x;y;z", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad authenticate url", map[string]string{"AUTHENTICATE_SERVICE_URL": "authenticate.example", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad authorize url", map[string]string{"AUTHORIZE_SERVICE_URL": "authorize.example", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad cert base64", map[string]string{"CERTIFICATE": "bad cert", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad cert key base64", map[string]string{"CERTIFICATE_KEY": "bad cert", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad no certs no insecure mode set", map[string]string{"SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"good disable headers ", map[string]string{"HEADERS": "disable:true", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv(tt.servicesEnvKey, tt.servicesEnvValue)
os.Setenv(tt.envKey, tt.envValue)
defer os.Unsetenv(tt.envKey)
defer os.Unsetenv(tt.servicesEnvKey)
got, err := ParseOptions("")
if (err != nil) != tt.wantErr {
t.Errorf("ParseOptions() error = %v, wantErr %v", err, tt.wantErr)
return
for k, v := range tt.envKeyPairs {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
if got != nil && got.Services != "all" && got.SharedKey != tt.wantSharedKey {
t.Errorf("ParseOptions()\n")
t.Errorf("got: %+v\n", got.SharedKey)
t.Errorf("want: %+v\n", tt.wantSharedKey)
_, err := NewOptionsFromConfig("")
if (err != nil) != tt.wantErr {
t.Errorf("NewOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
@ -453,42 +311,102 @@ func (m *mockService) UpdateOptions(o Options) error {
}
func Test_HandleConfigUpdate(t *testing.T) {
os.Clearenv()
os.Setenv("SHARED_SECRET", "foo")
defer os.Unsetenv("SHARED_SECRET")
blankOpts, err := NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
goodOpts, err := OptionsFromViper("")
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
envarKey string
envarValue string
service *mockService
oldOpts Options
wantUpdate bool
name string
oldEnvKeyPairs map[string]string
newEnvKeyPairs map[string]string
service *mockService
wantUpdate bool
}{
{"good", "", "", &mockService{fail: false}, *blankOpts, true},
{"good set debug", "POMERIUM_DEBUG", "true", &mockService{fail: false}, *blankOpts, true},
{"bad", "", "", &mockService{fail: true}, *blankOpts, true},
{"no change", "", "", &mockService{fail: false}, *goodOpts, false},
{"bad policy file unmarshal error", "POLICY", base64.StdEncoding.EncodeToString([]byte("{json:}")), &mockService{fail: false}, *blankOpts, false},
{"bad header key", "SERVICES", "error", &mockService{fail: false}, *blankOpts, false},
{"bad header header value", "HEADERS", "x;y;z", &mockService{fail: false}, *blankOpts, false},
{"good",
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
&mockService{fail: false},
true},
{"good set debug",
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
map[string]string{
"POMERIUM_DEBUG": "true",
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
&mockService{fail: false},
true},
{"bad",
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
&mockService{fail: true},
true},
{"bad policy file unmarshal error",
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
map[string]string{
"POLICY": base64.StdEncoding.EncodeToString([]byte("{json:}")),
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
&mockService{fail: false},
false},
{"bad header key",
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
map[string]string{
"SERVICES": "error",
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
&mockService{fail: false},
false},
{"bad header header value",
map[string]string{
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
map[string]string{
"HEADERS": "x;y;z",
"INSECURE_SERVER": "true",
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
&mockService{fail: false},
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv(tt.envarKey, tt.envarValue)
defer os.Unsetenv(tt.envarKey)
HandleConfigUpdate("", &tt.oldOpts, []OptionsUpdater{tt.service})
for k, v := range tt.oldEnvKeyPairs {
os.Setenv(k, v)
}
oldOpts, err := NewOptionsFromConfig("")
if err != nil {
t.Fatal(err)
}
for k := range tt.oldEnvKeyPairs {
os.Unsetenv(k)
}
for k, v := range tt.newEnvKeyPairs {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
HandleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service})
if tt.service.Updated != tt.wantUpdate {
t.Errorf("Failed to update config on service")
}

View file

@ -4,7 +4,7 @@ import (
"testing"
)
func Test_Validate(t *testing.T) {
func Test_PolicyValidate(t *testing.T) {
t.Parallel()
tests := []struct {

100
internal/grpcutil/grpc.go Normal file
View file

@ -0,0 +1,100 @@
package grpcutil // import "github.com/pomerium/pomerium/internal/grpcutil"
import (
"crypto/tls"
"net"
"os"
"os/signal"
"sync"
"syscall"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// NewServer creates a new gRPC serve.
// It is the callers responsibility to close the resturned server.
func NewServer(opt *ServerOptions, registrationFn func(s *grpc.Server), wg *sync.WaitGroup) *grpc.Server {
if opt == nil {
opt = defaultServerOptions
} else {
opt.applyServerDefaults()
}
ln, err := net.Listen("tcp", opt.Addr)
if err != nil {
log.Fatal().Str("addr", opt.Addr).Err(err).Msg("internal/grpcutil: unexpected ")
}
grpcAuth := NewSharedSecretCred(opt.SharedKey)
grpcOpts := []grpc.ServerOption{
grpc.UnaryInterceptor(grpcAuth.ValidateRequest),
grpc.StatsHandler(metrics.NewGRPCServerStatsHandler(opt.Addr))}
if opt.TLSCertificate != nil {
log.Debug().Str("addr", opt.Addr).Msg("internal/grpcutil: with TLS")
cert := credentials.NewServerTLSFromCert(opt.TLSCertificate)
grpcOpts = append(grpcOpts, grpc.Creds(cert))
} else {
log.Warn().Str("addr", opt.Addr).Msg("internal/grpcutil: insecure server")
}
srv := grpc.NewServer(grpcOpts...)
registrationFn(srv)
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Serve(ln); err != grpc.ErrServerStopped {
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/grpcutil: unexpected shutdown")
}
}()
return srv
}
// ServerOptions contains the configurations settings for a gRPC server.
type ServerOptions struct {
// Addr specifies the host and port on which the server should serve
// gRPC requests. If empty, ":443" is used.
Addr string
// SharedKey is the shared secret authorization key used to mutually authenticate
// requests between services.
SharedKey string
// TLS certificates to use, if any.
TLSCertificate *tls.Certificate
// InsecureServer when enabled disables all transport security.
// In this mode, Pomerium is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureServer bool
}
var defaultServerOptions = &ServerOptions{
Addr: ":443",
}
func (o *ServerOptions) applyServerDefaults() {
if o.Addr == "" {
o.Addr = defaultServerOptions.Addr
}
}
// Shutdown attempts to shut down the server when a os interrupt or sigterm
// signal are received without interrupting any
// active connections. Shutdown stops the server from
// accepting new connections and RPCs and blocks until all the pending RPCs are
// finished.
func Shutdown(srv *grpc.Server) {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
signal.Notify(sigint, syscall.SIGTERM)
rec := <-sigint
log.Info().Str("signal", rec.String()).Msg("internal/grpcutil: shutting down servers")
srv.GracefulStop()
log.Info().Str("signal", rec.String()).Msg("internal/grpcutil: shut down servers")
}

View file

@ -1,4 +1,4 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
package grpcutil // import "github.com/pomerium/pomerium/internal/grpcutil"
import (
"context"
@ -24,7 +24,8 @@ func (s SharedSecretCred) GetRequestMetadata(context.Context, ...string) (map[st
return map[string]string{"authorization": s.sharedSecret}, nil
}
// RequireTransportSecurity should be true as we want to have it encrypted over the wire.
// RequireTransportSecurity indicates whether the credentials requires
// transport security.
func (s SharedSecretCred) RequireTransportSecurity() bool { return false }
// ValidateRequest ensures a valid token exists within a request's metadata. If

View file

@ -1,76 +0,0 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"context"
"fmt"
stdlog "log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
)
// NewHTTPServer starts a http server given a set of options and a handler.
//
// It is the caller's responsibility to Close() or Shutdown() the returned
// server.
func NewHTTPServer(opt *ServerOptions, h http.Handler) *http.Server {
if opt == nil {
opt = defaultHTTPServerOptions
} else {
opt.applyHTTPDefaults()
}
sublogger := log.With().Str("addr", opt.Addr).Logger()
srv := http.Server{
Addr: opt.Addr,
ReadHeaderTimeout: opt.ReadHeaderTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
IdleTimeout: opt.IdleTimeout,
Handler: h,
ErrorLog: stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0),
}
go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/httputil: unexpected shutdown")
}
}()
return &srv
}
func RedirectHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "close")
url := fmt.Sprintf("https://%s%s", urlutil.StripPort(r.Host), r.URL.String())
http.Redirect(w, r, url, http.StatusMovedPermanently)
})
}
// Shutdown attempts to shut down the server when a os interrupt or sigterm
// signal are received without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
//
// When Shutdown is called, Serve, ListenAndServe, and
// ListenAndServeTLS immediately return ErrServerClosed.
func Shutdown(srv *http.Server) {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
signal.Notify(sigint, syscall.SIGTERM)
rec := <-sigint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
if err := srv.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("internal/httputil: shutdown failed")
}
}

View file

@ -1,49 +0,0 @@
package httputil
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewHTTPServer(t *testing.T) {
tests := []struct {
name string
opts *ServerOptions
// wantErr bool
}{
{"localhost:9232", &ServerOptions{Addr: "localhost:9232"}},
{"localhost:65536", &ServerOptions{Addr: "localhost:-1"}}, // will fail, but won't err
{"empty", &ServerOptions{}},
{"empty", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := NewHTTPServer(tt.opts, RedirectHandler())
defer srv.Close()
// we cheat a little bit here and use the httptest server to test the client
ts := httptest.NewServer(srv.Handler)
defer ts.Close()
client := ts.Client()
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
})
}
}

View file

@ -1,10 +1,8 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"path/filepath"
"crypto/tls"
"time"
"github.com/pomerium/pomerium/internal/fileutil"
)
// ServerOptions contains the configurations settings for a http server.
@ -14,11 +12,7 @@ type ServerOptions struct {
Addr string
// TLS certificates to use.
Cert string
Key string
CertFile string
KeyFile string
TLSCertificate *tls.Certificate
// Timeouts
ReadHeaderTimeout time.Duration
ReadTimeout time.Duration
@ -26,62 +20,28 @@ type ServerOptions struct {
IdleTimeout time.Duration
}
var defaultTLSServerOptions = &ServerOptions{
var defaultServerOptions = &ServerOptions{
Addr: ":443",
CertFile: filepath.Join(fileutil.Getwd(), "cert.pem"),
KeyFile: filepath.Join(fileutil.Getwd(), "privkey.pem"),
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // support streaming by default
IdleTimeout: 5 * time.Minute,
}
func (o *ServerOptions) applyTLSDefaults() {
func (o *ServerOptions) applyServerDefaults() {
if o.Addr == "" {
o.Addr = defaultTLSServerOptions.Addr
}
if o.Cert == "" && o.CertFile == "" {
o.CertFile = defaultTLSServerOptions.CertFile
}
if o.Key == "" && o.KeyFile == "" {
o.KeyFile = defaultTLSServerOptions.KeyFile
o.Addr = defaultServerOptions.Addr
}
if o.ReadHeaderTimeout == 0 {
o.ReadHeaderTimeout = defaultTLSServerOptions.ReadHeaderTimeout
o.ReadHeaderTimeout = defaultServerOptions.ReadHeaderTimeout
}
if o.ReadTimeout == 0 {
o.ReadTimeout = defaultTLSServerOptions.ReadTimeout
o.ReadTimeout = defaultServerOptions.ReadTimeout
}
if o.WriteTimeout == 0 {
o.WriteTimeout = defaultTLSServerOptions.WriteTimeout
o.WriteTimeout = defaultServerOptions.WriteTimeout
}
if o.IdleTimeout == 0 {
o.IdleTimeout = defaultTLSServerOptions.IdleTimeout
}
}
var defaultHTTPServerOptions = &ServerOptions{
Addr: ":80",
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 5 * time.Minute,
}
func (o *ServerOptions) applyHTTPDefaults() {
if o.Addr == "" {
o.Addr = defaultHTTPServerOptions.Addr
}
if o.ReadHeaderTimeout == 0 {
o.ReadHeaderTimeout = defaultHTTPServerOptions.ReadHeaderTimeout
}
if o.ReadTimeout == 0 {
o.ReadTimeout = defaultHTTPServerOptions.ReadTimeout
}
if o.WriteTimeout == 0 {
o.WriteTimeout = defaultHTTPServerOptions.WriteTimeout
}
if o.IdleTimeout == 0 {
o.IdleTimeout = defaultHTTPServerOptions.IdleTimeout
o.IdleTimeout = defaultServerOptions.IdleTimeout
}
}

View file

@ -1,49 +1,37 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"context"
"crypto/tls"
"fmt"
stdlog "log"
"net"
"net/http"
"strings"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
)
// NewTLSServer creates a new TLS server given a set of options, handlers, and
// optionally a set of gRPC endpoints as well.
// It is the callers responsibility to close the resturned server.
func NewTLSServer(opt *ServerOptions, httpHandler http.Handler, grpcHandler http.Handler) (*http.Server, error) {
// NewServer creates a new HTTP server given a set of options, handler, and
// waitgroup. It is the callers responsibility to close the resturned server.
func NewServer(opt *ServerOptions, h http.Handler, wg *sync.WaitGroup) (*http.Server, error) {
if opt == nil {
opt = defaultTLSServerOptions
opt = defaultServerOptions
} else {
opt.applyTLSDefaults()
opt.applyServerDefaults()
}
var cert *tls.Certificate
var err error
if opt.Cert != "" && opt.Key != "" {
cert, err = cryptutil.CertifcateFromBase64(opt.Cert, opt.Key)
} else {
cert, err = cryptutil.CertificateFromFile(opt.CertFile, opt.KeyFile)
}
if err != nil {
return nil, fmt.Errorf("internal/httputil: failed loading x509 certificate: %v", err)
}
config := newDefaultTLSConfig(cert)
ln, err := net.Listen("tcp", opt.Addr)
if err != nil {
return nil, err
}
ln = tls.NewListener(ln, config)
var h http.Handler
if grpcHandler == nil {
h = httpHandler
} else {
h = grpcHandlerFunc(grpcHandler, httpHandler)
if opt.TLSCertificate != nil {
ln = tls.NewListener(ln, newDefaultTLSConfig(opt.TLSCertificate))
}
sublogger := log.With().Str("addr", opt.Addr).Logger()
@ -53,16 +41,16 @@ func NewTLSServer(opt *ServerOptions, httpHandler http.Handler, grpcHandler http
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
IdleTimeout: opt.IdleTimeout,
TLSConfig: config,
Handler: h,
ErrorLog: stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0),
}
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Serve(ln); err != http.ErrServerClosed {
log.Error().Err(err).Msg("internal/httputil: tls server crashed")
}
}()
return srv, nil
}
@ -98,15 +86,35 @@ func newDefaultTLSConfig(cert *tls.Certificate) *tls.Config {
return tlsConfig
}
// grpcHandlerFunc splits request serving between gRPC and HTTPS depending on
// the request type. Requires HTTP/2 to be enabled.
func grpcHandlerFunc(rpcServer http.Handler, other http.Handler) http.Handler {
// RedirectHandler takes an incoming request and redirects to its HTTPS counterpart
func RedirectHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ct := r.Header.Get("Content-Type")
if r.ProtoMajor == 2 && strings.Contains(ct, "application/grpc") {
rpcServer.ServeHTTP(w, r)
} else {
other.ServeHTTP(w, r)
}
w.Header().Set("Connection", "close")
url := fmt.Sprintf("https://%s", urlutil.StripPort(r.Host))
http.Redirect(w, r, url, http.StatusMovedPermanently)
})
}
// Shutdown attempts to shut down the server when a os interrupt or sigterm
// signal are received without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
//
// When Shutdown is called, Serve, ListenAndServe, and
// ListenAndServeTLS immediately return ErrServerClosed.
func Shutdown(srv *http.Server) {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
signal.Notify(sigint, syscall.SIGTERM)
rec := <-sigint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
if err := srv.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("internal/httputil: shutdown failed")
}
}

View file

@ -0,0 +1,158 @@
package httputil
import (
"encoding/base64"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"os/signal"
"sync"
"syscall"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const privKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
-----END EC PRIVATE KEY-----`
const pubKey = `-----BEGIN CERTIFICATE-----
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
-----END CERTIFICATE-----`
func TestNewServer(t *testing.T) {
certb64, err := cryptutil.CertifcateFromBase64(
base64.StdEncoding.EncodeToString([]byte(pubKey)),
base64.StdEncoding.EncodeToString([]byte(privKey)))
if err != nil {
t.Fatal(err)
}
t.Parallel()
tests := []struct {
name string
opt *ServerOptions
httpHandler http.Handler
// want *http.Server
wantErr bool
}{
{"good basic http handler",
&ServerOptions{
Addr: "127.0.0.1:0",
TLSCertificate: certb64,
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
false},
// todo(bdd): fails travis-ci
// {"good no address",
// &ServerOptions{
// TLSCertificate: certb64,
// },
// http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintln(w, "Hello, http")
// }),
// false},
// todo(bdd): fails travis-ci
// {"empty handler",
// nil,
// http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintln(w, "Hello, http")
// }),
// false},
{"bad port - invalid port range ",
&ServerOptions{
Addr: "127.0.0.1:65536",
TLSCertificate: certb64,
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var wg sync.WaitGroup
srv, err := NewServer(tt.opt, tt.httpHandler, &wg)
if (err != nil) != tt.wantErr {
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
// we cheat a little bit here and use the httptest server to test the client
ts := httptest.NewServer(srv.Handler)
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
}
if srv != nil {
// simulate a sigterm and cleanup the server
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
go Shutdown(srv)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
}
})
}
}
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
select {
case s := <-c:
if s != sig {
t.Fatalf("signal was %v, want %v", s, sig)
}
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for %v", sig)
}
}
func TestRedirectHandler(t *testing.T) {
tests := []struct {
name string
wantStatus int
wantBody string
}{
{"http://example", http.StatusMovedPermanently, "<a href=\"https://example\">Moved Permanently</a>.\n\n"},
{"http://example:8080", http.StatusMovedPermanently, "<a href=\"https://example\">Moved Permanently</a>.\n\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example/", nil)
rr := httptest.NewRecorder()
RedirectHandler().ServeHTTP(rr, req)
if diff := cmp.Diff(tt.wantStatus, rr.Code); diff != "" {
t.Errorf("TestRedirectHandler() code diff :%s", diff)
}
if diff := cmp.Diff(tt.wantBody, rr.Body.String()); diff != "" {
t.Errorf("TestRedirectHandler() body diff :%s", diff)
}
})
}
}

View file

@ -1,210 +0,0 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"encoding/base64"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"os/signal"
"syscall"
"testing"
"time"
)
const privKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
-----END EC PRIVATE KEY-----`
const pubKey = `-----BEGIN CERTIFICATE-----
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
-----END CERTIFICATE-----`
func TestNewTLSServer(t *testing.T) {
t.Parallel()
tests := []struct {
name string
opt *ServerOptions
httpHandler http.Handler
grpcHandler http.Handler
// want *http.Server
wantErr bool
}{
{"good basic http handler",
&ServerOptions{
Addr: "127.0.0.1:0",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
false},
{"good basic http and grpc handler",
&ServerOptions{
Addr: "127.0.0.1:0",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
false},
{"good with cert files",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "test_data/cert.pem",
KeyFile: "test_data/privkey.pem",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
false},
{"unreadable cert file",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "test_data",
KeyFile: "test_data/privkey.pem",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
true},
{"unreadable key file",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "./test_data/cert.pem",
KeyFile: "./test_data",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
true},
{"unreadable key file",
&ServerOptions{
Addr: "127.0.0.1:0",
CertFile: "./test_data/cert.pem",
KeyFile: "./test_data/file-does-not-exist",
},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, grpc")
}),
true},
{"bad private key base64",
&ServerOptions{
Addr: "127.0.0.1:0",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: "bad guy",
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"bad public key base64",
&ServerOptions{
Addr: "127.0.0.1:9999",
Key: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Cert: "bad guy",
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"bad port - invalid port range ",
&ServerOptions{
Addr: "127.0.0.1:65536",
Cert: base64.StdEncoding.EncodeToString([]byte(pubKey)),
Key: base64.StdEncoding.EncodeToString([]byte(privKey)),
}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"nil apply default but will fail",
nil,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
{"empty, apply defaults to missing",
&ServerOptions{},
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, http")
}),
nil,
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv, err := NewTLSServer(tt.opt, tt.httpHandler, tt.grpcHandler)
if (err != nil) != tt.wantErr {
t.Errorf("NewTLSServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
// we cheat a little bit here and use the httptest server to test the client
ts := httptest.NewTLSServer(srv.Handler)
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
log.Fatal(err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s", greeting)
}
if srv != nil {
// simulate a sigterm and cleanup the server
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
go Shutdown(srv)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
}
})
}
}
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
select {
case s := <-c:
if s != sig {
t.Fatalf("signal was %v, want %v", s, sig)
}
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for %v", sig)
}
}

View file

@ -81,6 +81,7 @@ func TestAuthorizeGRPC_IsAdmin(t *testing.T) {
}
func TestNewGRPC(t *testing.T) {
t.Parallel()
tests := []struct {
name string
opts *Options
@ -100,6 +101,8 @@ func TestNewGRPC(t *testing.T) {
{"bad ca encoding", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CA: "^"}, true, "", "localhost.example:443"},
{"custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt"}, false, "", "localhost.example:443"},
{"bad custom ca file", &Options{Addr: nil, InternalAddr: &url.URL{Scheme: "https", Host: "localhost.example"}, OverrideCertificateName: "*.local", SharedSecret: "shh", CAFile: "testdata/example.crt2"}, true, "", "localhost.example:443"},
{"valid with insecure", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh", WithInsecure: true}, false, "", "localhost.example:8443"},
{"valid client round robin", &Options{Addr: &url.URL{Scheme: "https", Host: "localhost.example:8443"}, SharedSecret: "shh", ClientDNSRoundRobin: true}, false, "", "dns:///localhost.example:8443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -12,8 +12,8 @@ import (
"strings"
"time"
"github.com/pomerium/pomerium/internal/grpcutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"go.opencensus.io/plugin/ocgrpc"
@ -45,6 +45,10 @@ type Options struct {
RequestTimeout time.Duration
// ClientDNSRoundRobin enables or disables DNS resolver based load balancing
ClientDNSRoundRobin bool
// WithInsecure disables transport security for this ClientConn.
// Note that transport security is required unless WithInsecure is set.
WithInsecure bool
}
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
@ -57,7 +61,6 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
return nil, errors.New("proxy/clients: connection address required")
}
grpcAuth := middleware.NewSharedSecretCred(opts.SharedSecret)
var connAddr string
if opts.InternalAddr != nil {
@ -69,60 +72,61 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
if !strings.Contains(connAddr, ":") {
connAddr = fmt.Sprintf("%s:%d", connAddr, defaultGRPCPort)
}
var cp *x509.CertPool
if opts.CA != "" || opts.CAFile != "" {
cp = x509.NewCertPool()
var ca []byte
var err error
if opts.CA != "" {
ca, err = base64.StdEncoding.DecodeString(opts.CA)
if err != nil {
return nil, fmt.Errorf("failed to decode certificate authority: %v", err)
}
} else {
ca, err = ioutil.ReadFile(opts.CAFile)
if err != nil {
return nil, fmt.Errorf("certificate authority file %v not readable: %v", opts.CAFile, err)
}
}
if ok := cp.AppendCertsFromPEM(ca); !ok {
return nil, fmt.Errorf("failed to append CA cert to certPool")
}
log.Debug().Msg("proxy/clients: using a custom certificate authority")
} else {
newCp, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
cp = newCp
log.Debug().Msg("proxy/clients: using system certificate pool")
}
log.Debug().Str("cert-override-name", opts.OverrideCertificateName).Str("addr", connAddr).Msgf("proxy/clients: grpc connection")
cert := credentials.NewTLS(&tls.Config{RootCAs: cp})
// override allowed certificate name string, typically used when doing behind ingress connection
if opts.OverrideCertificateName != "" {
err := cert.OverrideServerName(opts.OverrideCertificateName)
if err != nil {
return nil, err
}
}
dialOptions := []grpc.DialOption{
grpc.WithTransportCredentials(cert),
grpc.WithPerRPCCredentials(grpcAuth),
grpc.WithPerRPCCredentials(grpcutil.NewSharedSecretCred(opts.SharedSecret)),
grpc.WithChainUnaryInterceptor(metrics.GRPCClientInterceptor("proxy"), grpcTimeoutInterceptor(opts.RequestTimeout)),
grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
}
if opts.WithInsecure {
log.Info().Str("addr", connAddr).Msg("proxy/clients: grpc with insecure")
dialOptions = append(dialOptions, grpc.WithInsecure())
} else {
rootCAs, err := x509.SystemCertPool()
if err != nil {
log.Warn().Msg("proxy/clients: failed getting system cert pool making new one")
rootCAs = x509.NewCertPool()
}
if opts.CA != "" || opts.CAFile != "" {
var ca []byte
var err error
if opts.CA != "" {
ca, err = base64.StdEncoding.DecodeString(opts.CA)
if err != nil {
return nil, fmt.Errorf("failed to decode certificate authority: %v", err)
}
} else {
ca, err = ioutil.ReadFile(opts.CAFile)
if err != nil {
return nil, fmt.Errorf("certificate authority file %v not readable: %v", opts.CAFile, err)
}
}
if ok := rootCAs.AppendCertsFromPEM(ca); !ok {
return nil, fmt.Errorf("failed to append CA cert to certPool")
}
log.Debug().Msg("proxy/clients: added custom certificate authority")
}
cert := credentials.NewTLS(&tls.Config{RootCAs: rootCAs})
// override allowed certificate name string, typically used when doing behind ingress connection
if opts.OverrideCertificateName != "" {
log.Debug().Str("cert-override-name", opts.OverrideCertificateName).Msg("proxy/clients: grpc")
err := cert.OverrideServerName(opts.OverrideCertificateName)
if err != nil {
return nil, err
}
}
// finally add our credential
dialOptions = append(dialOptions, grpc.WithTransportCredentials(cert))
}
if opts.ClientDNSRoundRobin {
dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig())
connAddr = fmt.Sprintf("dns:///%s", connAddr)
}
return grpc.Dial(
connAddr,
dialOptions...,

View file

@ -147,6 +147,7 @@ func New(opts config.Options) (*Proxy, error) {
CAFile: opts.CAFile,
RequestTimeout: opts.GRPCClientTimeout,
ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin,
WithInsecure: opts.GRPCInsecure,
})
return p, err
}

View file

@ -10,27 +10,18 @@ import (
"github.com/pomerium/pomerium/internal/config"
)
func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
return opts
}
func testOptions(t *testing.T) config.Options {
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
opts := config.NewDefaultOptions()
opts.AuthenticateURLString = "https://authenticate.example"
opts.AuthorizeURLString = "https://authorize.example"
opts := newTestOptions(t)
testPolicy := config.Policy{From: "https://corp.example.example", To: "https://example.example"}
opts.Policies = []config.Policy{testPolicy}
opts.AuthenticateURL = authenticateService
opts.AuthorizeURL = authorizeService
opts.InsecureServer = true
opts.CookieSecure = false
opts.Services = config.ServiceAll
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
opts.CookieName = "pomerium"
err := opts.Validate()
if err != nil {
t.Fatal(err)