mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
authorize: rewrite header evaluator to use go instead of rego (#5362)
* authorize: rewrite header evaluator to use go instead of rego * cache signed jwt * re-add missing trace * address comments
This commit is contained in:
parent
177f789e63
commit
37017e2a5b
7 changed files with 576 additions and 411 deletions
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
|
@ -27,6 +28,10 @@ import (
|
|||
// A Store stores data for the OPA rego policy evaluation.
|
||||
type Store struct {
|
||||
opastorage.Store
|
||||
|
||||
googleCloudServerlessAuthenticationServiceAccount atomic.Pointer[string]
|
||||
jwtClaimHeaders atomic.Pointer[map[string]string]
|
||||
signingKey atomic.Pointer[jose.JSONWebKey]
|
||||
}
|
||||
|
||||
// New creates a new Store.
|
||||
|
@ -36,15 +41,37 @@ func New() *Store {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Store) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
|
||||
v := s.googleCloudServerlessAuthenticationServiceAccount.Load()
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func (s *Store) GetJWTClaimHeaders() map[string]string {
|
||||
m := s.jwtClaimHeaders.Load()
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return *m
|
||||
}
|
||||
|
||||
func (s *Store) GetSigningKey() *jose.JSONWebKey {
|
||||
return s.signingKey.Load()
|
||||
}
|
||||
|
||||
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
|
||||
// service account in the store.
|
||||
func (s *Store) UpdateGoogleCloudServerlessAuthenticationServiceAccount(serviceAccount string) {
|
||||
s.write("/google_cloud_serverless_authentication_service_account", serviceAccount)
|
||||
s.googleCloudServerlessAuthenticationServiceAccount.Store(&serviceAccount)
|
||||
}
|
||||
|
||||
// UpdateJWTClaimHeaders updates the jwt claim headers in the store.
|
||||
func (s *Store) UpdateJWTClaimHeaders(jwtClaimHeaders map[string]string) {
|
||||
s.write("/jwt_claim_headers", jwtClaimHeaders)
|
||||
s.jwtClaimHeaders.Store(&jwtClaimHeaders)
|
||||
}
|
||||
|
||||
// UpdateRoutePolicies updates the route policies in the store.
|
||||
|
@ -56,6 +83,7 @@ func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) {
|
|||
// in rego use JWKs, so we take in that format.
|
||||
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||
s.write("/signing_key", signingKey)
|
||||
s.signingKey.Store(signingKey)
|
||||
}
|
||||
|
||||
func (s *Store) write(rawPath string, value any) {
|
||||
|
@ -111,40 +139,17 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
|||
}
|
||||
span.AddAttributes(octrace.StringAttribute("record_type", recordType.String()))
|
||||
|
||||
value, ok := op2.Value.(ast.String)
|
||||
recordIDOrIndex, ok := op2.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid record id: %T", op2)
|
||||
}
|
||||
span.AddAttributes(octrace.StringAttribute("record_id", value.String()))
|
||||
span.AddAttributes(octrace.StringAttribute("record_id", recordIDOrIndex.String()))
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: string(recordType),
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(string(value))
|
||||
|
||||
res, err := storage.GetQuerier(ctx).Query(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||
msg := s.GetDataBrokerRecord(ctx, string(recordType), string(recordIDOrIndex))
|
||||
if msg == nil {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
|
||||
// exclude expired records
|
||||
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
|
||||
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
|
||||
return ast.NullTerm(), nil
|
||||
}
|
||||
}
|
||||
|
||||
obj := toMap(msg)
|
||||
|
||||
regoValue, err := ast.InterfaceToValue(obj)
|
||||
|
@ -157,6 +162,38 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message {
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordIDOrIndex)
|
||||
|
||||
res, err := storage.GetQuerier(ctx).Query(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// exclude expired records
|
||||
if hasExpiresAt, ok := msg.(interface{ GetExpiresAt() *timestamppb.Timestamp }); ok && hasExpiresAt.GetExpiresAt() != nil {
|
||||
if hasExpiresAt.GetExpiresAt().AsTime().Before(time.Now()) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func toMap(msg proto.Message) map[string]any {
|
||||
bs, _ := json.Marshal(msg)
|
||||
var obj map[string]any
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue