pomerium/authorize/evaluator/google_cloud_serverless.go
Caleb Doxsey dad954ae16
core/logging: change log.Error function (#5251)
* core/logging: change log.Error function

* use request id
2024-09-05 15:42:46 -06:00

176 lines
4.6 KiB
Go

package evaluator
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/types"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"google.golang.org/api/idtoken"
"github.com/pomerium/pomerium/internal/log"
)
// GCP pre-defined values.
var (
GCPIdentityTokenExpiration = time.Minute * 45 // tokens expire after one hour according to the GCP docs
GCPIdentityDocURL = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity"
GCPIdentityNow = time.Now
GCPIdentityMaxBodySize int64 = 1024 * 1024 * 10
getGoogleCloudServerlessHeadersRegoOption = rego.Function2(&rego.Function{
Name: "get_google_cloud_serverless_headers",
Decl: types.NewFunction(
types.Args(types.S, types.S),
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)),
),
}, func(_ rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
serviceAccount, ok := op1.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid service account type: %T", op1)
}
audience, ok := op2.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid audience type: %T", op2)
}
headers, err := getGoogleCloudServerlessHeaders(string(serviceAccount), string(audience))
if err != nil {
log.Error().Err(err).Msg("error retrieving google cloud serverless headers")
return nil, fmt.Errorf("failed to get google cloud serverless headers: %w", err)
}
var kvs [][2]*ast.Term
for k, v := range headers {
kvs = append(kvs, [2]*ast.Term{ast.StringTerm(k), ast.StringTerm(v)})
}
return ast.ObjectTerm(kvs...), nil
})
)
type gcpIdentityTokenSource struct {
audience string
singleflight singleflight.Group
}
func (src *gcpIdentityTokenSource) Token() (*oauth2.Token, error) {
res, err, _ := src.singleflight.Do("", func() (any, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, GCPIdentityDocURL+"?"+url.Values{
"format": {"full"},
"audience": {src.audience},
}.Encode(), nil)
if err != nil {
return nil, err
}
req.Header.Add("Metadata-Flavor", "Google")
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = res.Body.Close() }()
bs, err := io.ReadAll(io.LimitReader(res.Body, GCPIdentityMaxBodySize))
if err != nil {
return nil, err
}
return string(bs), nil
})
if err != nil {
return nil, err
}
return &oauth2.Token{
AccessToken: strings.TrimSpace(res.(string)),
TokenType: "bearer",
Expiry: GCPIdentityNow().Add(GCPIdentityTokenExpiration),
}, nil
}
type gcpTokenSourceKey struct {
serviceAccount string
audience string
}
var gcpTokenSources = struct {
sync.Mutex
m map[gcpTokenSourceKey]oauth2.TokenSource
}{
m: make(map[gcpTokenSourceKey]oauth2.TokenSource),
}
func normalizeServiceAccount(serviceAccount string) (string, error) {
serviceAccount = strings.TrimSpace(serviceAccount)
// the service account can be base64 encoded
if !strings.HasPrefix(serviceAccount, "{") {
bs, err := base64.StdEncoding.DecodeString(serviceAccount)
if err != nil {
return "", err
}
serviceAccount = string(bs)
}
return serviceAccount, nil
}
func getGoogleCloudServerlessTokenSource(serviceAccount, audience string) (oauth2.TokenSource, error) {
key := gcpTokenSourceKey{
serviceAccount: serviceAccount,
audience: audience,
}
gcpTokenSources.Lock()
defer gcpTokenSources.Unlock()
src, ok := gcpTokenSources.m[key]
if ok {
return src, nil
}
if serviceAccount == "" {
src = oauth2.ReuseTokenSource(new(oauth2.Token), &gcpIdentityTokenSource{
audience: audience,
})
} else {
serviceAccount, err := normalizeServiceAccount(serviceAccount)
if err != nil {
return nil, err
}
newSrc, err := idtoken.NewTokenSource(context.Background(), audience, idtoken.WithCredentialsJSON([]byte(serviceAccount)))
if err != nil {
return nil, err
}
src = newSrc
}
gcpTokenSources.m[key] = src
return src, nil
}
func getGoogleCloudServerlessHeaders(serviceAccount, audience string) (map[string]string, error) {
src, err := getGoogleCloudServerlessTokenSource(serviceAccount, audience)
if err != nil {
return nil, fmt.Errorf("error retrieving google cloud serverless token source: %w", err)
}
tok, err := src.Token()
if err != nil {
return nil, fmt.Errorf("error retrieving google cloud serverless token: %w", err)
}
return map[string]string{
"Authorization": "Bearer " + tok.AccessToken,
}, nil
}