authorize: add databroker server and record version to result, force sync via polling (#2024)

* authorize: add databroker server and record version to result, force sync via polling

* wrap inmem store to take read lock when grabbing databroker versions

* address code review comments

* reset max to 0
This commit is contained in:
Caleb Doxsey 2021-03-31 10:09:06 -06:00 committed by GitHub
parent 8f97b0d6ee
commit d7ab817de7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 467 additions and 362 deletions

View file

@ -39,7 +39,7 @@ func TestAuthorize_okResponse(t *testing.T) {
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder
a.currentOptions.Store(opt)
a.store = evaluator.NewStoreFromProtos(
a.store = evaluator.NewStoreFromProtos(0,
&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",

View file

@ -14,7 +14,7 @@ func TestCustomEvaluator(t *testing.T) {
store := NewStore()
t.Run("bool deny", func(t *testing.T) {
ce := NewCustomEvaluator(store.opaStore)
ce := NewCustomEvaluator(store)
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
RegoPolicy: `
package pomerium.custom_policy
@ -29,7 +29,7 @@ func TestCustomEvaluator(t *testing.T) {
assert.Empty(t, res.Reason)
})
t.Run("set deny", func(t *testing.T) {
ce := NewCustomEvaluator(store.opaStore)
ce := NewCustomEvaluator(store)
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
RegoPolicy: `
package pomerium.custom_policy
@ -44,7 +44,7 @@ func TestCustomEvaluator(t *testing.T) {
assert.Equal(t, "test", res.Reason)
})
t.Run("missing package", func(t *testing.T) {
ce := NewCustomEvaluator(store.opaStore)
ce := NewCustomEvaluator(store)
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
RegoPolicy: `allow = true`,
})

View file

@ -7,7 +7,6 @@ import (
"encoding/base64"
"fmt"
"net/http"
"strconv"
"github.com/open-policy-agent/opa/rego"
"gopkg.in/square/go-jose.v2"
@ -29,7 +28,7 @@ type Evaluator struct {
// New creates a new Evaluator.
func New(options *config.Options, store *Store) (*Evaluator, error) {
e := &Evaluator{
custom: NewCustomEvaluator(store.opaStore),
custom: NewCustomEvaluator(store),
policies: options.GetAllPolicies(),
store: store,
}
@ -55,7 +54,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) {
store.UpdateSigningKey(jwk)
e.rego = rego.New(
rego.Store(store.opaStore),
rego.Store(store),
rego.Module("pomerium.authz", string(authzPolicy)),
rego.Query("result = data.pomerium.authz"),
getGoogleCloudServerlessHeadersRegoOption,
@ -91,6 +90,9 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
MatchingPolicy: getMatchingPolicy(res[0].Bindings.WithoutWildcards(), e.policies),
Headers: getHeadersVar(res[0].Bindings.WithoutWildcards()),
}
evalResult.DataBrokerServerVersion, evalResult.DataBrokerRecordVersion = getDataBrokerVersions(
res[0].Bindings,
)
allow := getAllowVar(res[0].Bindings.WithoutWildcards())
// evaluate any custom policies
@ -181,95 +183,3 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input
i.IsValidClientCertificate = isValidClientCertificate
return i
}
// Result is the result of evaluation.
type Result struct {
Status int
Message string
Headers map[string]string
MatchingPolicy *config.Policy
}
func getMatchingPolicy(vars rego.Vars, policies []config.Policy) *config.Policy {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return nil
}
idx, err := strconv.Atoi(fmt.Sprint(result["route_policy_idx"]))
if err != nil {
return nil
}
if idx >= len(policies) {
return nil
}
return &policies[idx]
}
func getAllowVar(vars rego.Vars) bool {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return false
}
allow, ok := result["allow"].(bool)
if !ok {
return false
}
return allow
}
func getDenyVar(vars rego.Vars) []Result {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return nil
}
denials, ok := result["deny"].([]interface{})
if !ok {
return nil
}
results := make([]Result, 0, len(denials))
for _, denial := range denials {
denial, ok := denial.([]interface{})
if !ok || len(denial) != 2 {
continue
}
status, err := strconv.Atoi(fmt.Sprint(denial[0]))
if err != nil {
log.Error().Err(err).Msg("invalid type in deny")
continue
}
msg := fmt.Sprint(denial[1])
results = append(results, Result{
Status: status,
Message: msg,
})
}
return results
}
func getHeadersVar(vars rego.Vars) map[string]string {
headers := make(map[string]string)
result, ok := vars["result"].(map[string]interface{})
if !ok {
return headers
}
m, ok := result["identity_headers"].(map[string]interface{})
if !ok {
return headers
}
for k, v := range m {
headers[k] = fmt.Sprint(v)
}
return headers
}

View file

@ -25,7 +25,7 @@ import (
func TestJSONMarshal(t *testing.T) {
opt := config.NewDefaultOptions()
opt.AuthenticateURLString = "https://authenticate.example.com"
e, err := New(opt, NewStoreFromProtos(
e, err := New(opt, NewStoreFromProtos(0,
&session.Session{
UserId: "user1",
},
@ -100,7 +100,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
store := NewStoreFromProtos()
store := NewStoreFromProtos(0)
data, _ := ptypes.MarshalAny(&session.Session{
Version: "1",
Id: sessionID,
@ -116,7 +116,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
RefreshToken: "REFRESH TOKEN",
},
})
store.UpdateRecord(&databroker.Record{
store.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: "type.googleapis.com/session.Session",
Id: sessionID,
@ -127,7 +127,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
Id: userID,
Email: "foo@example.com",
})
store.UpdateRecord(&databroker.Record{
store.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: "type.googleapis.com/user.User",
Id: userID,
@ -189,7 +189,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
RefreshToken: "REFRESH TOKEN",
},
})
store.UpdateRecord(&databroker.Record{
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: "type.googleapis.com/session.Session",
Id: sessionID,
@ -199,7 +199,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
Version: fmt.Sprint(i),
Id: userID,
})
store.UpdateRecord(&databroker.Record{
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: "type.googleapis.com/user.User",
Id: userID,
@ -211,7 +211,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
Id: userID,
GroupIds: []string{"1", "2", "3", "4"},
})
store.UpdateRecord(&databroker.Record{
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: data.TypeUrl,
Id: userID,
@ -222,7 +222,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
Version: fmt.Sprint(i),
Id: fmt.Sprint(i),
})
store.UpdateRecord(&databroker.Record{
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: data.TypeUrl,
Id: fmt.Sprint(i),

View file

@ -5,6 +5,10 @@ default allow = false
# 5 minutes from now in seconds
five_minutes := (time.now_ns() / 1e9) + (60 * 5)
# databroker versions to know which version of the data was evaluated
databroker_server_version := data.databroker_server_version
databroker_record_version := data.databroker_record_version
route_policy_idx := first_allowed_route_policy_idx(input.http.url)
route_policy := data.route_policies[route_policy_idx]

View file

@ -3,6 +3,7 @@ package evaluator
import (
"context"
"encoding/json"
"math"
"testing"
"time"
@ -11,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
@ -37,13 +39,13 @@ func TestOPA(t *testing.T) {
eval := func(t *testing.T, policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result {
authzPolicy, err := readPolicy()
require.NoError(t, err)
store := NewStoreFromProtos(data...)
store := NewStoreFromProtos(math.MaxUint64, data...)
store.UpdateIssuer("authenticate.example.com")
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user"))
store.UpdateRoutePolicies(policies)
store.UpdateSigningKey(privateJWK)
r := rego.New(
rego.Store(store.opaStore),
rego.Store(store),
rego.Module("pomerium.authz", string(authzPolicy)),
rego.Query("result = data.pomerium.authz"),
getGoogleCloudServerlessHeadersRegoOption,
@ -646,4 +648,12 @@ func TestOPA(t *testing.T) {
}, true)
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
})
t.Run("databroker versions", func(t *testing.T) {
res := eval(t, nil, []proto.Message{
wrapperspb.String("test"),
}, &Request{}, false)
serverVersion, recordVersion := getDataBrokerVersions(res.Bindings)
assert.Equal(t, uint64(math.MaxUint64), serverVersion)
assert.NotEqual(t, uint64(0), recordVersion) // random
})
}

View file

@ -0,0 +1,115 @@
package evaluator
import (
"fmt"
"strconv"
"github.com/open-policy-agent/opa/rego"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
)
// Result is the result of evaluation.
type Result struct {
Status int
Message string
Headers map[string]string
MatchingPolicy *config.Policy
DataBrokerServerVersion, DataBrokerRecordVersion uint64
}
func getMatchingPolicy(vars rego.Vars, policies []config.Policy) *config.Policy {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return nil
}
idx, err := strconv.Atoi(fmt.Sprint(result["route_policy_idx"]))
if err != nil {
return nil
}
if idx >= len(policies) {
return nil
}
return &policies[idx]
}
func getAllowVar(vars rego.Vars) bool {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return false
}
allow, ok := result["allow"].(bool)
if !ok {
return false
}
return allow
}
func getDenyVar(vars rego.Vars) []Result {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return nil
}
denials, ok := result["deny"].([]interface{})
if !ok {
return nil
}
results := make([]Result, 0, len(denials))
for _, denial := range denials {
denial, ok := denial.([]interface{})
if !ok || len(denial) != 2 {
continue
}
status, err := strconv.Atoi(fmt.Sprint(denial[0]))
if err != nil {
log.Error().Err(err).Msg("invalid type in deny")
continue
}
msg := fmt.Sprint(denial[1])
results = append(results, Result{
Status: status,
Message: msg,
})
}
return results
}
func getHeadersVar(vars rego.Vars) map[string]string {
headers := make(map[string]string)
result, ok := vars["result"].(map[string]interface{})
if !ok {
return headers
}
m, ok := result["identity_headers"].(map[string]interface{})
if !ok {
return headers
}
for k, v := range m {
headers[k] = fmt.Sprint(v)
}
return headers
}
func getDataBrokerVersions(vars rego.Vars) (serverVersion, recordVersion uint64) {
result, ok := vars["result"].(map[string]interface{})
if !ok {
return 0, 0
}
serverVersion, _ = strconv.ParseUint(fmt.Sprint(result["databroker_server_version"]), 10, 64)
recordVersion, _ = strconv.ParseUint(fmt.Sprint(result["databroker_record_version"]), 10, 64)
return serverVersion, recordVersion
}

View file

@ -25,7 +25,7 @@ import (
// A Store stores data for the OPA rego policy evaluation.
type Store struct {
opaStore storage.Store
storage.Store
mu sync.RWMutex
dataBrokerData map[string]map[string]proto.Message
@ -34,13 +34,13 @@ type Store struct {
// NewStore creates a new Store.
func NewStore() *Store {
return &Store{
opaStore: inmem.New(),
Store: inmem.New(),
dataBrokerData: make(map[string]map[string]proto.Message),
}
}
// NewStoreFromProtos creates a new Store from an existing set of protobuf messages.
func NewStoreFromProtos(msgs ...proto.Message) *Store {
func NewStoreFromProtos(serverVersion uint64, msgs ...proto.Message) *Store {
s := NewStore()
for _, msg := range msgs {
any, err := anypb.New(msg)
@ -58,11 +58,34 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
record.Id = hasID.GetId()
}
s.UpdateRecord(record)
s.UpdateRecord(serverVersion, record)
}
return s
}
// NewTransaction calls the underlying store NewTransaction and takes the transaction lock.
func (s *Store) NewTransaction(ctx context.Context, params ...storage.TransactionParams) (storage.Transaction, error) {
txn, err := s.Store.NewTransaction(ctx, params...)
if err != nil {
return nil, err
}
s.mu.RLock()
return txn, err
}
// Commit calls the underlying store Commit and releases the transaction lock.
func (s *Store) Commit(ctx context.Context, txn storage.Transaction) error {
err := s.Store.Commit(ctx, txn)
s.mu.RUnlock()
return err
}
// Abort calls the underlying store Abort and releases the transaction lock.
func (s *Store) Abort(ctx context.Context, txn storage.Transaction) {
s.Store.Abort(ctx, txn)
s.mu.RUnlock()
}
// ClearRecords removes all the records from the store.
func (s *Store) ClearRecords() {
s.mu.Lock()
@ -107,10 +130,13 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
}
// UpdateRecord updates a record in the store.
func (s *Store) UpdateRecord(record *databroker.Record) {
func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) {
s.mu.Lock()
defer s.mu.Unlock()
s.write("/databroker_server_version", fmt.Sprint(serverVersion))
s.write("/databroker_record_version", fmt.Sprint(record.GetVersion()))
m, ok := s.dataBrokerData[record.GetType()]
if !ok {
m = make(map[string]proto.Message)
@ -130,31 +156,8 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
}
func (s *Store) write(rawPath string, value interface{}) {
p, ok := storage.ParsePath(rawPath)
if !ok {
log.Error().
Str("path", rawPath).
Msg("opa-store: invalid path, ignoring data")
return
}
err := storage.Txn(context.Background(), s.opaStore, storage.WriteParams, func(txn storage.Transaction) error {
if len(p) > 1 {
err := storage.MakeDir(context.Background(), s.opaStore, txn, p[:len(p)-1])
if err != nil {
return err
}
}
var op storage.PatchOp = storage.ReplaceOp
_, err := s.opaStore.Read(context.Background(), txn, p)
if storage.IsNotFound(err) {
op = storage.AddOp
} else if err != nil {
return err
}
return s.opaStore.Write(context.Background(), txn, op, p, value)
err := storage.Txn(context.Background(), s.Store, storage.WriteParams, func(txn storage.Transaction) error {
return s.writeTxn(txn, rawPath, value)
})
if err != nil {
log.Error().Err(err).Msg("opa-store: error writing data")
@ -162,6 +165,30 @@ func (s *Store) write(rawPath string, value interface{}) {
}
}
func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error {
p, ok := storage.ParsePath(rawPath)
if !ok {
return fmt.Errorf("invalid path")
}
if len(p) > 1 {
err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
if err != nil {
return err
}
}
var op storage.PatchOp = storage.ReplaceOp
_, err := s.Read(context.Background(), txn, p)
if storage.IsNotFound(err) {
op = storage.AddOp
} else if err != nil {
return err
}
return s.Write(context.Background(), txn, op, p, value)
}
// GetDataBrokerRecordOption returns a function option that can retrieve databroker data.
func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
return rego.Function2(&rego.Function{

View file

@ -21,7 +21,7 @@ func TestStore(t *testing.T) {
Email: "name@example.com",
}
any, _ := anypb.New(u)
s.UpdateRecord(&databroker.Record{
s.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: u.GetId(),
@ -36,7 +36,7 @@ func TestStore(t *testing.T) {
"email": "name@example.com",
}, toMap(v))
s.UpdateRecord(&databroker.Record{
s.UpdateRecord(0, &databroker.Record{
Version: 2,
Type: any.GetTypeUrl(),
Id: u.GetId(),
@ -47,7 +47,7 @@ func TestStore(t *testing.T) {
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Nil(t, v)
s.UpdateRecord(&databroker.Record{
s.UpdateRecord(0, &databroker.Record{
Version: 3,
Type: any.GetTypeUrl(),
Id: u.GetId(),

View file

@ -3,7 +3,6 @@ package authorize
import (
"context"
"encoding/base64"
"errors"
"io/ioutil"
"net/http"
"net/url"
@ -19,10 +18,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
)
@ -83,81 +79,6 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
return a.deniedResponse(in, int32(reply.Status), reply.Message, nil)
}
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (*user.User, error) {
ctx, span := trace.StartSpan(ctx, "authorize.forceSync")
defer span.End()
if ss == nil {
return nil, nil
}
s := a.forceSyncSession(ctx, ss.ID)
if s == nil {
return nil, errors.New("session not found")
}
u := a.forceSyncUser(ctx, s.GetUserId())
return u, nil
}
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) interface{ GetUserId() string } {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
defer span.End()
state := a.state.Load()
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
if ok {
return s
}
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
if ok {
return sa
}
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: grpcutil.GetTypeURL(new(session.Session)),
Id: sessionID,
})
if err != nil {
log.Warn().Err(err).Msg("failed to get session from databroker")
return nil
}
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID); current == nil {
a.store.UpdateRecord(res.GetRecord())
}
s, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
return s
}
func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
defer span.End()
state := a.state.Load()
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
if ok {
return u
}
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: grpcutil.GetTypeURL(new(user.User)),
Id: userID,
})
if err != nil {
log.Warn().Err(err).Msg("failed to get user from databroker")
return nil
}
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID); current == nil {
a.store.UpdateRecord(res.GetRecord())
}
u, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
return u
}
func getForwardAuthURL(r *http.Request) *url.URL {
urqQuery := r.URL.Query().Get("uri")
u, _ := urlutil.ParseAndValidateURL(urqQuery)
@ -329,6 +250,8 @@ func logAuthorizeCheck(
evt = evt.Str("message", reply.Message)
evt = evt.Str("user", u.GetId())
evt = evt.Str("email", u.GetEmail())
evt = evt.Uint64("databroker_server_version", reply.DataBrokerServerVersion)
evt = evt.Uint64("databroker_record_version", reply.DataBrokerRecordVersion)
}
// potentially sensitive, only log if debug mode

View file

@ -2,12 +2,10 @@ package authorize
import (
"context"
"errors"
"net/url"
"testing"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/golang/protobuf/ptypes"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
@ -21,8 +19,6 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const certPEM = `
@ -313,132 +309,6 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
assert.Equal(t, expect, actual)
}
func TestSync(t *testing.T) {
mockSession := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, _ := ptypes.MarshalAny(&session.Session{
Id: in.GetId(),
UserId: "user1",
})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
}
mockUser := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, _ := ptypes.MarshalAny(&user.User{Id: in.GetId()})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
}
mockGetByType := map[string]func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error){
"type.googleapis.com/session.Session": mockSession,
"type.googleapis.com/user.User": mockUser,
}
dbdClient := mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
if in.GetId() == "not-existed-id" {
return nil, errors.New("not found")
}
f, ok := mockGetByType[in.GetType()]
if !ok {
return nil, errors.New("not found")
}
return f(ctx, in, opts...)
},
}
o := &config.Options{
AuthenticateURLString: "https://authN.example.com",
DataBrokerURLString: "https://databroker.example.com",
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
Policies: testPolicies(t),
}
ctx := context.Background()
tests := []struct {
name string
sessionState *sessions.State
databrokerClient mockDataBrokerServiceClient
wantErr bool
}{
{
"good with data in databroker data",
&sessions.State{ID: "dbd_session_id"},
mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, _ := ptypes.MarshalAny(&session.Session{
Id: in.GetId(),
UserId: "dbd_user1",
})
if in.GetType() == "type.googleapis.com/user.User" {
data, _ = ptypes.MarshalAny(&user.User{
Id: "dbd_user1",
})
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
},
},
false,
},
{"good", &sessions.State{ID: "SESSION_ID"}, dbdClient, false},
{"nil session state", nil, dbdClient, false},
{"not found session state", &sessions.State{ID: "not-existed-id"}, dbdClient, true},
{
"user not found",
&sessions.State{ID: "session_with_not_found_user"},
mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
if in.GetType() == "type.googleapis.com/user.User" {
return nil, errors.New("user not found")
}
data, _ := ptypes.MarshalAny(&session.Session{
Id: in.GetId(),
UserId: "user1",
})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
},
},
false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
a.state.Load().dataBrokerClient = dbdClient
_, err = a.forceSync(ctx, tc.sessionState)
assert.True(t, (err != nil) == tc.wantErr)
})
}
}
type mockDataBrokerServiceClient struct {
databroker.DataBrokerServiceClient

View file

@ -2,11 +2,32 @@ package authorize
import (
"context"
"errors"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
const (
forceSyncRecordMaxWait = 5 * time.Second
)
type sessionOrServiceAccount interface {
GetUserId() string
}
type dataBrokerSyncer struct {
*databroker.Syncer
authorize *Authorize
@ -29,9 +50,9 @@ func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
syncer.authorize.store.ClearRecords()
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
for _, record := range records {
syncer.authorize.store.UpdateRecord(record)
syncer.authorize.store.UpdateRecord(serverVersion, record)
}
// the first time we update records we signal the initial sync
@ -39,3 +60,112 @@ func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*da
close(syncer.authorize.dataBrokerInitialSync)
})
}
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (*user.User, error) {
ctx, span := trace.StartSpan(ctx, "authorize.forceSync")
defer span.End()
if ss == nil {
return nil, nil
}
s := a.forceSyncSession(ctx, ss.ID)
if s == nil {
return nil, errors.New("session not found")
}
u := a.forceSyncUser(ctx, s.GetUserId())
return u, nil
}
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
defer span.End()
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
defer clearTimeout()
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
if ok {
return s
}
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
if ok {
return sa
}
// wait for the session to show up
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
if err != nil {
return nil
}
s, ok = record.(*session.Session)
if !ok {
return nil
}
return s
}
func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
defer span.End()
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
defer clearTimeout()
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
if ok {
return u
}
// wait for the user to show up
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
if err != nil {
return nil
}
u, ok = record.(*user.User)
if !ok {
return nil
}
return u
}
// waitForRecordSync waits for the first sync of a record to complete
func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, recordID string) (proto.Message, error) {
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = time.Millisecond
bo.MaxElapsedTime = 0
bo.Reset()
for {
current := a.store.GetRecordData(recordTypeURL, recordID)
if current != nil {
// record found, so it's already synced
return current, nil
}
_, err := a.state.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: recordTypeURL,
Id: recordID,
})
if status.Code(err) == codes.NotFound {
// record not found, so no need to wait
return nil, nil
} else if err != nil {
log.Error().
Err(err).
Str("type", recordTypeURL).
Str("id", recordID).
Msg("authorize: error retrieving record")
return nil, err
}
select {
case <-ctx.Done():
log.Warn().
Str("type", recordTypeURL).
Str("id", recordID).
Msg("authorize: first sync of record did not complete")
return nil, ctx.Err()
case <-time.After(bo.NextBackOff()):
}
}
}

116
authorize/sync_test.go Normal file
View file

@ -0,0 +1,116 @@
package authorize
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
func TestAuthorize_waitForRecordSync(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
o := &config.Options{
AuthenticateURLString: "https://authN.example.com",
DataBrokerURLString: "https://databroker.example.com",
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
Policies: testPolicies(t),
}
t.Run("skip if exists", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
a.store.UpdateRecord(0, newRecord(&session.Session{
Id: "SESSION_ID",
}))
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
panic("should never be called")
},
}
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
})
t.Run("skip if not found", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
callCount := 0
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
callCount++
return nil, status.Error(codes.NotFound, "not found")
},
}
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
assert.Equal(t, 1, callCount, "should be called once")
})
t.Run("poll", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
callCount := 0
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
callCount++
switch callCount {
case 1:
s := &session.Session{Id: "SESSION_ID"}
a.store.UpdateRecord(0, newRecord(s))
return &databroker.GetResponse{Record: newRecord(s)}, nil
default:
panic("should never be called")
}
},
}
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
})
t.Run("timeout", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100)
defer clearTimeout()
callCount := 0
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
callCount++
s := &session.Session{Id: "SESSION_ID"}
return &databroker.GetResponse{Record: newRecord(s)}, nil
},
}
a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism
})
}
type storableMessage interface {
proto.Message
GetId() string
}
func newRecord(msg storableMessage) *databroker.Record {
any, err := anypb.New(msg)
if err != nil {
panic(err)
}
return &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: msg.GetId(),
Data: any,
}
}

View file

@ -206,7 +206,7 @@ func (s *syncerHandler) ClearRecords(ctx context.Context) {
s.src.mu.Unlock()
}
func (s *syncerHandler) UpdateRecords(ctx context.Context, records []*databroker.Record) {
func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
if len(records) == 0 {
return
}

View file

@ -50,7 +50,7 @@ func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrok
return syncer.cfg.Load().dataBrokerClient
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
select {
case <-ctx.Done():
case syncer.update <- updateRecordsMessage{records: records}:

View file

@ -39,7 +39,7 @@ func WithTypeURL(typeURL string) SyncerOption {
type SyncerHandler interface {
GetDataBrokerServiceClient() DataBrokerServiceClient
ClearRecords(ctx context.Context)
UpdateRecords(ctx context.Context, records []*Record)
UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record)
}
// A Syncer is a helper type for working with Sync and SyncLatest. It will make a call to
@ -122,7 +122,7 @@ func (syncer *Syncer) init(ctx context.Context) error {
syncer.recordVersion = recordVersion
syncer.serverVersion = serverVersion
syncer.handler.UpdateRecords(ctx, records)
syncer.handler.UpdateRecords(ctx, serverVersion, records)
return nil
}
@ -157,7 +157,7 @@ func (syncer *Syncer) sync(ctx context.Context) error {
}
syncer.recordVersion = res.GetRecord().GetVersion()
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
syncer.handler.UpdateRecords(ctx, []*Record{res.GetRecord()})
syncer.handler.UpdateRecords(ctx, syncer.serverVersion, []*Record{res.GetRecord()})
}
}
}

View file

@ -19,7 +19,7 @@ import (
type testSyncerHandler struct {
getDataBrokerServiceClient func() DataBrokerServiceClient
clearRecords func(ctx context.Context)
updateRecords func(ctx context.Context, records []*Record)
updateRecords func(ctx context.Context, serverVersion uint64, records []*Record)
}
func (t testSyncerHandler) GetDataBrokerServiceClient() DataBrokerServiceClient {
@ -30,8 +30,8 @@ func (t testSyncerHandler) ClearRecords(ctx context.Context) {
t.clearRecords(ctx)
}
func (t testSyncerHandler) UpdateRecords(ctx context.Context, records []*Record) {
t.updateRecords(ctx, records)
func (t testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) {
t.updateRecords(ctx, serverVersion, records)
}
type testServer struct {
@ -166,7 +166,7 @@ func TestSyncer(t *testing.T) {
clearRecords: func(ctx context.Context) {
clearCh <- struct{}{}
},
updateRecords: func(ctx context.Context, records []*Record) {
updateRecords: func(ctx context.Context, serverVersion uint64, records []*Record) {
updateCh <- records
},
})