pomerium/authorize/databroker.go
2025-01-20 19:42:45 -05:00

128 lines
3.5 KiB
Go

package authorize
import (
"context"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
octrace "go.opencensus.io/trace"
)
type sessionOrServiceAccount interface {
GetUserId() string
Validate() error
}
func getDataBrokerRecord(
ctx context.Context,
recordType string,
recordID string,
lowestRecordVersion uint64,
) (*databroker.Record, error) {
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerRecord")
span.AddAttributes(
octrace.StringAttribute("record_type", recordType),
octrace.StringAttribute("record_id", recordID),
octrace.Int64Attribute("lowest_record_version", int64(lowestRecordVersion)),
)
defer span.End()
q := storage.GetQuerier(ctx)
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordID)
res, err := q.Query(ctx, req)
if err != nil {
span.SetStatus(octrace.Status{Code: octrace.StatusCodeInternal, Message: err.Error()})
return nil, err
}
if len(res.GetRecords()) == 0 {
span.SetStatus(octrace.Status{Code: octrace.StatusCodeNotFound})
return nil, storage.ErrNotFound
}
// if the current record version is less than the lowest we'll accept, invalidate the cache
if v := res.GetRecords()[0].GetVersion(); v < lowestRecordVersion {
span.AddAttributes(octrace.Int64Attribute("got_record_version", int64(v)))
q.InvalidateCache(ctx, req)
} else {
return res.GetRecords()[0], nil
}
// retry with an up to date cache
res, err = q.Query(ctx, req)
if err != nil {
span.SetStatus(octrace.Status{Code: octrace.StatusCodeInternal, Message: err.Error()})
return nil, err
}
if len(res.GetRecords()) == 0 {
span.SetStatus(octrace.Status{Code: octrace.StatusCodeNotFound})
return nil, storage.ErrNotFound
}
return res.GetRecords()[0], nil
}
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
ctx context.Context,
sessionID string,
dataBrokerRecordVersion uint64,
) (s sessionOrServiceAccount, err error) {
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
defer span.End()
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
if storage.IsNotFound(err) {
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
}
if err != nil {
return nil, err
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, err
}
s = msg.(sessionOrServiceAccount)
if err := s.Validate(); err != nil {
return nil, err
}
if _, ok := s.(*session.Session); ok {
a.accessTracker.TrackSessionAccess(sessionID)
}
if _, ok := s.(*user.ServiceAccount); ok {
a.accessTracker.TrackServiceAccountAccess(sessionID)
}
return s, nil
}
func (a *Authorize) getDataBrokerUser(
ctx context.Context,
userID string,
) (*user.User, error) {
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
span.AddAttributes(octrace.StringAttribute("user_id", userID))
defer span.End()
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
if err != nil {
span.SetStatus(octrace.Status{Code: octrace.StatusCodeInternal, Message: err.Error()})
return nil, err
}
var u user.User
err = record.GetData().UnmarshalTo(&u)
if err != nil {
return nil, err
}
return &u, nil
}