authorize: use query instead of sync for databroker data (#3377)

This commit is contained in:
Caleb Doxsey 2022-06-01 15:40:07 -06:00 committed by GitHub
parent fd82cc7870
commit f61e7efe73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 661 additions and 1008 deletions

View file

@ -8,6 +8,8 @@ import (
"sync" "sync"
"time" "time"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -17,6 +19,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
) )
// Authorize struct holds // Authorize struct holds
@ -25,8 +28,7 @@ type Authorize struct {
store *store.Store store *store.Store
currentOptions *config.AtomicOptions currentOptions *config.AtomicOptions
accessTracker *AccessTracker accessTracker *AccessTracker
globalCache storage.Cache
dataBrokerInitialSync chan struct{}
// The stateLock prevents updating the evaluator store simultaneously with an evaluation. // The stateLock prevents updating the evaluator store simultaneously with an evaluation.
// This should provide a consistent view of the data at a given server/record version and // This should provide a consistent view of the data at a given server/record version and
@ -37,9 +39,9 @@ type Authorize struct {
// New validates and creates a new Authorize service from a set of config options. // New validates and creates a new Authorize service from a set of config options.
func New(cfg *config.Config) (*Authorize, error) { func New(cfg *config.Config) (*Authorize, error) {
a := &Authorize{ a := &Authorize{
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
store: store.New(), store: store.New(),
dataBrokerInitialSync: make(chan struct{}), globalCache: storage.NewGlobalCache(time.Minute),
} }
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -59,19 +61,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
// Run runs the authorize service. // Run runs the authorize service.
func (a *Authorize) Run(ctx context.Context) error { func (a *Authorize) Run(ctx context.Context) error {
go a.accessTracker.Run(ctx) eg, ctx := errgroup.WithContext(ctx)
_ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10) eg.Go(func() error {
return newDataBrokerSyncer(a).Run(ctx) a.accessTracker.Run(ctx)
} return nil
})
// WaitForInitialSync blocks until the initial sync is complete. eg.Go(func() error {
func (a *Authorize) WaitForInitialSync(ctx context.Context) error { _ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10)
select { return nil
case <-ctx.Done(): })
return ctx.Err() return eg.Wait()
case <-a.dataBrokerInitialSync:
}
return nil
} }
func validateOptions(o *config.Options) error { func validateOptions(o *config.Options) error {

View file

@ -20,8 +20,6 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
) )
func TestAuthorize_okResponse(t *testing.T) { func TestAuthorize_okResponse(t *testing.T) {
@ -40,17 +38,7 @@ func TestAuthorize_okResponse(t *testing.T) {
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(opt) a.currentOptions.Store(opt)
a.store = store.NewFromProtos(0, a.store = store.New()
&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
},
&user.User{
Id: "USER_ID",
Name: "foo",
Email: "foo@example.com",
},
)
pe, err := newPolicyEvaluator(opt, a.store) pe, err := newPolicyEvaluator(opt, a.store)
require.NoError(t, err) require.NoError(t, err)
a.state.Load().evaluator = pe a.state.Load().evaluator = pe

48
authorize/databroker.go Normal file
View file

@ -0,0 +1,48 @@
package authorize
import (
"context"
"github.com/open-policy-agent/opa/storage"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
type sessionOrServiceAccount interface {
GetUserId() string
}
func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) {
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
defer span.End()
client := a.state.Load().dataBrokerClient
s, err = session.Get(ctx, client, sessionID)
if storage.IsNotFound(err) {
s, err = user.GetServiceAccount(ctx, client, sessionID)
}
if err != nil {
return nil, err
}
if _, ok := s.(*session.Session); ok {
a.accessTracker.TrackSessionAccess(sessionID)
}
if _, ok := s.(*user.ServiceAccount); ok {
a.accessTracker.TrackServiceAccountAccess(sessionID)
}
return s, nil
}
func (a *Authorize) getDataBrokerUser(ctx context.Context, userID string) (u *user.User, err error) {
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
defer span.End()
client := a.state.Load().dataBrokerClient
u, err = user.Get(ctx, client, userID)
return u, err
}

View file

@ -72,8 +72,6 @@ type Result struct {
Allow RuleResult Allow RuleResult
Deny RuleResult Deny RuleResult
Headers http.Header Headers http.Header
DataBrokerServerVersion, DataBrokerRecordVersion uint64
} }
// An Evaluator evaluates policies. // An Evaluator evaluates policies.
@ -170,7 +168,6 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
Deny: policyOutput.Deny, Deny: policyOutput.Deny,
Headers: headersOutput.Headers, Headers: headersOutput.Headers,
} }
res.DataBrokerServerVersion, res.DataBrokerRecordVersion = e.store.GetDataBrokerVersions()
return res, nil return res, nil
} }

View file

@ -2,29 +2,24 @@ package evaluator
import ( import (
"context" "context"
"fmt"
"math"
"net/http" "net/http"
"net/url" "net/url"
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria" "github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/policy/parser" "github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/storage"
) )
func TestEvaluator(t *testing.T) { func TestEvaluator(t *testing.T) {
@ -36,13 +31,15 @@ func TestEvaluator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
eval := func(t *testing.T, options []Option, data []proto.Message, req *Request) (*Result, error) { eval := func(t *testing.T, options []Option, data []proto.Message, req *Request) (*Result, error) {
store := store.NewFromProtos(math.MaxUint64, data...) ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateIssuer("authenticate.example.com") store.UpdateIssuer("authenticate.example.com")
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK) store.UpdateSigningKey(privateJWK)
e, err := New(context.Background(), store, options...) e, err := New(ctx, store, options...)
require.NoError(t, err) require.NoError(t, err)
return e.Evaluate(context.Background(), req) return e.Evaluate(ctx, req)
} }
policies := []config.Policy{ policies := []config.Policy{
@ -511,104 +508,3 @@ func mustParseURL(str string) *url.URL {
} }
return u return u
} }
func BenchmarkEvaluator_Evaluate(b *testing.B) {
store := store.New()
policies := []config.Policy{
{
From: "https://from.example.com",
To: config.WeightedURLs{
{URL: *mustParseURL("https://to.example.com")},
},
AllowedUsers: []string{"SOME_USER"},
},
}
options := []Option{
WithAuthenticateURL("https://authn.example.com"),
WithPolicies(policies),
}
e, err := New(context.Background(), store, options...)
if !assert.NoError(b, err) {
return
}
lastSessionID := ""
for i := 0; i < 100000; i++ {
sessionID := uuid.New().String()
lastSessionID = sessionID
userID := uuid.New().String()
data := protoutil.NewAny(&session.Session{
Version: fmt.Sprint(i),
Id: sessionID,
UserId: userID,
IdToken: &session.IDToken{
Issuer: "benchmark",
Subject: userID,
IssuedAt: timestamppb.Now(),
},
OauthToken: &session.OAuthToken{
AccessToken: "ACCESS TOKEN",
TokenType: "Bearer",
RefreshToken: "REFRESH TOKEN",
},
})
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: "type.googleapis.com/session.Session",
Id: sessionID,
Data: data,
})
data = protoutil.NewAny(&user.User{
Version: fmt.Sprint(i),
Id: userID,
})
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: "type.googleapis.com/user.User",
Id: userID,
Data: data,
})
data = protoutil.NewAny(&directory.User{
Version: fmt.Sprint(i),
Id: userID,
GroupIds: []string{"1", "2", "3", "4"},
})
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: data.TypeUrl,
Id: userID,
Data: data,
})
data = protoutil.NewAny(&directory.Group{
Version: fmt.Sprint(i),
Id: fmt.Sprint(i),
})
store.UpdateRecord(0, &databroker.Record{
Version: uint64(i),
Type: data.TypeUrl,
Id: fmt.Sprint(i),
Data: data,
})
}
b.ResetTimer()
ctx := context.Background()
for i := 0; i < b.N; i++ {
_, _ = e.Evaluate(ctx, &Request{
Policy: &policies[0],
HTTP: RequestHTTP{
Method: "GET",
URL: "https://example.com/path",
Headers: map[string]string{},
},
Session: RequestSession{
ID: lastSessionID,
},
})
}
}

View file

@ -18,6 +18,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
) )
func TestNewHeadersRequestFromPolicy(t *testing.T) { func TestNewHeadersRequestFromPolicy(t *testing.T) {
@ -51,13 +52,15 @@ func TestHeadersEvaluator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) { eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
store := store.NewFromProtos(math.MaxUint64, data...) ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateIssuer("authenticate.example.com") store.UpdateIssuer("authenticate.example.com")
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK) store.UpdateSigningKey(privateJWK)
e, err := NewHeadersEvaluator(context.Background(), store) e, err := NewHeadersEvaluator(ctx, store)
require.NoError(t, err) require.NoError(t, err)
return e.Evaluate(context.Background(), input) return e.Evaluate(ctx, input)
} }
t.Run("groups", func(t *testing.T) { t.Run("groups", func(t *testing.T) {

View file

@ -2,7 +2,6 @@ package evaluator
import ( import (
"context" "context"
"math"
"strings" "strings"
"testing" "testing"
@ -18,6 +17,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy" "github.com/pomerium/pomerium/pkg/policy"
"github.com/pomerium/pomerium/pkg/policy/criteria" "github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/storage"
) )
func TestPolicyEvaluator(t *testing.T) { func TestPolicyEvaluator(t *testing.T) {
@ -29,13 +29,15 @@ func TestPolicyEvaluator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) { eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) {
store := store.NewFromProtos(math.MaxUint64, data...) ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateIssuer("authenticate.example.com") store.UpdateIssuer("authenticate.example.com")
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK) store.UpdateSigningKey(privateJWK)
e, err := NewPolicyEvaluator(context.Background(), store, policy) e, err := NewPolicyEvaluator(ctx, store, policy)
require.NoError(t, err) require.NoError(t, err)
return e.Evaluate(context.Background(), input) return e.Evaluate(ctx, input)
} }
p1 := &config.Policy{ p1 := &config.Policy{

View file

@ -16,6 +16,8 @@ import (
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
) )
// Check implements the envoy auth server gRPC endpoint. // Check implements the envoy auth server gRPC endpoint.
@ -23,10 +25,16 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check") ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
defer span.End() defer span.End()
// wait for the initial sync to complete so that data is available for evaluation querier := storage.NewTracingQuerier(
if err := a.WaitForInitialSync(ctx); err != nil { storage.NewCachingQuerier(
return nil, err storage.NewCachingQuerier(
} storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
),
storage.NewLocalCache(),
),
)
ctx = storage.WithQuerier(ctx, querier)
state := a.state.Load() state := a.state.Load()
@ -48,10 +56,21 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder) rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
sessionState, _ := loadSession(state.encoder, rawJWT) sessionState, _ := loadSession(state.encoder, rawJWT)
s, u, err := a.forceSync(ctx, sessionState) var s sessionOrServiceAccount
if err != nil { var u *user.User
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed") if sessionState != nil {
sessionState = nil s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID)
if err != nil {
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
sessionState = nil
}
}
if s != nil {
u, err = a.getDataBrokerUser(ctx, s.GetUserId())
if err != nil {
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
sessionState = nil
}
} }
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState) req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)

View file

@ -337,8 +337,6 @@ func TestAuthorize_Check(t *testing.T) {
} }
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"}) a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
close(a.dataBrokerInitialSync)
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}), cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
cmpopts.IgnoreUnexported(status.Status{}), cmpopts.IgnoreUnexported(status.Status{}),

View file

@ -1,196 +0,0 @@
package store
import (
"sync"
"github.com/kentik/patricia"
"github.com/kentik/patricia/string_tree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
)
const (
indexField = "$index"
cidrField = "cidr"
)
type index struct {
mu sync.RWMutex
byType map[string]*recordIndex
}
func newIndex() *index {
idx := new(index)
idx.clear()
return idx
}
func (idx *index) clear() {
idx.mu.Lock()
defer idx.mu.Unlock()
idx.byType = map[string]*recordIndex{}
}
func (idx *index) delete(typeURL, id string) {
idx.mu.Lock()
defer idx.mu.Unlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return
}
ridx.delete(id)
if len(ridx.byID) == 0 {
delete(idx.byType, typeURL)
}
}
func (idx *index) find(typeURL, id string) proto.Message {
idx.mu.RLock()
defer idx.mu.RUnlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return nil
}
return ridx.find(id)
}
func (idx *index) get(typeURL, id string) proto.Message {
idx.mu.RLock()
defer idx.mu.RUnlock()
ridx, ok := idx.byType[typeURL]
if !ok {
return nil
}
return ridx.get(id)
}
func (idx *index) set(typeURL, id string, msg proto.Message) {
idx.mu.Lock()
defer idx.mu.Unlock()
ridx, ok := idx.byType[typeURL]
if !ok {
ridx = newRecordIndex()
idx.byType[typeURL] = ridx
}
ridx.set(id, msg)
}
// a recordIndex indexes records for of a specific type
type recordIndex struct {
byID map[string]proto.Message
byCIDRV4 *string_tree.TreeV4
byCIDRV6 *string_tree.TreeV6
}
// newRecordIndex creates a new record index.
func newRecordIndex() *recordIndex {
return &recordIndex{
byID: map[string]proto.Message{},
byCIDRV4: string_tree.NewTreeV4(),
byCIDRV6: string_tree.NewTreeV6(),
}
}
func (idx *recordIndex) delete(id string) {
r, ok := idx.byID[id]
if !ok {
return
}
delete(idx.byID, id)
addr4, addr6 := getIndexCIDR(r)
if addr4 != nil {
idx.byCIDRV4.Delete(*addr4, func(payload, val string) bool {
return payload == val
}, id)
}
if addr6 != nil {
idx.byCIDRV6.Delete(*addr6, func(payload, val string) bool {
return payload == val
}, id)
}
}
func (idx *recordIndex) find(idOrString string) proto.Message {
r, ok := idx.byID[idOrString]
if ok {
return r
}
addrv4, addrv6, _ := patricia.ParseIPFromString(idOrString)
if addrv4 != nil {
found, id := idx.byCIDRV4.FindDeepestTag(*addrv4)
if found {
return idx.byID[id]
}
}
if addrv6 != nil {
found, id := idx.byCIDRV6.FindDeepestTag(*addrv6)
if found {
return idx.byID[id]
}
}
return nil
}
func (idx *recordIndex) get(id string) proto.Message {
return idx.byID[id]
}
func (idx *recordIndex) set(id string, msg proto.Message) {
_, ok := idx.byID[id]
if ok {
idx.delete(id)
}
idx.byID[id] = msg
addr4, addr6 := getIndexCIDR(msg)
if addr4 != nil {
idx.byCIDRV4.Set(*addr4, id)
}
if addr6 != nil {
idx.byCIDRV6.Set(*addr6, id)
}
}
func getIndexCIDR(msg proto.Message) (*patricia.IPv4Address, *patricia.IPv6Address) {
var s *structpb.Struct
if sv, ok := msg.(*structpb.Value); ok {
s = sv.GetStructValue()
} else {
s, _ = msg.(*structpb.Struct)
}
if s == nil {
return nil, nil
}
f, ok := s.Fields[indexField]
if !ok {
return nil, nil
}
obj := f.GetStructValue()
if obj == nil {
return nil, nil
}
cf, ok := obj.Fields[cidrField]
if !ok {
return nil, nil
}
c := cf.GetStringValue()
if c == "" {
return nil, nil
}
addr4, addr6, _ := patricia.ParseIPFromString(c)
return addr4, addr6
}

View file

@ -1,74 +0,0 @@
package store
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
)
func TestByID(t *testing.T) {
idx := newIndex()
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("r1"),
}}
idx.set("example.com/record", "r1", r1)
assert.Equal(t, r1, idx.get("example.com/record", "r1"))
idx.delete("example.com/record", "r1")
assert.Nil(t, idx.get("example.com/record", "r1"))
}
func TestByCIDR(t *testing.T) {
t.Run("ipv4", func(t *testing.T) {
idx := newIndex()
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
}}),
"id": structpb.NewStringValue("r1"),
}}
idx.set("example.com/record", "r1", r1)
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/24"),
}}),
"id": structpb.NewStringValue("r2"),
}}
idx.set("example.com/record", "r2", r2)
assert.Equal(t, r2, idx.find("example.com/record", "192.168.0.7"))
idx.delete("example.com/record", "r2")
assert.Equal(t, r1, idx.find("example.com/record", "192.168.0.7"))
idx.delete("example.com/record", "r1")
assert.Nil(t, idx.find("example.com/record", "192.168.0.7"))
})
t.Run("ipv6", func(t *testing.T) {
idx := newIndex()
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("2001:db8::/32"),
}}),
"id": structpb.NewStringValue("r1"),
}}
idx.set("example.com/record", "r1", r1)
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("2001:db8::/48"),
}}),
"id": structpb.NewStringValue("r2"),
}}
idx.set("example.com/record", "r2", r2)
assert.Equal(t, r2, idx.find("example.com/record", "2001:db8::"))
idx.delete("example.com/record", "r2")
assert.Equal(t, r1, idx.find("example.com/record", "2001:db8::"))
idx.delete("example.com/record", "r1")
assert.Nil(t, idx.find("example.com/record", "2001:db8::"))
})
}

View file

@ -5,78 +5,33 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync/atomic"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/storage" opastorage "github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/types" "github.com/open-policy-agent/opa/types"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/storage"
) )
// A Store stores data for the OPA rego policy evaluation. // A Store stores data for the OPA rego policy evaluation.
type Store struct { type Store struct {
storage.Store opastorage.Store
index *index
dataBrokerServerVersion, dataBrokerRecordVersion uint64
} }
// New creates a new Store. // New creates a new Store.
func New() *Store { func New() *Store {
return &Store{ return &Store{
Store: inmem.New(), Store: inmem.New(),
index: newIndex(),
} }
} }
// NewFromProtos creates a new Store from an existing set of protobuf messages.
func NewFromProtos(serverVersion uint64, msgs ...proto.Message) *Store {
s := New()
for _, msg := range msgs {
any := protoutil.NewAny(msg)
record := new(databroker.Record)
record.ModifiedAt = timestamppb.Now()
record.Version = cryptutil.NewRandomUInt64()
record.Id = uuid.New().String()
record.Data = any
record.Type = any.TypeUrl
if hasID, ok := msg.(interface{ GetId() string }); ok {
record.Id = hasID.GetId()
}
s.UpdateRecord(serverVersion, record)
}
return s
}
// ClearRecords removes all the records from the store.
func (s *Store) ClearRecords() {
s.index.clear()
}
// GetDataBrokerVersions gets the databroker versions.
func (s *Store) GetDataBrokerVersions() (serverVersion, recordVersion uint64) {
return atomic.LoadUint64(&s.dataBrokerServerVersion),
atomic.LoadUint64(&s.dataBrokerRecordVersion)
}
// GetRecordData gets a record's data from the store. `nil` is returned
// if no record exists for the given type and id.
func (s *Store) GetRecordData(typeURL, idOrValue string) proto.Message {
return s.index.find(typeURL, idOrValue)
}
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction. // UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
func (s *Store) UpdateIssuer(issuer string) { func (s *Store) UpdateIssuer(issuer string) {
s.write("/issuer", issuer) s.write("/issuer", issuer)
@ -98,20 +53,6 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
s.write("/route_policies", routePolicies) s.write("/route_policies", routePolicies)
} }
// UpdateRecord updates a record in the store.
func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) {
if record.GetDeletedAt() != nil {
s.index.delete(record.GetType(), record.GetId())
} else {
msg, _ := record.GetData().UnmarshalNew()
s.index.set(record.GetType(), record.GetId(), msg)
}
s.write("/databroker_server_version", fmt.Sprint(serverVersion))
s.write("/databroker_record_version", fmt.Sprint(record.GetVersion()))
atomic.StoreUint64(&s.dataBrokerServerVersion, serverVersion)
atomic.StoreUint64(&s.dataBrokerRecordVersion, record.GetVersion())
}
// UpdateSigningKey updates the signing key stored in the database. Signing operations // UpdateSigningKey updates the signing key stored in the database. Signing operations
// in rego use JWKs, so we take in that format. // in rego use JWKs, so we take in that format.
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) { func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
@ -120,7 +61,7 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
func (s *Store) write(rawPath string, value interface{}) { func (s *Store) write(rawPath string, value interface{}) {
ctx := context.TODO() ctx := context.TODO()
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error { err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {
return s.writeTxn(txn, rawPath, value) return s.writeTxn(txn, rawPath, value)
}) })
if err != nil { if err != nil {
@ -129,23 +70,23 @@ func (s *Store) write(rawPath string, value interface{}) {
} }
} }
func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error { func (s *Store) writeTxn(txn opastorage.Transaction, rawPath string, value interface{}) error {
p, ok := storage.ParsePath(rawPath) p, ok := opastorage.ParsePath(rawPath)
if !ok { if !ok {
return fmt.Errorf("invalid path") return fmt.Errorf("invalid path")
} }
if len(p) > 1 { if len(p) > 1 {
err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1]) err := opastorage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
if err != nil { if err != nil {
return err return err
} }
} }
var op storage.PatchOp = storage.ReplaceOp var op opastorage.PatchOp = opastorage.ReplaceOp
_, err := s.Read(context.Background(), txn, p) _, err := s.Read(context.Background(), txn, p)
if storage.IsNotFound(err) { if opastorage.IsNotFound(err) {
op = storage.AddOp op = opastorage.AddOp
} else if err != nil { } else if err != nil {
return err return err
} }
@ -167,23 +108,42 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
return nil, fmt.Errorf("invalid record type: %T", op1) return nil, fmt.Errorf("invalid record type: %T", op1)
} }
recordID, ok := op2.Value.(ast.String) value, ok := op2.Value.(ast.String)
if !ok { if !ok {
return nil, fmt.Errorf("invalid record id: %T", op2) return nil, fmt.Errorf("invalid record id: %T", op2)
} }
msg := s.GetRecordData(string(recordType), string(recordID)) req := &databroker.QueryRequest{
if msg == nil { Type: string(recordType),
Limit: 1,
}
req.SetFilterByIDOrIndex(string(value))
res, err := storage.GetQuerier(bctx.Context).Query(bctx.Context, req)
if err != nil {
log.Error(bctx.Context).Err(err).Msg("authorize/store: error retrieving record")
return ast.NullTerm(), nil return ast.NullTerm(), nil
} }
if len(res.GetRecords()) == 0 {
return ast.NullTerm(), nil
}
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
if msg == nil {
if msg == nil {
return ast.NullTerm(), nil
}
}
obj := toMap(msg) obj := toMap(msg)
value, err := ast.InterfaceToValue(obj) regoValue, err := ast.InterfaceToValue(obj)
if err != nil { if err != nil {
return nil, err log.Error(bctx.Context).Err(err).Msg("authorize/store: error converting object to rego")
return ast.NullTerm(), nil
} }
return ast.NewTerm(value), nil return ast.NewTerm(regoValue), nil
}) })
} }

View file

@ -1,83 +0,0 @@
package store
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestStore(t *testing.T) {
t.Run("records", func(t *testing.T) {
s := New()
u := &user.User{
Version: "v1",
Id: "u1",
Name: "name",
Email: "name@example.com",
}
any := protoutil.NewAny(u)
s.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
})
v := s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Equal(t, map[string]interface{}{
"version": "v1",
"id": "u1",
"name": "name",
"email": "name@example.com",
}, toMap(v))
s.UpdateRecord(0, &databroker.Record{
Version: 2,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
DeletedAt: timestamppb.Now(),
})
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Nil(t, v)
s.UpdateRecord(0, &databroker.Record{
Version: 3,
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
})
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.NotNil(t, v)
s.ClearRecords()
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Nil(t, v)
})
t.Run("cidr", func(t *testing.T) {
s := New()
any := protoutil.NewAny(&structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
}}),
"id": structpb.NewStringValue("r1"),
}})
s.UpdateRecord(0, &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: "r1",
Data: any,
})
v := s.GetRecordData(any.GetTypeUrl(), "192.168.0.7")
assert.NotNil(t, v)
})
}

View file

@ -12,9 +12,11 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/audit" "github.com/pomerium/pomerium/pkg/grpc/audit"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
) )
func (a *Authorize) logAuthorizeCheck( func (a *Authorize) logAuthorizeCheck(
@ -39,7 +41,7 @@ func (a *Authorize) logAuthorizeCheck(
// session information // session information
if s, ok := s.(*session.Session); ok { if s, ok := s.(*session.Session); ok {
evt = a.populateLogSessionDetails(evt, s) evt = a.populateLogSessionDetails(ctx, evt, s)
} }
if sa, ok := s.(*user.ServiceAccount); ok { if sa, ok := s.(*user.ServiceAccount); ok {
evt = evt.Str("service-account-id", sa.GetId()) evt = evt.Str("service-account-id", sa.GetId())
@ -61,8 +63,6 @@ func (a *Authorize) logAuthorizeCheck(
} }
evt = evt.Str("user", u.GetId()) evt = evt.Str("user", u.GetId())
evt = evt.Str("email", u.GetEmail()) evt = evt.Str("email", u.GetEmail())
evt = evt.Uint64("databroker_server_version", res.DataBrokerServerVersion)
evt = evt.Uint64("databroker_record_version", res.DataBrokerRecordVersion)
} }
// potentially sensitive, only log if debug mode // potentially sensitive, only log if debug mode
@ -80,10 +80,6 @@ func (a *Authorize) logAuthorizeCheck(
Request: in, Request: in,
Response: out, Response: out,
} }
if res != nil {
record.DatabrokerServerVersion = res.DataBrokerServerVersion
record.DatabrokerRecordVersion = res.DataBrokerRecordVersion
}
sealed, err := enc.Encrypt(record) sealed, err := enc.Encrypt(record)
if err != nil { if err != nil {
log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record") log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record")
@ -96,26 +92,50 @@ func (a *Authorize) logAuthorizeCheck(
} }
} }
func (a *Authorize) populateLogSessionDetails(evt *zerolog.Event, s *session.Session) *zerolog.Event { func (a *Authorize) populateLogSessionDetails(ctx context.Context, evt *zerolog.Event, s *session.Session) *zerolog.Event {
evt = evt.Str("session-id", s.GetId()) evt = evt.Str("session-id", s.GetId())
if s.GetImpersonateSessionId() == "" { if s.GetImpersonateSessionId() == "" {
return evt return evt
} }
querier := storage.GetQuerier(ctx)
evt = evt.Str("impersonate-session-id", s.GetImpersonateSessionId()) evt = evt.Str("impersonate-session-id", s.GetImpersonateSessionId())
impersonatedSession, ok := a.store.GetRecordData( req := &databroker.QueryRequest{
grpcutil.GetTypeURL(new(session.Session)), Type: grpcutil.GetTypeURL(new(session.Session)),
s.GetImpersonateSessionId(), Limit: 1,
).(*session.Session) }
req.SetFilterByID(s.GetImpersonateSessionId())
res, err := querier.Query(ctx, req)
if err != nil || len(res.GetRecords()) == 0 {
return evt
}
impersonatedSessionMsg, err := res.GetRecords()[0].GetData().UnmarshalNew()
if err != nil {
return evt
}
impersonatedSession, ok := impersonatedSessionMsg.(*session.Session)
if !ok { if !ok {
return evt return evt
} }
evt = evt.Str("impersonate-user-id", impersonatedSession.GetUserId()) evt = evt.Str("impersonate-user-id", impersonatedSession.GetUserId())
impersonatedUser, ok := a.store.GetRecordData( req = &databroker.QueryRequest{
grpcutil.GetTypeURL(new(user.User)), Type: grpcutil.GetTypeURL(new(user.User)),
impersonatedSession.GetUserId(), Limit: 1,
).(*user.User) }
req.SetFilterByID(impersonatedSession.GetUserId())
res, err = querier.Query(ctx, req)
if err != nil || len(res.GetRecords()) == 0 {
return evt
}
impersonatedUserMsg, err := res.GetRecords()[0].GetData().UnmarshalNew()
if err != nil {
return evt
}
impersonatedUser, ok := impersonatedUserMsg.(*user.User)
if !ok { if !ok {
return evt return evt
} }

View file

@ -1,209 +0,0 @@
package authorize
import (
"context"
"errors"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
const (
forceSyncRecordMaxWait = 5 * time.Second
)
type sessionOrServiceAccount interface {
GetUserId() string
}
type dataBrokerSyncer struct {
*databroker.Syncer
authorize *Authorize
signalOnce sync.Once
}
func newDataBrokerSyncer(authorize *Authorize) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
authorize: authorize,
}
syncer.Syncer = databroker.NewSyncer("authorize", syncer)
return syncer
}
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return syncer.authorize.state.Load().dataBrokerClient
}
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
syncer.authorize.stateLock.Lock()
syncer.authorize.store.ClearRecords()
syncer.authorize.stateLock.Unlock()
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
syncer.authorize.stateLock.Lock()
for _, record := range records {
syncer.authorize.store.UpdateRecord(serverVersion, record)
}
syncer.authorize.stateLock.Unlock()
// the first time we update records we signal the initial sync
syncer.signalOnce.Do(func() {
log.Info(ctx).Msg("initial sync from databroker complete")
close(syncer.authorize.dataBrokerInitialSync)
})
}
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionOrServiceAccount, *user.User, error) {
ctx, span := trace.StartSpan(ctx, "authorize.forceSync")
defer span.End()
if ss == nil {
return nil, nil, nil
}
// if the session state has databroker versions, wait for those to finish syncing
if ss.DatabrokerServerVersion != 0 && ss.DatabrokerRecordVersion != 0 {
a.forceSyncToVersion(ctx, ss.DatabrokerServerVersion, ss.DatabrokerRecordVersion)
}
s := a.forceSyncSession(ctx, ss.ID)
if s == nil {
return nil, nil, errors.New("session not found")
}
u := a.forceSyncUser(ctx, s.GetUserId())
return s, u, nil
}
func (a *Authorize) forceSyncToVersion(ctx context.Context, serverVersion, recordVersion uint64) (ready bool) {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncToVersion")
defer span.End()
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
defer clearTimeout()
ticker := time.NewTicker(time.Millisecond * 50)
defer ticker.Stop()
for {
currentServerVersion, currentRecordVersion := a.store.GetDataBrokerVersions()
// check if the local record version is up to date with the expected record version
if currentServerVersion == serverVersion && currentRecordVersion >= recordVersion {
return true
}
select {
case <-ctx.Done():
return false
case <-ticker.C:
}
}
}
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
defer span.End()
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
defer clearTimeout()
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
if ok {
a.accessTracker.TrackSessionAccess(sessionID)
return s
}
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
if ok {
a.accessTracker.TrackServiceAccountAccess(sessionID)
return sa
}
// wait for the session to show up
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
if err != nil {
return nil
}
s, ok = record.(*session.Session)
if !ok {
return nil
}
a.accessTracker.TrackSessionAccess(sessionID)
return s
}
func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User {
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
defer span.End()
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
defer clearTimeout()
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
if ok {
return u
}
// wait for the user to show up
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
if err != nil {
return nil
}
u, ok = record.(*user.User)
if !ok {
return nil
}
return u
}
// waitForRecordSync waits for the first sync of a record to complete
func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, recordID string) (proto.Message, error) {
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = time.Millisecond
bo.MaxElapsedTime = 0
bo.Reset()
for {
current := a.store.GetRecordData(recordTypeURL, recordID)
if current != nil {
// record found, so it's already synced
return current, nil
}
_, err := a.state.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: recordTypeURL,
Id: recordID,
})
if status.Code(err) == codes.NotFound {
// record not found, so no need to wait
return nil, nil
} else if err != nil {
log.Error(ctx).
Err(err).
Str("type", recordTypeURL).
Str("id", recordID).
Msg("authorize: error retrieving record")
return nil, err
}
select {
case <-ctx.Done():
log.Warn(ctx).
Str("type", recordTypeURL).
Str("id", recordID).
Msg("authorize: first sync of record did not complete")
return nil, ctx.Err()
case <-time.After(bo.NextBackOff()):
}
}
}

View file

@ -1,150 +0,0 @@
package authorize
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestAuthorize_forceSyncToVersion(t *testing.T) {
o := &config.Options{
AuthenticateURLString: "https://authN.example.com",
DataBrokerURLString: "https://databroker.example.com",
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
Policies: testPolicies(t),
}
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
a.store.UpdateRecord(1, &databroker.Record{
Version: 1,
})
t.Run("ready", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
assert.True(t, a.forceSyncToVersion(ctx, 1, 1))
})
t.Run("not ready", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
assert.False(t, a.forceSyncToVersion(ctx, 1, 2))
})
t.Run("becomes ready", func(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
go func() {
<-time.After(time.Millisecond * 100)
a.store.UpdateRecord(1, &databroker.Record{
Version: 2,
})
}()
assert.True(t, a.forceSyncToVersion(ctx, 1, 2))
})
}
func TestAuthorize_waitForRecordSync(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
o := &config.Options{
AuthenticateURLString: "https://authN.example.com",
DataBrokerURLString: "https://databroker.example.com",
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
Policies: testPolicies(t),
}
t.Run("skip if exists", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
a.store.UpdateRecord(0, newRecord(&session.Session{
Id: "SESSION_ID",
}))
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
panic("should never be called")
},
}
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
})
t.Run("skip if not found", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
callCount := 0
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
callCount++
return nil, status.Error(codes.NotFound, "not found")
},
}
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
assert.Equal(t, 1, callCount, "should be called once")
})
t.Run("poll", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
callCount := 0
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
callCount++
switch callCount {
case 1:
s := &session.Session{Id: "SESSION_ID"}
a.store.UpdateRecord(0, newRecord(s))
return &databroker.GetResponse{Record: newRecord(s)}, nil
default:
panic("should never be called")
}
},
}
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
})
t.Run("timeout", func(t *testing.T) {
a, err := New(&config.Config{Options: o})
require.NoError(t, err)
tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100)
defer clearTimeout()
callCount := 0
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
callCount++
s := &session.Session{Id: "SESSION_ID"}
return &databroker.GetResponse{Record: newRecord(s)}, nil
},
}
a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism
})
}
type storableMessage interface {
proto.Message
GetId() string
}
func newRecord(msg storableMessage) *databroker.Record {
any := protoutil.NewAny(msg)
return &databroker.Record{
Version: 1,
Type: any.GetTypeUrl(),
Id: msg.GetId(),
Data: any,
}
}

16
go.mod
View file

@ -6,7 +6,9 @@ require (
contrib.go.opencensus.io/exporter/jaeger v0.2.1 contrib.go.opencensus.io/exporter/jaeger v0.2.1
contrib.go.opencensus.io/exporter/prometheus v0.4.1 contrib.go.opencensus.io/exporter/prometheus v0.4.1
contrib.go.opencensus.io/exporter/zipkin v0.1.2 contrib.go.opencensus.io/exporter/zipkin v0.1.2
github.com/CAFxX/httpcompression v0.0.8
github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0 github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0
github.com/VictoriaMetrics/fastcache v1.10.0
github.com/caddyserver/certmagic v0.16.0 github.com/caddyserver/certmagic v0.16.0
github.com/cenkalti/backoff/v4 v4.1.3 github.com/cenkalti/backoff/v4 v4.1.3
github.com/cespare/xxhash/v2 v2.1.2 github.com/cespare/xxhash/v2 v2.1.2
@ -29,9 +31,11 @@ require (
github.com/gorilla/handlers v1.5.1 github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/jackc/pgconn v1.12.1
github.com/jackc/pgtype v1.11.0
github.com/jackc/pgx/v4 v4.16.1
github.com/martinlindhe/base36 v1.1.0 github.com/martinlindhe/base36 v1.1.0
github.com/mholt/acmez v1.0.2 github.com/mholt/acmez v1.0.2
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
@ -75,14 +79,6 @@ require (
sigs.k8s.io/yaml v1.3.0 sigs.k8s.io/yaml v1.3.0
) )
require (
github.com/CAFxX/httpcompression v0.0.8
github.com/jackc/pgconn v1.12.1
github.com/jackc/pgtype v1.11.0
github.com/jackc/pgx/v4 v4.16.1
github.com/kentik/patricia v1.0.0
)
require ( require (
4d63.com/gochecknoglobals v0.1.0 // indirect 4d63.com/gochecknoglobals v0.1.0 // indirect
cloud.google.com/go/compute v1.6.1 // indirect cloud.google.com/go/compute v1.6.1 // indirect
@ -148,6 +144,7 @@ require (
github.com/gofrs/flock v0.8.1 // indirect github.com/gofrs/flock v0.8.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect
github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect
github.com/golangci/go-misc v0.0.0-20220329215616-d24fe342adfe // indirect github.com/golangci/go-misc v0.0.0-20220329215616-d24fe342adfe // indirect
@ -166,6 +163,7 @@ require (
github.com/gostaticanalysis/comment v1.4.2 // indirect github.com/gostaticanalysis/comment v1.4.2 // indirect
github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect
github.com/gostaticanalysis/nilerr v0.1.1 // indirect github.com/gostaticanalysis/nilerr v0.1.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-version v1.4.0 // indirect github.com/hashicorp/go-version v1.4.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect

7
go.sum
View file

@ -165,6 +165,8 @@ github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWX
github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs= github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0=
github.com/VictoriaMetrics/fastcache v1.10.0 h1:5hDJnLsKLpnUEToub7ETuRu8RCkb40woBZAUiKonXzY=
github.com/VictoriaMetrics/fastcache v1.10.0/go.mod h1:tjiYeEfYXCqacuvYw/7UoDIeJaNxq6132xHICNP77w8=
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
@ -175,6 +177,8 @@ github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:C
github.com/alexflint/go-filemutex v1.1.0/go.mod h1:7P4iRhttt/nUvUOrYIhcpMzv2G6CY9UnI16Z+UJqRyk= github.com/alexflint/go-filemutex v1.1.0/go.mod h1:7P4iRhttt/nUvUOrYIhcpMzv2G6CY9UnI16Z+UJqRyk=
github.com/alexkohler/prealloc v1.0.0 h1:Hbq0/3fJPQhNkN0dR95AVrr6R7tou91y0uHG5pOcUuw= github.com/alexkohler/prealloc v1.0.0 h1:Hbq0/3fJPQhNkN0dR95AVrr6R7tou91y0uHG5pOcUuw=
github.com/alexkohler/prealloc v1.0.0/go.mod h1:VetnK3dIgFBBKmg0YnD9F9x6Icjd+9cvfHR56wJVlKE= github.com/alexkohler/prealloc v1.0.0/go.mod h1:VetnK3dIgFBBKmg0YnD9F9x6Icjd+9cvfHR56wJVlKE=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
@ -995,8 +999,6 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8
github.com/julz/importas v0.1.0 h1:F78HnrsjY3cR7j0etXy5+TU1Zuy7Xt08X/1aJnH5xXY= github.com/julz/importas v0.1.0 h1:F78HnrsjY3cR7j0etXy5+TU1Zuy7Xt08X/1aJnH5xXY=
github.com/julz/importas v0.1.0/go.mod h1:oSFU2R4XK/P7kNBrnL/FEQlDGN1/6WoxXEjSSXO0DV0= github.com/julz/importas v0.1.0/go.mod h1:oSFU2R4XK/P7kNBrnL/FEQlDGN1/6WoxXEjSSXO0DV0=
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k=
github.com/kentik/patricia v1.0.0 h1:jx/8kXf0JvQEHNPX4njL+PDzpxxqNKg0RjA8hJcX38A=
github.com/kentik/patricia v1.0.0/go.mod h1:e0nkPLU9NQl8v05ukfHU6+R5ykbKcXO+NqaC3ifTm0Y=
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@ -2054,6 +2056,7 @@ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View file

@ -26,10 +26,8 @@ type Record struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Request *v3.CheckRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"` Request *v3.CheckRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"`
Response *v3.CheckResponse `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` Response *v3.CheckResponse `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"`
DatabrokerServerVersion uint64 `protobuf:"varint,3,opt,name=databroker_server_version,json=databrokerServerVersion,proto3" json:"databroker_server_version,omitempty"`
DatabrokerRecordVersion uint64 `protobuf:"varint,4,opt,name=databroker_record_version,json=databrokerRecordVersion,proto3" json:"databroker_record_version,omitempty"`
} }
func (x *Record) Reset() { func (x *Record) Reset() {
@ -78,20 +76,6 @@ func (x *Record) GetResponse() *v3.CheckResponse {
return nil return nil
} }
func (x *Record) GetDatabrokerServerVersion() uint64 {
if x != nil {
return x.DatabrokerServerVersion
}
return 0
}
func (x *Record) GetDatabrokerRecordVersion() uint64 {
if x != nil {
return x.DatabrokerRecordVersion
}
return 0
}
var File_audit_proto protoreflect.FileDescriptor var File_audit_proto protoreflect.FileDescriptor
var file_audit_proto_rawDesc = []byte{ var file_audit_proto_rawDesc = []byte{
@ -99,7 +83,7 @@ var file_audit_proto_rawDesc = []byte{
0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x61, 0x75, 0x64, 0x69, 0x74, 0x1a, 0x29, 0x65, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x61, 0x75, 0x64, 0x69, 0x74, 0x1a, 0x29, 0x65,
0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x61, 0x75, 0x74, 0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x61, 0x75, 0x74,
0x68, 0x2f, 0x76, 0x33, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, 0x75, 0x68, 0x2f, 0x76, 0x33, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, 0x75,
0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x02, 0x0a, 0x06, 0x52, 0x65, 0x63, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x89, 0x01, 0x0a, 0x06, 0x52, 0x65, 0x63,
0x6f, 0x72, 0x64, 0x12, 0x3d, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x6f, 0x72, 0x64, 0x12, 0x3d, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01,
0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65,
@ -108,18 +92,10 @@ var file_audit_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65,
0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x19, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
0x65, 0x72, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65,
0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x17, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75,
0x6b, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x12, 0x3a, 0x0a, 0x19, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x5f, 0x72,
0x65, 0x63, 0x6f, 0x72, 0x64, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
0x01, 0x28, 0x04, 0x52, 0x17, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x52,
0x65, 0x63, 0x6f, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x2d, 0x5a, 0x2b,
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72,
0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67,
0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x64, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
} }
var ( var (

View file

@ -8,6 +8,4 @@ import "envoy/service/auth/v3/external_auth.proto";
message Record { message Record {
envoy.service.auth.v3.CheckRequest request = 1; envoy.service.auth.v3.CheckRequest request = 1;
envoy.service.auth.v3.CheckResponse response = 2; envoy.service.auth.v3.CheckResponse response = 2;
uint64 databroker_server_version = 3;
uint64 databroker_record_version = 4;
} }

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
structpb "google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
@ -118,6 +119,27 @@ func (x *PutResponse) GetRecord() *Record {
return records[0] return records[0]
} }
// SetFilterByID sets the filter to an id.
func (x *QueryRequest) SetFilterByID(id string) {
x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue(id),
}}
}
// SetFilterByIDOrIndex sets the filter to an id or an index.
func (x *QueryRequest) SetFilterByIDOrIndex(idOrIndex string) {
x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue(idOrIndex),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"$index": structpb.NewStringValue(idOrIndex),
}}),
}}),
}}
}
// default is 4MB, but we'll do 1MB // default is 4MB, but we'll do 1MB
const maxMessageSize = 1024 * 1024 * 1 const maxMessageSize = 1024 * 1024 * 1

155
pkg/storage/cache.go Normal file
View file

@ -0,0 +1,155 @@
package storage
import (
"context"
"encoding/binary"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache"
"golang.org/x/sync/singleflight"
)
// A Cache will return cached data when available or call update when not.
type Cache interface {
GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error)
Invalidate(key []byte)
}
type localCache struct {
singleflight singleflight.Group
mu sync.RWMutex
m map[string][]byte
}
// NewLocalCache creates a new Cache backed by a map.
func NewLocalCache() Cache {
return &localCache{
m: make(map[string][]byte),
}
}
func (cache *localCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
strkey := string(key)
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
v, err, _ := cache.singleflight.Do(strkey, func() (interface{}, error) {
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
result, err := update(ctx)
if err != nil {
return nil, err
}
cache.mu.Lock()
cache.m[strkey] = result
cache.mu.Unlock()
return result, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *localCache) Invalidate(key []byte) {
cache.mu.Lock()
delete(cache.m, string(key))
cache.mu.Unlock()
}
type globalCache struct {
ttl time.Duration
singleflight singleflight.Group
mu sync.RWMutex
fastcache *fastcache.Cache
}
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
func NewGlobalCache(ttl time.Duration) Cache {
return &globalCache{
ttl: ttl,
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
}
}
func (cache *globalCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
now := time.Now()
data, expiry, ok := cache.get(key)
if ok && now.Before(expiry) {
return data, nil
}
v, err, _ := cache.singleflight.Do(string(key), func() (interface{}, error) {
data, expiry, ok := cache.get(key)
if ok && now.Before(expiry) {
return data, nil
}
value, err := update(ctx)
if err != nil {
return nil, err
}
cache.set(key, value)
return value, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *globalCache) Invalidate(key []byte) {
cache.mu.Lock()
cache.fastcache.Del(key)
cache.mu.Unlock()
}
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
cache.mu.RLock()
item := cache.fastcache.Get(nil, k)
cache.mu.RUnlock()
if len(item) < 8 {
return data, expiry, false
}
unix, data := binary.LittleEndian.Uint64(item), item[8:]
expiry = time.UnixMilli(int64(unix))
return data, expiry, true
}
func (cache *globalCache) set(k, v []byte) {
unix := time.Now().Add(cache.ttl).UnixMilli()
item := make([]byte, len(v)+8)
binary.LittleEndian.PutUint64(item, uint64(unix))
copy(item[8:], v)
cache.mu.Lock()
cache.fastcache.Set(k, item)
cache.mu.Unlock()
}

73
pkg/storage/cache_test.go Normal file
View file

@ -0,0 +1,73 @@
package storage
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLocalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
callCount := 0
update := func(ctx context.Context) ([]byte, error) {
callCount++
return []byte("v1"), nil
}
c := NewLocalCache()
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
c.Invalidate([]byte("k1"))
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 2, callCount)
}
func TestGlobalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
callCount := 0
update := func(ctx context.Context) ([]byte, error) {
callCount++
return []byte("v1"), nil
}
c := NewGlobalCache(time.Millisecond * 100)
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
c.Invalidate([]byte("k1"))
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 2, callCount)
assert.Eventually(t, func() bool {
_, err := c.GetOrUpdate(ctx, []byte("k1"), func(ctx context.Context) ([]byte, error) {
return nil, fmt.Errorf("ERROR")
})
return err != nil
}, time.Second, time.Millisecond*10, "should honor TTL")
}

210
pkg/storage/querier.go Normal file
View file

@ -0,0 +1,210 @@
package storage
import (
"context"
"sync"
"github.com/google/uuid"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// A Querier is a read-only subset of the client methods
type Querier interface {
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
}
// nilQuerier always returns NotFound.
type nilQuerier struct{}
func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
}
var querierKey struct{}
// GetQuerier gets the databroker Querier from the context.
func GetQuerier(ctx context.Context) Querier {
q, ok := ctx.Value(querierKey).(Querier)
if !ok {
q = nilQuerier{}
}
return q
}
// WithQuerier sets the databroker Querier on a context.
func WithQuerier(ctx context.Context, querier Querier) context.Context {
return context.WithValue(ctx, querierKey, querier)
}
type staticQuerier struct {
records []*databroker.Record
}
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
func NewStaticQuerier(msgs ...proto.Message) Querier {
getter := &staticQuerier{}
for _, msg := range msgs {
any := protoutil.NewAny(msg)
record := new(databroker.Record)
record.ModifiedAt = timestamppb.Now()
record.Version = cryptutil.NewRandomUInt64()
record.Id = uuid.New().String()
record.Data = any
record.Type = any.TypeUrl
if hasID, ok := msg.(interface{ GetId() string }); ok {
record.Id = hasID.GetId()
}
getter.records = append(getter.records, record)
}
return getter
}
// Query queries for records.
func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
expr, err := FilterExpressionFromStruct(in.GetFilter())
if err != nil {
return nil, err
}
filter, err := RecordStreamFilterFromFilterExpression(expr)
if err != nil {
return nil, err
}
res := new(databroker.QueryResponse)
for _, record := range q.records {
if record.GetType() != in.GetType() {
continue
}
if !filter(record) {
continue
}
if in.GetQuery() != "" && !MatchAny(record.GetData(), in.GetQuery()) {
continue
}
res.Records = append(res.Records, record)
}
var total int
res.Records, total = databroker.ApplyOffsetAndLimit(
res.Records,
int(in.GetOffset()),
int(in.GetLimit()),
)
res.TotalCount = int64(total)
return res, nil
}
type clientQuerier struct {
client databroker.DataBrokerServiceClient
}
// NewQuerier creates a new Querier that implements the Querier interface by making calls to the databroker over gRPC.
func NewQuerier(client databroker.DataBrokerServiceClient) Querier {
return &clientQuerier{client: client}
}
// Query queries for records.
func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
return q.client.Query(ctx, in, opts...)
}
// A TracingQuerier records calls to Query.
type TracingQuerier struct {
underlying Querier
mu sync.Mutex
traces []QueryTrace
}
// A QueryTrace traces a call to Query.
type QueryTrace struct {
ServerVersion, RecordVersion uint64
RecordType string
Query string
Filter *structpb.Struct
}
// NewTracingQuerier creates a new TracingQuerier.
func NewTracingQuerier(q Querier) *TracingQuerier {
return &TracingQuerier{
underlying: q,
}
}
// Query queries for records.
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
res, err := q.underlying.Query(ctx, in, opts...)
if err == nil {
q.mu.Lock()
q.traces = append(q.traces, QueryTrace{
RecordType: in.GetType(),
Query: in.GetQuery(),
Filter: in.GetFilter(),
})
q.mu.Unlock()
}
return res, err
}
// Traces returns all the traces.
func (q *TracingQuerier) Traces() []QueryTrace {
q.mu.Lock()
traces := make([]QueryTrace, len(q.traces))
copy(traces, q.traces)
q.mu.Unlock()
return traces
}
type cachingQuerier struct {
q Querier
cache Cache
}
// NewCachingQuerier creates a new querier that caches results in a Cache.
func NewCachingQuerier(q Querier, cache Cache) Querier {
return &cachingQuerier{
q: q,
cache: cache,
}
}
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
if err != nil {
return nil, err
}
rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) {
res, err := q.q.Query(ctx, in, opts...)
if err != nil {
return nil, err
}
return proto.Marshal(res)
})
if err != nil {
return nil, err
}
var res databroker.QueryResponse
err = proto.Unmarshal(rawResult, &res)
if err != nil {
return nil, err
}
return &res, nil
}