From 3fd66c1401948ba6d4e830fb80a3656739d3a22c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 1 Sep 2020 23:09:32 +0700 Subject: [PATCH] internal/directory/okta: accept non-json service account (#1359) (#1360) Fixes #1354 Co-authored-by: Cuong Manh Le --- internal/directory/okta/okta.go | 5 ++--- internal/directory/okta/okta_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/internal/directory/okta/okta.go b/internal/directory/okta/okta.go index 10e42f8f0..2e23a143f 100644 --- a/internal/directory/okta/okta.go +++ b/internal/directory/okta/okta.go @@ -291,9 +291,8 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { } var serviceAccount ServiceAccount - err = json.Unmarshal(bs, &serviceAccount) - if err != nil { - return nil, err + if err := json.Unmarshal(bs, &serviceAccount); err != nil { + serviceAccount.APIKey = string(bs) } if serviceAccount.APIKey == "" { diff --git a/internal/directory/okta/okta_test.go b/internal/directory/okta/okta_test.go index c30e4afc9..ad1da56d4 100644 --- a/internal/directory/okta/okta_test.go +++ b/internal/directory/okta/okta_test.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tomnomnom/linkheader" "github.com/pomerium/pomerium/pkg/grpc/directory" @@ -207,3 +208,28 @@ func mustParseURL(rawurl string) *url.URL { } return u } + +func TestParseServiceAccount(t *testing.T) { + tests := []struct { + name string + rawServiceAccount string + apiKey string + wantErr bool + }{ + {"json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false}, + {"value", "Zm9v", "foo", false}, + {"empty", "", "", true}, + {"invalid", "Zm9v---", "", true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := ParseServiceAccount(tc.rawServiceAccount) + require.True(t, (err != nil) == tc.wantErr) + if tc.apiKey != "" { + assert.Equal(t, tc.apiKey, got.APIKey) + } + }) + } +}