mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
176 lines
4.6 KiB
Go
176 lines
4.6 KiB
Go
package evaluator
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/open-policy-agent/opa/ast"
|
|
"github.com/open-policy-agent/opa/rego"
|
|
"github.com/open-policy-agent/opa/types"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/sync/singleflight"
|
|
"google.golang.org/api/idtoken"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
)
|
|
|
|
// GCP pre-defined values.
|
|
var (
|
|
GCPIdentityTokenExpiration = time.Minute * 45 // tokens expire after one hour according to the GCP docs
|
|
GCPIdentityDocURL = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity"
|
|
GCPIdentityNow = time.Now
|
|
GCPIdentityMaxBodySize int64 = 1024 * 1024 * 10
|
|
|
|
getGoogleCloudServerlessHeadersRegoOption = rego.Function2(®o.Function{
|
|
Name: "get_google_cloud_serverless_headers",
|
|
Decl: types.NewFunction(
|
|
types.Args(types.S, types.S),
|
|
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)),
|
|
),
|
|
}, func(_ rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
|
|
serviceAccount, ok := op1.Value.(ast.String)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid service account type: %T", op1)
|
|
}
|
|
|
|
audience, ok := op2.Value.(ast.String)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid audience type: %T", op2)
|
|
}
|
|
|
|
headers, err := getGoogleCloudServerlessHeaders(string(serviceAccount), string(audience))
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("error retrieving google cloud serverless headers")
|
|
return nil, fmt.Errorf("failed to get google cloud serverless headers: %w", err)
|
|
}
|
|
var kvs [][2]*ast.Term
|
|
for k, v := range headers {
|
|
kvs = append(kvs, [2]*ast.Term{ast.StringTerm(k), ast.StringTerm(v)})
|
|
}
|
|
|
|
return ast.ObjectTerm(kvs...), nil
|
|
})
|
|
)
|
|
|
|
type gcpIdentityTokenSource struct {
|
|
audience string
|
|
singleflight singleflight.Group
|
|
}
|
|
|
|
func (src *gcpIdentityTokenSource) Token() (*oauth2.Token, error) {
|
|
res, err, _ := src.singleflight.Do("", func() (any, error) {
|
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, GCPIdentityDocURL+"?"+url.Values{
|
|
"format": {"full"},
|
|
"audience": {src.audience},
|
|
}.Encode(), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Add("Metadata-Flavor", "Google")
|
|
|
|
res, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = res.Body.Close() }()
|
|
|
|
bs, err := io.ReadAll(io.LimitReader(res.Body, GCPIdentityMaxBodySize))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return string(bs), nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &oauth2.Token{
|
|
AccessToken: strings.TrimSpace(res.(string)),
|
|
TokenType: "bearer",
|
|
Expiry: GCPIdentityNow().Add(GCPIdentityTokenExpiration),
|
|
}, nil
|
|
}
|
|
|
|
type gcpTokenSourceKey struct {
|
|
serviceAccount string
|
|
audience string
|
|
}
|
|
|
|
var gcpTokenSources = struct {
|
|
sync.Mutex
|
|
m map[gcpTokenSourceKey]oauth2.TokenSource
|
|
}{
|
|
m: make(map[gcpTokenSourceKey]oauth2.TokenSource),
|
|
}
|
|
|
|
func normalizeServiceAccount(serviceAccount string) (string, error) {
|
|
serviceAccount = strings.TrimSpace(serviceAccount)
|
|
|
|
// the service account can be base64 encoded
|
|
if !strings.HasPrefix(serviceAccount, "{") {
|
|
bs, err := base64.StdEncoding.DecodeString(serviceAccount)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
serviceAccount = string(bs)
|
|
}
|
|
return serviceAccount, nil
|
|
}
|
|
|
|
func getGoogleCloudServerlessTokenSource(serviceAccount, audience string) (oauth2.TokenSource, error) {
|
|
key := gcpTokenSourceKey{
|
|
serviceAccount: serviceAccount,
|
|
audience: audience,
|
|
}
|
|
|
|
gcpTokenSources.Lock()
|
|
defer gcpTokenSources.Unlock()
|
|
|
|
src, ok := gcpTokenSources.m[key]
|
|
if ok {
|
|
return src, nil
|
|
}
|
|
|
|
if serviceAccount == "" {
|
|
src = oauth2.ReuseTokenSource(new(oauth2.Token), &gcpIdentityTokenSource{
|
|
audience: audience,
|
|
})
|
|
} else {
|
|
serviceAccount, err := normalizeServiceAccount(serviceAccount)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newSrc, err := idtoken.NewTokenSource(context.Background(), audience, idtoken.WithCredentialsJSON([]byte(serviceAccount)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
src = newSrc
|
|
}
|
|
|
|
gcpTokenSources.m[key] = src
|
|
|
|
return src, nil
|
|
}
|
|
|
|
func getGoogleCloudServerlessHeaders(serviceAccount, audience string) (map[string]string, error) {
|
|
src, err := getGoogleCloudServerlessTokenSource(serviceAccount, audience)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error retrieving google cloud serverless token source: %w", err)
|
|
}
|
|
|
|
tok, err := src.Token()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error retrieving google cloud serverless token: %w", err)
|
|
}
|
|
|
|
return map[string]string{
|
|
"Authorization": "Bearer " + tok.AccessToken,
|
|
}, nil
|
|
}
|