mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-18 11:37:08 +02:00
authorize: move service account normalization to its own function
This helps testing the code easier, increase coverage.
This commit is contained in:
parent
e6c78f10e9
commit
0624658e4b
2 changed files with 46 additions and 13 deletions
|
@ -80,6 +80,20 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func normalizeServiceAccount(serviceAccount string) (string, error) {
|
||||
serviceAccount = strings.TrimSpace(serviceAccount)
|
||||
|
||||
// the service account can be base64 encoded
|
||||
if !strings.HasPrefix(serviceAccount, "{") {
|
||||
bs, err := base64.StdEncoding.DecodeString(serviceAccount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
serviceAccount = string(bs)
|
||||
}
|
||||
return serviceAccount, nil
|
||||
}
|
||||
|
||||
func getGoogleCloudServerlessTokenSource(serviceAccount, audience string) (oauth2.TokenSource, error) {
|
||||
key := gcpTokenSourceKey{
|
||||
serviceAccount: serviceAccount,
|
||||
|
@ -99,22 +113,15 @@ func getGoogleCloudServerlessTokenSource(serviceAccount, audience string) (oauth
|
|||
audience: audience,
|
||||
})
|
||||
} else {
|
||||
serviceAccount = strings.TrimSpace(serviceAccount)
|
||||
|
||||
// the service account can be base64 encoded
|
||||
if !strings.HasPrefix(serviceAccount, "{") {
|
||||
bs, err := base64.StdEncoding.DecodeString(serviceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serviceAccount = string(bs)
|
||||
}
|
||||
|
||||
var err error
|
||||
src, err = idtoken.NewTokenSource(context.Background(), audience, idtoken.WithCredentialsJSON([]byte(serviceAccount)))
|
||||
serviceAccount, err := normalizeServiceAccount(serviceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSrc, err := idtoken.NewTokenSource(context.Background(), audience, idtoken.WithCredentialsJSON([]byte(serviceAccount)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
src = newSrc
|
||||
}
|
||||
|
||||
gcpTokenSources.m[key] = src
|
||||
|
|
|
@ -38,3 +38,29 @@ func TestGCPIdentityTokenSource(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2020-01-01T01:00:00Z", token.AccessToken)
|
||||
}
|
||||
|
||||
func Test_normalizeServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serviceAccount string
|
||||
expectedServiceAccount string
|
||||
wantError bool
|
||||
}{
|
||||
{"empty", "", "", false},
|
||||
{"leading spaces", ` {"service_account": "foo"}`, `{"service_account": "foo"}`, false},
|
||||
{"trailing spaces", `{"service_account": "foo"} `, `{"service_account": "foo"}`, false},
|
||||
{"leading+trailing spaces", ` {"service_account": "foo"} `, `{"service_account": "foo"}`, false},
|
||||
{"base64", "eyJzZXJ2aWNlX2FjY291bnQiOiAiZm9vIn0=", `{"service_account": "foo"}`, false},
|
||||
{"invalid base64", "--eyJzZXJ2aWNlX2FjY291bnQiOiAiZm9vIn0=--", "", true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gotServiceAccount, err := normalizeServiceAccount(tc.serviceAccount)
|
||||
assert.True(t, (err != nil) == tc.wantError)
|
||||
assert.Equal(t, tc.expectedServiceAccount, gotServiceAccount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue