authorize: move service account normalization to its own function

This helps testing the code easier, increase coverage.
This commit is contained in:
Cuong Manh Le 2020-08-06 13:19:23 +07:00
parent e6c78f10e9
commit 0624658e4b
2 changed files with 46 additions and 13 deletions

View file

@ -80,6 +80,20 @@ var (
}
)
func normalizeServiceAccount(serviceAccount string) (string, error) {
serviceAccount = strings.TrimSpace(serviceAccount)
// the service account can be base64 encoded
if !strings.HasPrefix(serviceAccount, "{") {
bs, err := base64.StdEncoding.DecodeString(serviceAccount)
if err != nil {
return "", err
}
serviceAccount = string(bs)
}
return serviceAccount, nil
}
func getGoogleCloudServerlessTokenSource(serviceAccount, audience string) (oauth2.TokenSource, error) {
key := gcpTokenSourceKey{
serviceAccount: serviceAccount,
@ -99,22 +113,15 @@ func getGoogleCloudServerlessTokenSource(serviceAccount, audience string) (oauth
audience: audience,
})
} else {
serviceAccount = strings.TrimSpace(serviceAccount)
// the service account can be base64 encoded
if !strings.HasPrefix(serviceAccount, "{") {
bs, err := base64.StdEncoding.DecodeString(serviceAccount)
if err != nil {
return nil, err
}
serviceAccount = string(bs)
}
var err error
src, err = idtoken.NewTokenSource(context.Background(), audience, idtoken.WithCredentialsJSON([]byte(serviceAccount)))
serviceAccount, err := normalizeServiceAccount(serviceAccount)
if err != nil {
return nil, err
}
newSrc, err := idtoken.NewTokenSource(context.Background(), audience, idtoken.WithCredentialsJSON([]byte(serviceAccount)))
if err != nil {
return nil, err
}
src = newSrc
}
gcpTokenSources.m[key] = src

View file

@ -38,3 +38,29 @@ func TestGCPIdentityTokenSource(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "2020-01-01T01:00:00Z", token.AccessToken)
}
func Test_normalizeServiceAccount(t *testing.T) {
tests := []struct {
name string
serviceAccount string
expectedServiceAccount string
wantError bool
}{
{"empty", "", "", false},
{"leading spaces", ` {"service_account": "foo"}`, `{"service_account": "foo"}`, false},
{"trailing spaces", `{"service_account": "foo"} `, `{"service_account": "foo"}`, false},
{"leading+trailing spaces", ` {"service_account": "foo"} `, `{"service_account": "foo"}`, false},
{"base64", "eyJzZXJ2aWNlX2FjY291bnQiOiAiZm9vIn0=", `{"service_account": "foo"}`, false},
{"invalid base64", "--eyJzZXJ2aWNlX2FjY291bnQiOiAiZm9vIn0=--", "", true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
gotServiceAccount, err := normalizeServiceAccount(tc.serviceAccount)
assert.True(t, (err != nil) == tc.wantError)
assert.Equal(t, tc.expectedServiceAccount, gotServiceAccount)
})
}
}