mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 20:06:03 +02:00
authorize: performance improvements (#3723)
This commit is contained in:
parent
a3cfe8fa42
commit
02df20f10a
4 changed files with 50 additions and 20 deletions
|
@ -87,12 +87,23 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) getDataBrokerUser(ctx context.Context, userID string) (u *user.User, err error) {
|
func (a *Authorize) getDataBrokerUser(
|
||||||
|
ctx context.Context,
|
||||||
|
userID string,
|
||||||
|
dataBrokerRecordVersion uint64,
|
||||||
|
) (*user.User, error) {
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
|
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
client := a.state.Load().dataBrokerClient
|
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, dataBrokerRecordVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
u, err = user.Get(ctx, client, userID)
|
var u user.User
|
||||||
return u, err
|
err = record.GetData().UnmarshalTo(&u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
"github.com/open-policy-agent/opa/rego"
|
"github.com/open-policy-agent/opa/rego"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
@ -147,18 +148,29 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
|
||||||
return nil, fmt.Errorf("authorize: error validating client certificate: %w", err)
|
return nil, fmt.Errorf("authorize: error validating client certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
policyOutput, err := policyEvaluator.Evaluate(ctx, &PolicyRequest{
|
eg, ectx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
var policyOutput *PolicyResponse
|
||||||
|
eg.Go(func() error {
|
||||||
|
var err error
|
||||||
|
policyOutput, err = policyEvaluator.Evaluate(ectx, &PolicyRequest{
|
||||||
HTTP: req.HTTP,
|
HTTP: req.HTTP,
|
||||||
Session: req.Session,
|
Session: req.Session,
|
||||||
IsValidClientCertificate: isValidClientCertificate,
|
IsValidClientCertificate: isValidClientCertificate,
|
||||||
})
|
})
|
||||||
if err != nil {
|
return err
|
||||||
return nil, err
|
})
|
||||||
}
|
|
||||||
|
|
||||||
|
var headersOutput *HeadersResponse
|
||||||
|
eg.Go(func() error {
|
||||||
headersReq := NewHeadersRequestFromPolicy(req.Policy)
|
headersReq := NewHeadersRequestFromPolicy(req.Policy)
|
||||||
headersReq.Session = req.Session
|
headersReq.Session = req.Session
|
||||||
headersOutput, err := e.headersEvaluators.Evaluate(ctx, headersReq)
|
var err error
|
||||||
|
headersOutput, err = e.headersEvaluators.Evaluate(ectx, headersReq)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
err = eg.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,8 +66,8 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
sessionState = nil
|
sessionState = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if s != nil {
|
if sessionState != nil && s != nil {
|
||||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
u, _ = a.getDataBrokerUser(ctx, s.GetUserId(), sessionState.DatabrokerRecordVersion) // ignore any missing user error
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
|
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
|
||||||
|
|
|
@ -13,11 +13,13 @@ import (
|
||||||
opastorage "github.com/open-policy-agent/opa/storage"
|
opastorage "github.com/open-policy-agent/opa/storage"
|
||||||
"github.com/open-policy-agent/opa/storage/inmem"
|
"github.com/open-policy-agent/opa/storage/inmem"
|
||||||
"github.com/open-policy-agent/opa/types"
|
"github.com/open-policy-agent/opa/types"
|
||||||
|
octrace "go.opencensus.io/trace"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
@ -105,15 +107,20 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
||||||
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)),
|
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)),
|
||||||
),
|
),
|
||||||
}, func(bctx rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
|
}, func(bctx rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
|
||||||
|
ctx, span := trace.StartSpan(bctx.Context, "rego.get_databroker_record")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
recordType, ok := op1.Value.(ast.String)
|
recordType, ok := op1.Value.(ast.String)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid record type: %T", op1)
|
return nil, fmt.Errorf("invalid record type: %T", op1)
|
||||||
}
|
}
|
||||||
|
span.AddAttributes(octrace.StringAttribute("record_type", recordType.String()))
|
||||||
|
|
||||||
value, ok := op2.Value.(ast.String)
|
value, ok := op2.Value.(ast.String)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid record id: %T", op2)
|
return nil, fmt.Errorf("invalid record id: %T", op2)
|
||||||
}
|
}
|
||||||
|
span.AddAttributes(octrace.StringAttribute("record_id", value.String()))
|
||||||
|
|
||||||
req := &databroker.QueryRequest{
|
req := &databroker.QueryRequest{
|
||||||
Type: string(recordType),
|
Type: string(recordType),
|
||||||
|
@ -121,9 +128,9 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
||||||
}
|
}
|
||||||
req.SetFilterByIDOrIndex(string(value))
|
req.SetFilterByIDOrIndex(string(value))
|
||||||
|
|
||||||
res, err := storage.GetQuerier(bctx.Context).Query(bctx.Context, req)
|
res, err := storage.GetQuerier(ctx).Query(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(bctx.Context).Err(err).Msg("authorize/store: error retrieving record")
|
log.Error(ctx).Err(err).Msg("authorize/store: error retrieving record")
|
||||||
return ast.NullTerm(), nil
|
return ast.NullTerm(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +154,7 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
||||||
|
|
||||||
regoValue, err := ast.InterfaceToValue(obj)
|
regoValue, err := ast.InterfaceToValue(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(bctx.Context).Err(err).Msg("authorize/store: error converting object to rego")
|
log.Error(ctx).Err(err).Msg("authorize/store: error converting object to rego")
|
||||||
return ast.NullTerm(), nil
|
return ast.NullTerm(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue