From 6810091d3897910ef447a377a6136c767e671ce0 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 17 Nov 2023 09:04:59 -0700 Subject: [PATCH] core/zero: add support for managed mode from config file (#4756) --- cmd/pomerium/main.go | 4 ++-- internal/zero/cmd/command.go | 8 +++---- internal/zero/cmd/env.go | 24 +++++++++++++++++--- internal/zero/cmd/env_test.go | 41 +++++++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 internal/zero/cmd/env_test.go diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index f4f290438..490da99f9 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -32,8 +32,8 @@ func main() { ctx := context.Background() runFn := run - if zero_cmd.IsManagedMode() { - runFn = zero_cmd.Run + if zero_cmd.IsManagedMode(*configFile) { + runFn = func(ctx context.Context) error { return zero_cmd.Run(ctx, *configFile) } } if err := runFn(ctx); err != nil && !errors.Is(err, context.Canceled) { diff --git a/internal/zero/cmd/command.go b/internal/zero/cmd/command.go index 24321b4f0..8ed466106 100644 --- a/internal/zero/cmd/command.go +++ b/internal/zero/cmd/command.go @@ -17,13 +17,13 @@ import ( ) // Run runs the pomerium zero command. -func Run(ctx context.Context) error { +func Run(ctx context.Context, configFile string) error { err := setupLogger() if err != nil { return fmt.Errorf("error setting up logger: %w", err) } - token := getToken() + token := getToken(configFile) if token == "" { return errors.New("no token provided") } @@ -37,8 +37,8 @@ func Run(ctx context.Context) error { } // IsManagedMode returns true if Pomerium should start in managed mode using this command. -func IsManagedMode() bool { - return getToken() != "" +func IsManagedMode(configFile string) bool { + return getToken(configFile) != "" } func withInterrupt(ctx context.Context) context.Context { diff --git a/internal/zero/cmd/env.go b/internal/zero/cmd/env.go index 5ee6a3c37..a06811613 100644 --- a/internal/zero/cmd/env.go +++ b/internal/zero/cmd/env.go @@ -1,6 +1,10 @@ package cmd -import "os" +import ( + "os" + + "github.com/spf13/viper" +) const ( // PomeriumZeroTokenEnv is the environment variable name for the API token. @@ -8,6 +12,20 @@ const ( PomeriumZeroTokenEnv = "POMERIUM_ZERO_TOKEN" ) -func getToken() string { - return os.Getenv(PomeriumZeroTokenEnv) +func getToken(configFile string) string { + if token, ok := os.LookupEnv(PomeriumZeroTokenEnv); ok { + return token + } + + if configFile != "" { + // load the token from the config file + v := viper.New() + v.SetConfigFile(configFile) + if v.ReadInConfig() == nil { + return v.GetString("pomerium_zero_token") + } + } + + // we will fallback to normal pomerium if empty + return "" } diff --git a/internal/zero/cmd/env_test.go b/internal/zero/cmd/env_test.go new file mode 100644 index 000000000..34573e330 --- /dev/null +++ b/internal/zero/cmd/env_test.go @@ -0,0 +1,41 @@ +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_getToken(t *testing.T) { + t.Run("empty", func(t *testing.T) { + assert.Equal(t, "", getToken("")) + }) + t.Run("env", func(t *testing.T) { + t.Setenv("POMERIUM_ZERO_TOKEN", "FROM_ENV") + assert.Equal(t, "FROM_ENV", getToken("")) + }) + t.Run("json", func(t *testing.T) { + fp := filepath.Join(t.TempDir(), "config.json") + require.NoError(t, os.WriteFile(fp, []byte(`{ + "pomerium_zero_token": "FROM_JSON" + }`), 0o644)) + assert.Equal(t, "FROM_JSON", getToken(fp)) + }) + t.Run("yaml", func(t *testing.T) { + fp := filepath.Join(t.TempDir(), "config.yaml") + require.NoError(t, os.WriteFile(fp, []byte(` +pomerium_zero_token: FROM_YAML +`), 0o644)) + assert.Equal(t, "FROM_YAML", getToken(fp)) + }) + t.Run("toml", func(t *testing.T) { + fp := filepath.Join(t.TempDir(), "config.toml") + require.NoError(t, os.WriteFile(fp, []byte(` +pomerium_zero_token = "FROM_TOML" +`), 0o644)) + assert.Equal(t, "FROM_TOML", getToken(fp)) + }) +}