core/zero: add support for managed mode from config file (#4756)

This commit is contained in:
Caleb Doxsey 2023-11-17 09:04:59 -07:00 committed by GitHub
parent eb729a53f8
commit 6810091d38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 9 deletions

View file

@ -32,8 +32,8 @@ func main() {
ctx := context.Background()
runFn := run
if zero_cmd.IsManagedMode() {
runFn = zero_cmd.Run
if zero_cmd.IsManagedMode(*configFile) {
runFn = func(ctx context.Context) error { return zero_cmd.Run(ctx, *configFile) }
}
if err := runFn(ctx); err != nil && !errors.Is(err, context.Canceled) {

View file

@ -17,13 +17,13 @@ import (
)
// Run runs the pomerium zero command.
func Run(ctx context.Context) error {
func Run(ctx context.Context, configFile string) error {
err := setupLogger()
if err != nil {
return fmt.Errorf("error setting up logger: %w", err)
}
token := getToken()
token := getToken(configFile)
if token == "" {
return errors.New("no token provided")
}
@ -37,8 +37,8 @@ func Run(ctx context.Context) error {
}
// IsManagedMode returns true if Pomerium should start in managed mode using this command.
func IsManagedMode() bool {
return getToken() != ""
func IsManagedMode(configFile string) bool {
return getToken(configFile) != ""
}
func withInterrupt(ctx context.Context) context.Context {

View file

@ -1,6 +1,10 @@
package cmd
import "os"
import (
"os"
"github.com/spf13/viper"
)
const (
// PomeriumZeroTokenEnv is the environment variable name for the API token.
@ -8,6 +12,20 @@ const (
PomeriumZeroTokenEnv = "POMERIUM_ZERO_TOKEN"
)
func getToken() string {
return os.Getenv(PomeriumZeroTokenEnv)
func getToken(configFile string) string {
if token, ok := os.LookupEnv(PomeriumZeroTokenEnv); ok {
return token
}
if configFile != "" {
// load the token from the config file
v := viper.New()
v.SetConfigFile(configFile)
if v.ReadInConfig() == nil {
return v.GetString("pomerium_zero_token")
}
}
// we will fallback to normal pomerium if empty
return ""
}

View file

@ -0,0 +1,41 @@
package cmd
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_getToken(t *testing.T) {
t.Run("empty", func(t *testing.T) {
assert.Equal(t, "", getToken(""))
})
t.Run("env", func(t *testing.T) {
t.Setenv("POMERIUM_ZERO_TOKEN", "FROM_ENV")
assert.Equal(t, "FROM_ENV", getToken(""))
})
t.Run("json", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "config.json")
require.NoError(t, os.WriteFile(fp, []byte(`{
"pomerium_zero_token": "FROM_JSON"
}`), 0o644))
assert.Equal(t, "FROM_JSON", getToken(fp))
})
t.Run("yaml", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "config.yaml")
require.NoError(t, os.WriteFile(fp, []byte(`
pomerium_zero_token: FROM_YAML
`), 0o644))
assert.Equal(t, "FROM_YAML", getToken(fp))
})
t.Run("toml", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "config.toml")
require.NoError(t, os.WriteFile(fp, []byte(`
pomerium_zero_token = "FROM_TOML"
`), 0o644))
assert.Equal(t, "FROM_TOML", getToken(fp))
})
}