mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
* identity: add support for verifying access and identity tokens * allow overriding with policy option * authenticate: add verify endpoints * wip * implement session creation * add verify test * implement idp token login * fix tests * add pr permission * make session ids route-specific * rename method * add test * add access token test * test for newUserFromIDPClaims * more tests * make the session id per-idp * use type for * add test * remove nil checks
114 lines
2.6 KiB
Go
114 lines
2.6 KiB
Go
package authorize
|
|
|
|
import (
|
|
"context"
|
|
|
|
"google.golang.org/grpc"
|
|
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
type sessionOrServiceAccount interface {
|
|
GetId() string
|
|
GetUserId() string
|
|
Validate() error
|
|
}
|
|
|
|
func getDataBrokerRecord(
|
|
ctx context.Context,
|
|
recordType string,
|
|
recordID string,
|
|
lowestRecordVersion uint64,
|
|
) (*databroker.Record, error) {
|
|
q := storage.GetQuerier(ctx)
|
|
|
|
req := &databroker.QueryRequest{
|
|
Type: recordType,
|
|
Limit: 1,
|
|
}
|
|
req.SetFilterByIDOrIndex(recordID)
|
|
|
|
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(res.GetRecords()) == 0 {
|
|
return nil, storage.ErrNotFound
|
|
}
|
|
|
|
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
|
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
|
q.InvalidateCache(ctx, req)
|
|
} else {
|
|
return res.GetRecords()[0], nil
|
|
}
|
|
|
|
// retry with an up to date cache
|
|
res, err = q.Query(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(res.GetRecords()) == 0 {
|
|
return nil, storage.ErrNotFound
|
|
}
|
|
|
|
return res.GetRecords()[0], nil
|
|
}
|
|
|
|
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
|
ctx context.Context,
|
|
sessionID string,
|
|
dataBrokerRecordVersion uint64,
|
|
) (s sessionOrServiceAccount, err error) {
|
|
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
|
defer span.End()
|
|
|
|
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
|
if storage.IsNotFound(err) {
|
|
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
msg, err := record.GetData().UnmarshalNew()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s = msg.(sessionOrServiceAccount)
|
|
if err := s.Validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if _, ok := s.(*session.Session); ok {
|
|
a.accessTracker.TrackSessionAccess(sessionID)
|
|
}
|
|
if _, ok := s.(*user.ServiceAccount); ok {
|
|
a.accessTracker.TrackServiceAccountAccess(sessionID)
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (a *Authorize) getDataBrokerUser(
|
|
ctx context.Context,
|
|
userID string,
|
|
) (*user.User, error) {
|
|
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
|
|
defer span.End()
|
|
|
|
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var u user.User
|
|
err = record.GetData().UnmarshalTo(&u)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|