diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 38a39da5c..c167c1647 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -4,29 +4,10 @@ import ( "context" "flag" "fmt" - "net" - "os" - "os/signal" - "sync" - "syscall" - "github.com/pomerium/pomerium/authenticate" - "github.com/pomerium/pomerium/authorize" - "github.com/pomerium/pomerium/cache" - "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/controlplane" - "github.com/pomerium/pomerium/internal/envoy" - pbCache "github.com/pomerium/pomerium/internal/grpc/cache" - "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/cmd/pomerium" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/metrics" - "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" - "github.com/pomerium/pomerium/proxy" - - envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" - "golang.org/x/sync/errgroup" ) var versionFlag = flag.Bool("version", false, "prints the version") @@ -44,179 +25,5 @@ func run(ctx context.Context) error { fmt.Println(version.FullVersion()) return nil } - opt, err := config.NewOptionsFromConfig(*configFile) - if err != nil { - return err - } - var optionsUpdaters []config.OptionsUpdater - - log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium") - - if err := setupMetrics(opt); err != nil { - return err - } - if err := setupTracing(opt); err != nil { - return err - } - - // setup the control plane - controlPlane, err := controlplane.NewServer() - if err != nil { - return fmt.Errorf("error creating control plane: %w", err) - } - optionsUpdaters = append(optionsUpdaters, controlPlane) - err = controlPlane.UpdateOptions(*opt) - if err != nil { - return fmt.Errorf("error updating control plane options: %w", err) - } - - _, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String()) - _, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String()) - - // create envoy server - envoyServer, err := envoy.NewServer(grpcPort, httpPort) - if err != nil { - return fmt.Errorf("error creating envoy server") - } - - // add services - if err := setupAuthenticate(opt, controlPlane); err != nil { - return err - } - if err := setupAuthorize(opt, controlPlane, &optionsUpdaters); err != nil { - return err - } - if err := setupCache(opt, controlPlane); err != nil { - return err - } - if err := setupProxy(opt, controlPlane); err != nil { - return err - } - - // start the config change listener - go config.WatchChanges(*configFile, opt, optionsUpdaters) - - ctx, cancel := context.WithCancel(ctx) - go func() { - ch := make(chan os.Signal, 2) - signal.Notify(ch, os.Interrupt) - signal.Notify(ch, syscall.SIGTERM) - <-ch - cancel() - }() - - // run everything - eg, ctx := errgroup.WithContext(ctx) - eg.Go(func() error { - return controlPlane.Run(ctx) - }) - eg.Go(func() error { - return envoyServer.Run(ctx) - }) - return eg.Wait() -} - -func setupAuthenticate(opt *config.Options, controlPlane *controlplane.Server) error { - if !config.IsAuthenticate(opt.Services) { - return nil - } - - svc, err := authenticate.New(*opt) - if err != nil { - return fmt.Errorf("error creating authenticate service: %w", err) - } - host := urlutil.StripPort(opt.AuthenticateURL.Host) - sr := controlPlane.HTTPRouter.Host(host).Subrouter() - svc.Mount(sr) - log.Info().Str("host", host).Msg("enabled authenticate service") - - return nil -} - -func setupAuthorize(opt *config.Options, controlPlane *controlplane.Server, optionsUpdaters *[]config.OptionsUpdater) error { - if !config.IsAuthorize(opt.Services) { - return nil - } - - svc, err := authorize.New(*opt) - if err != nil { - return fmt.Errorf("error creating authorize service: %w", err) - } - envoy_service_auth_v2.RegisterAuthorizationServer(controlPlane.GRPCServer, svc) - - log.Info().Msg("enabled authorize service") - - *optionsUpdaters = append(*optionsUpdaters, svc) - err = svc.UpdateOptions(*opt) - if err != nil { - return fmt.Errorf("error updating authorize options: %w", err) - } - return nil -} - -func setupCache(opt *config.Options, controlPlane *controlplane.Server) error { - if !config.IsCache(opt.Services) { - return nil - } - - svc, err := cache.New(*opt) - if err != nil { - return fmt.Errorf("error creating config service: %w", err) - } - defer svc.Close() - pbCache.RegisterCacheServer(controlPlane.GRPCServer, svc) - log.Info().Msg("enabled cache service") - return nil -} - -func setupMetrics(opt *config.Options) error { - if opt.MetricsAddr != "" { - handler, err := metrics.PrometheusHandler() - if err != nil { - return err - } - metrics.SetBuildInfo(opt.Services) - metrics.RegisterInfoMetrics() - serverOpts := &httputil.ServerOptions{ - Addr: opt.MetricsAddr, - Insecure: true, - Service: "metrics", - } - var wg sync.WaitGroup - _, err = httputil.NewServer(serverOpts, handler, &wg) - if err != nil { - return err - } - } - return nil -} - -func setupProxy(opt *config.Options, controlPlane *controlplane.Server) error { - if !config.IsProxy(opt.Services) { - return nil - } - - svc, err := proxy.New(*opt) - if err != nil { - return fmt.Errorf("error creating proxy service: %w", err) - } - controlPlane.HTTPRouter.PathPrefix("/").Handler(svc) - return nil -} - -func setupTracing(opt *config.Options) error { - if opt.TracingProvider != "" { - tracingOpts := &trace.TracingOptions{ - Provider: opt.TracingProvider, - Service: opt.Services, - Debug: opt.TracingDebug, - JaegerAgentEndpoint: opt.TracingJaegerAgentEndpoint, - JaegerCollectorEndpoint: opt.TracingJaegerCollectorEndpoint, - ZipkinEndpoint: opt.ZipkinEndpoint, - } - if err := trace.RegisterTracing(tracingOpts); err != nil { - return err - } - } - return nil + return pomerium.Run(ctx, *configFile) } diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go new file mode 100644 index 000000000..6879a6a4d --- /dev/null +++ b/internal/cmd/pomerium/pomerium.go @@ -0,0 +1,225 @@ +// Package pomerium houses the main pomerium CLI command. +// +package pomerium + +import ( + "context" + "fmt" + "net" + "os" + "os/signal" + "sync" + "syscall" + + envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/authenticate" + "github.com/pomerium/pomerium/authorize" + "github.com/pomerium/pomerium/cache" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/controlplane" + "github.com/pomerium/pomerium/internal/envoy" + pbCache "github.com/pomerium/pomerium/internal/grpc/cache" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/telemetry/metrics" + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/internal/version" + "github.com/pomerium/pomerium/proxy" +) + +// Run runs the main pomerium application. +func Run(ctx context.Context, configFile string) error { + opt, err := config.NewOptionsFromConfig(configFile) + if err != nil { + return err + } + var optionsUpdaters []config.OptionsUpdater + + log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium") + + if err := setupMetrics(ctx, opt); err != nil { + return err + } + if err := setupTracing(ctx, opt); err != nil { + return err + } + + // setup the control plane + controlPlane, err := controlplane.NewServer() + if err != nil { + return fmt.Errorf("error creating control plane: %w", err) + } + optionsUpdaters = append(optionsUpdaters, controlPlane) + err = controlPlane.UpdateOptions(*opt) + if err != nil { + return fmt.Errorf("error updating control plane options: %w", err) + } + + _, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String()) + _, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String()) + + // create envoy server + envoyServer, err := envoy.NewServer(grpcPort, httpPort) + if err != nil { + return fmt.Errorf("error creating envoy server") + } + + // add services + if err := setupAuthenticate(opt, controlPlane); err != nil { + return err + } + if err := setupAuthorize(opt, controlPlane, &optionsUpdaters); err != nil { + return err + } + if err := setupCache(opt, controlPlane); err != nil { + return err + } + if err := setupProxy(opt, controlPlane); err != nil { + return err + } + + // start the config change listener + go config.WatchChanges(configFile, opt, optionsUpdaters) + + ctx, cancel := context.WithCancel(ctx) + go func() { + ch := make(chan os.Signal, 2) + defer signal.Stop(ch) + + signal.Notify(ch, os.Interrupt) + signal.Notify(ch, syscall.SIGTERM) + + select { + case <-ch: + case <-ctx.Done(): + } + cancel() + }() + + // run everything + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return controlPlane.Run(ctx) + }) + eg.Go(func() error { + return envoyServer.Run(ctx) + }) + return eg.Wait() +} + +func setupAuthenticate(opt *config.Options, controlPlane *controlplane.Server) error { + if !config.IsAuthenticate(opt.Services) { + return nil + } + + svc, err := authenticate.New(*opt) + if err != nil { + return fmt.Errorf("error creating authenticate service: %w", err) + } + host := urlutil.StripPort(opt.AuthenticateURL.Host) + sr := controlPlane.HTTPRouter.Host(host).Subrouter() + svc.Mount(sr) + log.Info().Str("host", host).Msg("enabled authenticate service") + + return nil +} + +func setupAuthorize(opt *config.Options, controlPlane *controlplane.Server, optionsUpdaters *[]config.OptionsUpdater) error { + if !config.IsAuthorize(opt.Services) { + return nil + } + + svc, err := authorize.New(*opt) + if err != nil { + return fmt.Errorf("error creating authorize service: %w", err) + } + envoy_service_auth_v2.RegisterAuthorizationServer(controlPlane.GRPCServer, svc) + + log.Info().Msg("enabled authorize service") + + *optionsUpdaters = append(*optionsUpdaters, svc) + err = svc.UpdateOptions(*opt) + if err != nil { + return fmt.Errorf("error updating authorize options: %w", err) + } + return nil +} + +func setupCache(opt *config.Options, controlPlane *controlplane.Server) error { + if !config.IsCache(opt.Services) { + return nil + } + + svc, err := cache.New(*opt) + if err != nil { + return fmt.Errorf("error creating config service: %w", err) + } + defer svc.Close() + pbCache.RegisterCacheServer(controlPlane.GRPCServer, svc) + log.Info().Msg("enabled cache service") + return nil +} + +func setupMetrics(ctx context.Context, opt *config.Options) error { + if opt.MetricsAddr != "" { + handler, err := metrics.PrometheusHandler() + if err != nil { + return err + } + metrics.SetBuildInfo(opt.Services) + metrics.RegisterInfoMetrics() + serverOpts := &httputil.ServerOptions{ + Addr: opt.MetricsAddr, + Insecure: true, + Service: "metrics", + } + var wg sync.WaitGroup + srv, err := httputil.NewServer(serverOpts, handler, &wg) + if err != nil { + return err + } + go func() { + <-ctx.Done() + _ = srv.Close() + }() + } + return nil +} + +func setupProxy(opt *config.Options, controlPlane *controlplane.Server) error { + if !config.IsProxy(opt.Services) { + return nil + } + + svc, err := proxy.New(*opt) + if err != nil { + return fmt.Errorf("error creating proxy service: %w", err) + } + controlPlane.HTTPRouter.PathPrefix("/").Handler(svc) + return nil +} + +func setupTracing(ctx context.Context, opt *config.Options) error { + if opt.TracingProvider != "" { + tracingOpts := &trace.TracingOptions{ + Provider: opt.TracingProvider, + Service: opt.Services, + Debug: opt.TracingDebug, + JaegerAgentEndpoint: opt.TracingJaegerAgentEndpoint, + JaegerCollectorEndpoint: opt.TracingJaegerCollectorEndpoint, + ZipkinEndpoint: opt.ZipkinEndpoint, + } + exporter, err := trace.RegisterTracing(tracingOpts) + if err != nil { + return err + } + go func() { + <-ctx.Done() + trace.UnregisterTracing(exporter) + }() + } + return nil +} diff --git a/cmd/pomerium/main_test.go b/internal/cmd/pomerium/pomerium_test.go similarity index 69% rename from cmd/pomerium/main_test.go rename to internal/cmd/pomerium/pomerium_test.go index 21c317dfa..04323760a 100644 --- a/cmd/pomerium/main_test.go +++ b/internal/cmd/pomerium/pomerium_test.go @@ -1,4 +1,4 @@ -package main +package pomerium import ( "context" @@ -23,7 +23,7 @@ func Test_setupTracing(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setupTracing(tt.opt) + setupTracing(context.Background(), tt.opt) }) } } @@ -41,7 +41,7 @@ func Test_setupMetrics(t *testing.T) { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) defer signal.Stop(c) - setupMetrics(tt.opt) + setupMetrics(context.Background(), tt.opt) syscall.Kill(syscall.Getpid(), syscall.SIGINT) waitSig(t, c, syscall.SIGINT) }) @@ -64,13 +64,11 @@ func Test_run(t *testing.T) { t.Parallel() tests := []struct { name string - versionFlag bool configFileFlag string wantErr bool }{ - {"simply print version", true, "", false}, - {"nil configuration", false, "", true}, - {"bad proxy no authenticate url", false, ` + {"nil configuration", "", true}, + {"bad proxy no authenticate url", ` { "address": ":9433", "grpc_address": ":9444", @@ -82,7 +80,7 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - {"bad authenticate no cookie secret", false, ` + {"bad authenticate no cookie secret", ` { "address": ":9433", "grpc_address": ":9444", @@ -93,7 +91,7 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - {"bad authorize service bad shared key", false, ` + {"bad authorize service bad shared key", ` { "address": ":9433", "grpc_address": ":9444", @@ -105,7 +103,7 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - {"bad http port", false, ` + {"bad http port", ` { "address": ":-1", "grpc_address": ":9444", @@ -119,7 +117,7 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - {"bad redirect port", false, ` + {"bad redirect port", ` { "address": ":9433", "http_redirect_addr":":-1", @@ -134,7 +132,7 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - {"bad metrics port ", false, ` + {"bad metrics port ", ` { "address": ":9433", "metrics_address": ":-1", @@ -148,7 +146,7 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - {"malformed tracing provider", false, ` + {"malformed tracing provider", ` { "tracing_provider": "bad tracing provider", "address": ":9433", @@ -163,55 +161,9 @@ func Test_run(t *testing.T) { "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] } `, true}, - // {"simple cache", false, ` - // { - // "address": ":9433", - // "grpc_address": ":9444", - // "grpc_insecure": false, - // "insecure_server": true, - // "cache_service_url": "https://authorize.corp.example", - // "authenticate_service_url": "https://authenticate.corp.example", - // "shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", - // "cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", - // "services": "cache", - // "cache_store": "bolt", - // "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] - // } - // `, false}, - // {"malformed cache", false, ` - // { - // "address": ":9433", - // "grpc_address": ":9444", - // "grpc_insecure": false, - // "insecure_server": true, - // "cache_service_url": "https://authorize.corp.example", - // "authenticate_service_url": "https://authenticate.corp.example", - // "shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", - // "cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", - // "services": "cache", - // "cache_store": "bad bolt", - // "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] - // } - // `, true}, - // {"bad cache port", false, ` - // { - // "address": ":9433", - // "grpc_address": ":9999999", - // "grpc_insecure": false, - // "insecure_server": true, - // "cache_service_url": "https://authorize.corp.example", - // "authenticate_service_url": "https://authenticate.corp.example", - // "shared_secret": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", - // "cookie_secret": "zixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", - // "services": "cache", - // "cache_store": "bolt", - // "policy": [{ "from": "https://pomerium.io", "to": "https://httpbin.org" }] - // } - // `, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - versionFlag = &tt.versionFlag tmpFile, err := ioutil.TempFile(os.TempDir(), "*.json") if err != nil { t.Fatal("Cannot create temporary file", err) @@ -222,12 +174,12 @@ func Test_run(t *testing.T) { tmpFile.Close() t.Fatal(err) } - configFile = &fn + configFile := fn ctx, clearTimeout := context.WithTimeout(context.Background(), 500*time.Millisecond) defer clearTimeout() - err = run(ctx) + err = Run(ctx, configFile) if (err != nil) != tt.wantErr { t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/internal/telemetry/trace/trace.go b/internal/telemetry/trace/trace.go index 18fa6264f..05b7b5e72 100644 --- a/internal/telemetry/trace/trace.go +++ b/internal/telemetry/trace/trace.go @@ -44,28 +44,34 @@ type TracingOptions struct { } // RegisterTracing creates a new trace exporter from TracingOptions. -func RegisterTracing(opts *TracingOptions) error { +func RegisterTracing(opts *TracingOptions) (trace.Exporter, error) { + var exporter trace.Exporter var err error switch opts.Provider { case JaegerTracingProviderName: - err = registerJaeger(opts) + exporter, err = registerJaeger(opts) case ZipkinTracingProviderName: - err = registerZipkin(opts) + exporter, err = registerZipkin(opts) default: - return fmt.Errorf("telemetry/trace: provider %s unknown", opts.Provider) + return nil, fmt.Errorf("telemetry/trace: provider %s unknown", opts.Provider) } if err != nil { - return err + return nil, err } if opts.Debug { log.Debug().Msg("telemetry/trace: debug on, sample everything") trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()}) } log.Debug().Interface("Opts", opts).Msg("telemetry/trace: exporter created") - return nil + return exporter, nil } -func registerJaeger(opts *TracingOptions) error { +// UnregisterTracing unregisters a trace exporter. +func UnregisterTracing(exporter trace.Exporter) { + trace.UnregisterExporter(exporter) +} + +func registerJaeger(opts *TracingOptions) (trace.Exporter, error) { jex, err := jaeger.NewExporter( jaeger.Options{ AgentEndpoint: opts.JaegerAgentEndpoint, @@ -73,16 +79,16 @@ func registerJaeger(opts *TracingOptions) error { ServiceName: opts.Service, }) if err != nil { - return err + return nil, err } trace.RegisterExporter(jex) - return nil + return jex, nil } -func registerZipkin(opts *TracingOptions) error { +func registerZipkin(opts *TracingOptions) (trace.Exporter, error) { localEndpoint, err := zipkin.NewEndpoint(opts.Service, "") if err != nil { - return fmt.Errorf("telemetry/trace: could not create local endpoint: %w", err) + return nil, fmt.Errorf("telemetry/trace: could not create local endpoint: %w", err) } reporter := zipkinHTTP.NewReporter(opts.ZipkinEndpoint) @@ -90,7 +96,7 @@ func registerZipkin(opts *TracingOptions) error { exporter := ocZipkin.NewExporter(reporter, localEndpoint) trace.RegisterExporter(exporter) - return nil + return exporter, nil } // StartSpan starts a new child span of the current span in the context. If diff --git a/internal/telemetry/trace/trace_test.go b/internal/telemetry/trace/trace_test.go index afcdf48cd..770024299 100644 --- a/internal/telemetry/trace/trace_test.go +++ b/internal/telemetry/trace/trace_test.go @@ -15,7 +15,7 @@ func TestRegisterTracing(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := RegisterTracing(tt.opts); (err != nil) != tt.wantErr { + if _, err := RegisterTracing(tt.opts); (err != nil) != tt.wantErr { t.Errorf("RegisterTracing() error = %v, wantErr %v", err, tt.wantErr) } })