mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 22:47:14 +02:00
authorize: use query instead of sync for databroker data (#3377)
This commit is contained in:
parent
fd82cc7870
commit
f61e7efe73
24 changed files with 661 additions and 1008 deletions
|
@ -8,6 +8,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
@ -17,6 +19,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authorize struct holds
|
// Authorize struct holds
|
||||||
|
@ -25,8 +28,7 @@ type Authorize struct {
|
||||||
store *store.Store
|
store *store.Store
|
||||||
currentOptions *config.AtomicOptions
|
currentOptions *config.AtomicOptions
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
|
globalCache storage.Cache
|
||||||
dataBrokerInitialSync chan struct{}
|
|
||||||
|
|
||||||
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
|
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
|
||||||
// This should provide a consistent view of the data at a given server/record version and
|
// This should provide a consistent view of the data at a given server/record version and
|
||||||
|
@ -39,7 +41,7 @@ func New(cfg *config.Config) (*Authorize, error) {
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentOptions: config.NewAtomicOptions(),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
dataBrokerInitialSync: make(chan struct{}),
|
globalCache: storage.NewGlobalCache(time.Minute),
|
||||||
}
|
}
|
||||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||||
|
|
||||||
|
@ -59,19 +61,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
|
||||||
|
|
||||||
// Run runs the authorize service.
|
// Run runs the authorize service.
|
||||||
func (a *Authorize) Run(ctx context.Context) error {
|
func (a *Authorize) Run(ctx context.Context) error {
|
||||||
go a.accessTracker.Run(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
_ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10)
|
eg.Go(func() error {
|
||||||
return newDataBrokerSyncer(a).Run(ctx)
|
a.accessTracker.Run(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
// WaitForInitialSync blocks until the initial sync is complete.
|
|
||||||
func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-a.dataBrokerInitialSync:
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
_ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return eg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateOptions(o *config.Options) error {
|
func validateOptions(o *config.Options) error {
|
||||||
|
|
|
@ -20,8 +20,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthorize_okResponse(t *testing.T) {
|
func TestAuthorize_okResponse(t *testing.T) {
|
||||||
|
@ -40,17 +38,7 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||||
a.state.Load().encoder = encoder
|
a.state.Load().encoder = encoder
|
||||||
a.currentOptions.Store(opt)
|
a.currentOptions.Store(opt)
|
||||||
a.store = store.NewFromProtos(0,
|
a.store = store.New()
|
||||||
&session.Session{
|
|
||||||
Id: "SESSION_ID",
|
|
||||||
UserId: "USER_ID",
|
|
||||||
},
|
|
||||||
&user.User{
|
|
||||||
Id: "USER_ID",
|
|
||||||
Name: "foo",
|
|
||||||
Email: "foo@example.com",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pe, err := newPolicyEvaluator(opt, a.store)
|
pe, err := newPolicyEvaluator(opt, a.store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
a.state.Load().evaluator = pe
|
a.state.Load().evaluator = pe
|
||||||
|
|
48
authorize/databroker.go
Normal file
48
authorize/databroker.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package authorize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/open-policy-agent/opa/storage"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionOrServiceAccount interface {
|
||||||
|
GetUserId() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
client := a.state.Load().dataBrokerClient
|
||||||
|
|
||||||
|
s, err = session.Get(ctx, client, sessionID)
|
||||||
|
if storage.IsNotFound(err) {
|
||||||
|
s, err = user.GetServiceAccount(ctx, client, sessionID)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := s.(*session.Session); ok {
|
||||||
|
a.accessTracker.TrackSessionAccess(sessionID)
|
||||||
|
}
|
||||||
|
if _, ok := s.(*user.ServiceAccount); ok {
|
||||||
|
a.accessTracker.TrackServiceAccountAccess(sessionID)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) getDataBrokerUser(ctx context.Context, userID string) (u *user.User, err error) {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
client := a.state.Load().dataBrokerClient
|
||||||
|
|
||||||
|
u, err = user.Get(ctx, client, userID)
|
||||||
|
return u, err
|
||||||
|
}
|
|
@ -72,8 +72,6 @@ type Result struct {
|
||||||
Allow RuleResult
|
Allow RuleResult
|
||||||
Deny RuleResult
|
Deny RuleResult
|
||||||
Headers http.Header
|
Headers http.Header
|
||||||
|
|
||||||
DataBrokerServerVersion, DataBrokerRecordVersion uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// An Evaluator evaluates policies.
|
// An Evaluator evaluates policies.
|
||||||
|
@ -170,7 +168,6 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
|
||||||
Deny: policyOutput.Deny,
|
Deny: policyOutput.Deny,
|
||||||
Headers: headersOutput.Headers,
|
Headers: headersOutput.Headers,
|
||||||
}
|
}
|
||||||
res.DataBrokerServerVersion, res.DataBrokerRecordVersion = e.store.GetDataBrokerVersions()
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,29 +2,24 @@ package evaluator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEvaluator(t *testing.T) {
|
func TestEvaluator(t *testing.T) {
|
||||||
|
@ -36,13 +31,15 @@ func TestEvaluator(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
eval := func(t *testing.T, options []Option, data []proto.Message, req *Request) (*Result, error) {
|
eval := func(t *testing.T, options []Option, data []proto.Message, req *Request) (*Result, error) {
|
||||||
store := store.NewFromProtos(math.MaxUint64, data...)
|
ctx := context.Background()
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
|
||||||
|
store := store.New()
|
||||||
store.UpdateIssuer("authenticate.example.com")
|
store.UpdateIssuer("authenticate.example.com")
|
||||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
||||||
store.UpdateSigningKey(privateJWK)
|
store.UpdateSigningKey(privateJWK)
|
||||||
e, err := New(context.Background(), store, options...)
|
e, err := New(ctx, store, options...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return e.Evaluate(context.Background(), req)
|
return e.Evaluate(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
policies := []config.Policy{
|
policies := []config.Policy{
|
||||||
|
@ -511,104 +508,3 @@ func mustParseURL(str string) *url.URL {
|
||||||
}
|
}
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
|
||||||
store := store.New()
|
|
||||||
|
|
||||||
policies := []config.Policy{
|
|
||||||
{
|
|
||||||
From: "https://from.example.com",
|
|
||||||
To: config.WeightedURLs{
|
|
||||||
{URL: *mustParseURL("https://to.example.com")},
|
|
||||||
},
|
|
||||||
AllowedUsers: []string{"SOME_USER"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
options := []Option{
|
|
||||||
WithAuthenticateURL("https://authn.example.com"),
|
|
||||||
WithPolicies(policies),
|
|
||||||
}
|
|
||||||
|
|
||||||
e, err := New(context.Background(), store, options...)
|
|
||||||
if !assert.NoError(b, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastSessionID := ""
|
|
||||||
|
|
||||||
for i := 0; i < 100000; i++ {
|
|
||||||
sessionID := uuid.New().String()
|
|
||||||
lastSessionID = sessionID
|
|
||||||
userID := uuid.New().String()
|
|
||||||
data := protoutil.NewAny(&session.Session{
|
|
||||||
Version: fmt.Sprint(i),
|
|
||||||
Id: sessionID,
|
|
||||||
UserId: userID,
|
|
||||||
IdToken: &session.IDToken{
|
|
||||||
Issuer: "benchmark",
|
|
||||||
Subject: userID,
|
|
||||||
IssuedAt: timestamppb.Now(),
|
|
||||||
},
|
|
||||||
OauthToken: &session.OAuthToken{
|
|
||||||
AccessToken: "ACCESS TOKEN",
|
|
||||||
TokenType: "Bearer",
|
|
||||||
RefreshToken: "REFRESH TOKEN",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
store.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: uint64(i),
|
|
||||||
Type: "type.googleapis.com/session.Session",
|
|
||||||
Id: sessionID,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
data = protoutil.NewAny(&user.User{
|
|
||||||
Version: fmt.Sprint(i),
|
|
||||||
Id: userID,
|
|
||||||
})
|
|
||||||
store.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: uint64(i),
|
|
||||||
Type: "type.googleapis.com/user.User",
|
|
||||||
Id: userID,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
|
|
||||||
data = protoutil.NewAny(&directory.User{
|
|
||||||
Version: fmt.Sprint(i),
|
|
||||||
Id: userID,
|
|
||||||
GroupIds: []string{"1", "2", "3", "4"},
|
|
||||||
})
|
|
||||||
store.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: uint64(i),
|
|
||||||
Type: data.TypeUrl,
|
|
||||||
Id: userID,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
|
|
||||||
data = protoutil.NewAny(&directory.Group{
|
|
||||||
Version: fmt.Sprint(i),
|
|
||||||
Id: fmt.Sprint(i),
|
|
||||||
})
|
|
||||||
store.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: uint64(i),
|
|
||||||
Type: data.TypeUrl,
|
|
||||||
Id: fmt.Sprint(i),
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
ctx := context.Background()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_, _ = e.Evaluate(ctx, &Request{
|
|
||||||
Policy: &policies[0],
|
|
||||||
HTTP: RequestHTTP{
|
|
||||||
Method: "GET",
|
|
||||||
URL: "https://example.com/path",
|
|
||||||
Headers: map[string]string{},
|
|
||||||
},
|
|
||||||
Session: RequestSession{
|
|
||||||
ID: lastSessionID,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewHeadersRequestFromPolicy(t *testing.T) {
|
func TestNewHeadersRequestFromPolicy(t *testing.T) {
|
||||||
|
@ -51,13 +52,15 @@ func TestHeadersEvaluator(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
|
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
|
||||||
store := store.NewFromProtos(math.MaxUint64, data...)
|
ctx := context.Background()
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
|
||||||
|
store := store.New()
|
||||||
store.UpdateIssuer("authenticate.example.com")
|
store.UpdateIssuer("authenticate.example.com")
|
||||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
||||||
store.UpdateSigningKey(privateJWK)
|
store.UpdateSigningKey(privateJWK)
|
||||||
e, err := NewHeadersEvaluator(context.Background(), store)
|
e, err := NewHeadersEvaluator(ctx, store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return e.Evaluate(context.Background(), input)
|
return e.Evaluate(ctx, input)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("groups", func(t *testing.T) {
|
t.Run("groups", func(t *testing.T) {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package evaluator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"math"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -18,6 +17,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/policy"
|
"github.com/pomerium/pomerium/pkg/policy"
|
||||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPolicyEvaluator(t *testing.T) {
|
func TestPolicyEvaluator(t *testing.T) {
|
||||||
|
@ -29,13 +29,15 @@ func TestPolicyEvaluator(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) {
|
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) {
|
||||||
store := store.NewFromProtos(math.MaxUint64, data...)
|
ctx := context.Background()
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
|
||||||
|
store := store.New()
|
||||||
store.UpdateIssuer("authenticate.example.com")
|
store.UpdateIssuer("authenticate.example.com")
|
||||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
||||||
store.UpdateSigningKey(privateJWK)
|
store.UpdateSigningKey(privateJWK)
|
||||||
e, err := NewPolicyEvaluator(context.Background(), store, policy)
|
e, err := NewPolicyEvaluator(ctx, store, policy)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return e.Evaluate(context.Background(), input)
|
return e.Evaluate(ctx, input)
|
||||||
}
|
}
|
||||||
|
|
||||||
p1 := &config.Policy{
|
p1 := &config.Policy{
|
||||||
|
|
|
@ -16,6 +16,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check implements the envoy auth server gRPC endpoint.
|
// Check implements the envoy auth server gRPC endpoint.
|
||||||
|
@ -23,10 +25,16 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
|
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
// wait for the initial sync to complete so that data is available for evaluation
|
querier := storage.NewTracingQuerier(
|
||||||
if err := a.WaitForInitialSync(ctx); err != nil {
|
storage.NewCachingQuerier(
|
||||||
return nil, err
|
storage.NewCachingQuerier(
|
||||||
}
|
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||||
|
a.globalCache,
|
||||||
|
),
|
||||||
|
storage.NewLocalCache(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
ctx = storage.WithQuerier(ctx, querier)
|
||||||
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
|
@ -48,11 +56,22 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
|
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
|
||||||
sessionState, _ := loadSession(state.encoder, rawJWT)
|
sessionState, _ := loadSession(state.encoder, rawJWT)
|
||||||
|
|
||||||
s, u, err := a.forceSync(ctx, sessionState)
|
var s sessionOrServiceAccount
|
||||||
|
var u *user.User
|
||||||
|
if sessionState != nil {
|
||||||
|
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
|
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
|
||||||
sessionState = nil
|
sessionState = nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if s != nil {
|
||||||
|
u, err = a.getDataBrokerUser(ctx, s.GetUserId())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
|
||||||
|
sessionState = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
|
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -337,8 +337,6 @@ func TestAuthorize_Check(t *testing.T) {
|
||||||
}
|
}
|
||||||
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
|
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
|
||||||
|
|
||||||
close(a.dataBrokerInitialSync)
|
|
||||||
|
|
||||||
cmpOpts := []cmp.Option{
|
cmpOpts := []cmp.Option{
|
||||||
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
|
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
|
||||||
cmpopts.IgnoreUnexported(status.Status{}),
|
cmpopts.IgnoreUnexported(status.Status{}),
|
||||||
|
|
|
@ -1,196 +0,0 @@
|
||||||
package store
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/kentik/patricia"
|
|
||||||
"github.com/kentik/patricia/string_tree"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
indexField = "$index"
|
|
||||||
cidrField = "cidr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type index struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
byType map[string]*recordIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIndex() *index {
|
|
||||||
idx := new(index)
|
|
||||||
idx.clear()
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *index) clear() {
|
|
||||||
idx.mu.Lock()
|
|
||||||
defer idx.mu.Unlock()
|
|
||||||
idx.byType = map[string]*recordIndex{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *index) delete(typeURL, id string) {
|
|
||||||
idx.mu.Lock()
|
|
||||||
defer idx.mu.Unlock()
|
|
||||||
|
|
||||||
ridx, ok := idx.byType[typeURL]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ridx.delete(id)
|
|
||||||
|
|
||||||
if len(ridx.byID) == 0 {
|
|
||||||
delete(idx.byType, typeURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *index) find(typeURL, id string) proto.Message {
|
|
||||||
idx.mu.RLock()
|
|
||||||
defer idx.mu.RUnlock()
|
|
||||||
|
|
||||||
ridx, ok := idx.byType[typeURL]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ridx.find(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *index) get(typeURL, id string) proto.Message {
|
|
||||||
idx.mu.RLock()
|
|
||||||
defer idx.mu.RUnlock()
|
|
||||||
|
|
||||||
ridx, ok := idx.byType[typeURL]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ridx.get(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *index) set(typeURL, id string, msg proto.Message) {
|
|
||||||
idx.mu.Lock()
|
|
||||||
defer idx.mu.Unlock()
|
|
||||||
|
|
||||||
ridx, ok := idx.byType[typeURL]
|
|
||||||
if !ok {
|
|
||||||
ridx = newRecordIndex()
|
|
||||||
idx.byType[typeURL] = ridx
|
|
||||||
}
|
|
||||||
ridx.set(id, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// a recordIndex indexes records for of a specific type
|
|
||||||
type recordIndex struct {
|
|
||||||
byID map[string]proto.Message
|
|
||||||
byCIDRV4 *string_tree.TreeV4
|
|
||||||
byCIDRV6 *string_tree.TreeV6
|
|
||||||
}
|
|
||||||
|
|
||||||
// newRecordIndex creates a new record index.
|
|
||||||
func newRecordIndex() *recordIndex {
|
|
||||||
return &recordIndex{
|
|
||||||
byID: map[string]proto.Message{},
|
|
||||||
byCIDRV4: string_tree.NewTreeV4(),
|
|
||||||
byCIDRV6: string_tree.NewTreeV6(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *recordIndex) delete(id string) {
|
|
||||||
r, ok := idx.byID[id]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(idx.byID, id)
|
|
||||||
|
|
||||||
addr4, addr6 := getIndexCIDR(r)
|
|
||||||
if addr4 != nil {
|
|
||||||
idx.byCIDRV4.Delete(*addr4, func(payload, val string) bool {
|
|
||||||
return payload == val
|
|
||||||
}, id)
|
|
||||||
}
|
|
||||||
if addr6 != nil {
|
|
||||||
idx.byCIDRV6.Delete(*addr6, func(payload, val string) bool {
|
|
||||||
return payload == val
|
|
||||||
}, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *recordIndex) find(idOrString string) proto.Message {
|
|
||||||
r, ok := idx.byID[idOrString]
|
|
||||||
if ok {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
addrv4, addrv6, _ := patricia.ParseIPFromString(idOrString)
|
|
||||||
if addrv4 != nil {
|
|
||||||
found, id := idx.byCIDRV4.FindDeepestTag(*addrv4)
|
|
||||||
if found {
|
|
||||||
return idx.byID[id]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if addrv6 != nil {
|
|
||||||
found, id := idx.byCIDRV6.FindDeepestTag(*addrv6)
|
|
||||||
if found {
|
|
||||||
return idx.byID[id]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *recordIndex) get(id string) proto.Message {
|
|
||||||
return idx.byID[id]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idx *recordIndex) set(id string, msg proto.Message) {
|
|
||||||
_, ok := idx.byID[id]
|
|
||||||
if ok {
|
|
||||||
idx.delete(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
idx.byID[id] = msg
|
|
||||||
addr4, addr6 := getIndexCIDR(msg)
|
|
||||||
if addr4 != nil {
|
|
||||||
idx.byCIDRV4.Set(*addr4, id)
|
|
||||||
}
|
|
||||||
if addr6 != nil {
|
|
||||||
idx.byCIDRV6.Set(*addr6, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIndexCIDR(msg proto.Message) (*patricia.IPv4Address, *patricia.IPv6Address) {
|
|
||||||
var s *structpb.Struct
|
|
||||||
if sv, ok := msg.(*structpb.Value); ok {
|
|
||||||
s = sv.GetStructValue()
|
|
||||||
} else {
|
|
||||||
s, _ = msg.(*structpb.Struct)
|
|
||||||
}
|
|
||||||
if s == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
f, ok := s.Fields[indexField]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
obj := f.GetStructValue()
|
|
||||||
if obj == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cf, ok := obj.Fields[cidrField]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c := cf.GetStringValue()
|
|
||||||
if c == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
addr4, addr6, _ := patricia.ParseIPFromString(c)
|
|
||||||
return addr4, addr6
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
package store
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestByID(t *testing.T) {
|
|
||||||
idx := newIndex()
|
|
||||||
|
|
||||||
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"id": structpb.NewStringValue("r1"),
|
|
||||||
}}
|
|
||||||
|
|
||||||
idx.set("example.com/record", "r1", r1)
|
|
||||||
assert.Equal(t, r1, idx.get("example.com/record", "r1"))
|
|
||||||
idx.delete("example.com/record", "r1")
|
|
||||||
assert.Nil(t, idx.get("example.com/record", "r1"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestByCIDR(t *testing.T) {
|
|
||||||
t.Run("ipv4", func(t *testing.T) {
|
|
||||||
idx := newIndex()
|
|
||||||
|
|
||||||
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
|
||||||
}}),
|
|
||||||
"id": structpb.NewStringValue("r1"),
|
|
||||||
}}
|
|
||||||
idx.set("example.com/record", "r1", r1)
|
|
||||||
|
|
||||||
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"cidr": structpb.NewStringValue("192.168.0.0/24"),
|
|
||||||
}}),
|
|
||||||
"id": structpb.NewStringValue("r2"),
|
|
||||||
}}
|
|
||||||
idx.set("example.com/record", "r2", r2)
|
|
||||||
|
|
||||||
assert.Equal(t, r2, idx.find("example.com/record", "192.168.0.7"))
|
|
||||||
idx.delete("example.com/record", "r2")
|
|
||||||
assert.Equal(t, r1, idx.find("example.com/record", "192.168.0.7"))
|
|
||||||
idx.delete("example.com/record", "r1")
|
|
||||||
assert.Nil(t, idx.find("example.com/record", "192.168.0.7"))
|
|
||||||
})
|
|
||||||
t.Run("ipv6", func(t *testing.T) {
|
|
||||||
idx := newIndex()
|
|
||||||
|
|
||||||
r1 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"cidr": structpb.NewStringValue("2001:db8::/32"),
|
|
||||||
}}),
|
|
||||||
"id": structpb.NewStringValue("r1"),
|
|
||||||
}}
|
|
||||||
idx.set("example.com/record", "r1", r1)
|
|
||||||
|
|
||||||
r2 := &structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"cidr": structpb.NewStringValue("2001:db8::/48"),
|
|
||||||
}}),
|
|
||||||
"id": structpb.NewStringValue("r2"),
|
|
||||||
}}
|
|
||||||
idx.set("example.com/record", "r2", r2)
|
|
||||||
|
|
||||||
assert.Equal(t, r2, idx.find("example.com/record", "2001:db8::"))
|
|
||||||
idx.delete("example.com/record", "r2")
|
|
||||||
assert.Equal(t, r1, idx.find("example.com/record", "2001:db8::"))
|
|
||||||
idx.delete("example.com/record", "r1")
|
|
||||||
assert.Nil(t, idx.find("example.com/record", "2001:db8::"))
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -5,78 +5,33 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/open-policy-agent/opa/ast"
|
"github.com/open-policy-agent/opa/ast"
|
||||||
"github.com/open-policy-agent/opa/rego"
|
"github.com/open-policy-agent/opa/rego"
|
||||||
"github.com/open-policy-agent/opa/storage"
|
opastorage "github.com/open-policy-agent/opa/storage"
|
||||||
"github.com/open-policy-agent/opa/storage/inmem"
|
"github.com/open-policy-agent/opa/storage/inmem"
|
||||||
"github.com/open-policy-agent/opa/types"
|
"github.com/open-policy-agent/opa/types"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Store stores data for the OPA rego policy evaluation.
|
// A Store stores data for the OPA rego policy evaluation.
|
||||||
type Store struct {
|
type Store struct {
|
||||||
storage.Store
|
opastorage.Store
|
||||||
index *index
|
|
||||||
|
|
||||||
dataBrokerServerVersion, dataBrokerRecordVersion uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Store.
|
// New creates a new Store.
|
||||||
func New() *Store {
|
func New() *Store {
|
||||||
return &Store{
|
return &Store{
|
||||||
Store: inmem.New(),
|
Store: inmem.New(),
|
||||||
index: newIndex(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFromProtos creates a new Store from an existing set of protobuf messages.
|
|
||||||
func NewFromProtos(serverVersion uint64, msgs ...proto.Message) *Store {
|
|
||||||
s := New()
|
|
||||||
for _, msg := range msgs {
|
|
||||||
any := protoutil.NewAny(msg)
|
|
||||||
record := new(databroker.Record)
|
|
||||||
record.ModifiedAt = timestamppb.Now()
|
|
||||||
record.Version = cryptutil.NewRandomUInt64()
|
|
||||||
record.Id = uuid.New().String()
|
|
||||||
record.Data = any
|
|
||||||
record.Type = any.TypeUrl
|
|
||||||
if hasID, ok := msg.(interface{ GetId() string }); ok {
|
|
||||||
record.Id = hasID.GetId()
|
|
||||||
}
|
|
||||||
|
|
||||||
s.UpdateRecord(serverVersion, record)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearRecords removes all the records from the store.
|
|
||||||
func (s *Store) ClearRecords() {
|
|
||||||
s.index.clear()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDataBrokerVersions gets the databroker versions.
|
|
||||||
func (s *Store) GetDataBrokerVersions() (serverVersion, recordVersion uint64) {
|
|
||||||
return atomic.LoadUint64(&s.dataBrokerServerVersion),
|
|
||||||
atomic.LoadUint64(&s.dataBrokerRecordVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRecordData gets a record's data from the store. `nil` is returned
|
|
||||||
// if no record exists for the given type and id.
|
|
||||||
func (s *Store) GetRecordData(typeURL, idOrValue string) proto.Message {
|
|
||||||
return s.index.find(typeURL, idOrValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
|
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
|
||||||
func (s *Store) UpdateIssuer(issuer string) {
|
func (s *Store) UpdateIssuer(issuer string) {
|
||||||
s.write("/issuer", issuer)
|
s.write("/issuer", issuer)
|
||||||
|
@ -98,20 +53,6 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
|
||||||
s.write("/route_policies", routePolicies)
|
s.write("/route_policies", routePolicies)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRecord updates a record in the store.
|
|
||||||
func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) {
|
|
||||||
if record.GetDeletedAt() != nil {
|
|
||||||
s.index.delete(record.GetType(), record.GetId())
|
|
||||||
} else {
|
|
||||||
msg, _ := record.GetData().UnmarshalNew()
|
|
||||||
s.index.set(record.GetType(), record.GetId(), msg)
|
|
||||||
}
|
|
||||||
s.write("/databroker_server_version", fmt.Sprint(serverVersion))
|
|
||||||
s.write("/databroker_record_version", fmt.Sprint(record.GetVersion()))
|
|
||||||
atomic.StoreUint64(&s.dataBrokerServerVersion, serverVersion)
|
|
||||||
atomic.StoreUint64(&s.dataBrokerRecordVersion, record.GetVersion())
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSigningKey updates the signing key stored in the database. Signing operations
|
// UpdateSigningKey updates the signing key stored in the database. Signing operations
|
||||||
// in rego use JWKs, so we take in that format.
|
// in rego use JWKs, so we take in that format.
|
||||||
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||||
|
@ -120,7 +61,7 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||||
|
|
||||||
func (s *Store) write(rawPath string, value interface{}) {
|
func (s *Store) write(rawPath string, value interface{}) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error {
|
err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {
|
||||||
return s.writeTxn(txn, rawPath, value)
|
return s.writeTxn(txn, rawPath, value)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -129,23 +70,23 @@ func (s *Store) write(rawPath string, value interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error {
|
func (s *Store) writeTxn(txn opastorage.Transaction, rawPath string, value interface{}) error {
|
||||||
p, ok := storage.ParsePath(rawPath)
|
p, ok := opastorage.ParsePath(rawPath)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid path")
|
return fmt.Errorf("invalid path")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p) > 1 {
|
if len(p) > 1 {
|
||||||
err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
|
err := opastorage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var op storage.PatchOp = storage.ReplaceOp
|
var op opastorage.PatchOp = opastorage.ReplaceOp
|
||||||
_, err := s.Read(context.Background(), txn, p)
|
_, err := s.Read(context.Background(), txn, p)
|
||||||
if storage.IsNotFound(err) {
|
if opastorage.IsNotFound(err) {
|
||||||
op = storage.AddOp
|
op = opastorage.AddOp
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -167,23 +108,42 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
||||||
return nil, fmt.Errorf("invalid record type: %T", op1)
|
return nil, fmt.Errorf("invalid record type: %T", op1)
|
||||||
}
|
}
|
||||||
|
|
||||||
recordID, ok := op2.Value.(ast.String)
|
value, ok := op2.Value.(ast.String)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid record id: %T", op2)
|
return nil, fmt.Errorf("invalid record id: %T", op2)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := s.GetRecordData(string(recordType), string(recordID))
|
req := &databroker.QueryRequest{
|
||||||
|
Type: string(recordType),
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
req.SetFilterByIDOrIndex(string(value))
|
||||||
|
|
||||||
|
res, err := storage.GetQuerier(bctx.Context).Query(bctx.Context, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(bctx.Context).Err(err).Msg("authorize/store: error retrieving record")
|
||||||
|
return ast.NullTerm(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return ast.NullTerm(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _ := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||||
|
if msg == nil {
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return ast.NullTerm(), nil
|
return ast.NullTerm(), nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
obj := toMap(msg)
|
obj := toMap(msg)
|
||||||
|
|
||||||
value, err := ast.InterfaceToValue(obj)
|
regoValue, err := ast.InterfaceToValue(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
log.Error(bctx.Context).Err(err).Msg("authorize/store: error converting object to rego")
|
||||||
|
return ast.NullTerm(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ast.NewTerm(value), nil
|
return ast.NewTerm(regoValue), nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,83 +0,0 @@
|
||||||
package store
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
|
||||||
t.Run("records", func(t *testing.T) {
|
|
||||||
s := New()
|
|
||||||
u := &user.User{
|
|
||||||
Version: "v1",
|
|
||||||
Id: "u1",
|
|
||||||
Name: "name",
|
|
||||||
Email: "name@example.com",
|
|
||||||
}
|
|
||||||
any := protoutil.NewAny(u)
|
|
||||||
s.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: any.GetTypeUrl(),
|
|
||||||
Id: u.GetId(),
|
|
||||||
Data: any,
|
|
||||||
})
|
|
||||||
|
|
||||||
v := s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
|
||||||
assert.Equal(t, map[string]interface{}{
|
|
||||||
"version": "v1",
|
|
||||||
"id": "u1",
|
|
||||||
"name": "name",
|
|
||||||
"email": "name@example.com",
|
|
||||||
}, toMap(v))
|
|
||||||
|
|
||||||
s.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: 2,
|
|
||||||
Type: any.GetTypeUrl(),
|
|
||||||
Id: u.GetId(),
|
|
||||||
Data: any,
|
|
||||||
DeletedAt: timestamppb.Now(),
|
|
||||||
})
|
|
||||||
|
|
||||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
|
||||||
assert.Nil(t, v)
|
|
||||||
|
|
||||||
s.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: 3,
|
|
||||||
Type: any.GetTypeUrl(),
|
|
||||||
Id: u.GetId(),
|
|
||||||
Data: any,
|
|
||||||
})
|
|
||||||
|
|
||||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
|
||||||
assert.NotNil(t, v)
|
|
||||||
|
|
||||||
s.ClearRecords()
|
|
||||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
|
||||||
assert.Nil(t, v)
|
|
||||||
})
|
|
||||||
t.Run("cidr", func(t *testing.T) {
|
|
||||||
s := New()
|
|
||||||
any := protoutil.NewAny(&structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
|
||||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
|
||||||
}}),
|
|
||||||
"id": structpb.NewStringValue("r1"),
|
|
||||||
}})
|
|
||||||
s.UpdateRecord(0, &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: any.GetTypeUrl(),
|
|
||||||
Id: "r1",
|
|
||||||
Data: any,
|
|
||||||
})
|
|
||||||
|
|
||||||
v := s.GetRecordData(any.GetTypeUrl(), "192.168.0.7")
|
|
||||||
assert.NotNil(t, v)
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -12,9 +12,11 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/audit"
|
"github.com/pomerium/pomerium/pkg/grpc/audit"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *Authorize) logAuthorizeCheck(
|
func (a *Authorize) logAuthorizeCheck(
|
||||||
|
@ -39,7 +41,7 @@ func (a *Authorize) logAuthorizeCheck(
|
||||||
|
|
||||||
// session information
|
// session information
|
||||||
if s, ok := s.(*session.Session); ok {
|
if s, ok := s.(*session.Session); ok {
|
||||||
evt = a.populateLogSessionDetails(evt, s)
|
evt = a.populateLogSessionDetails(ctx, evt, s)
|
||||||
}
|
}
|
||||||
if sa, ok := s.(*user.ServiceAccount); ok {
|
if sa, ok := s.(*user.ServiceAccount); ok {
|
||||||
evt = evt.Str("service-account-id", sa.GetId())
|
evt = evt.Str("service-account-id", sa.GetId())
|
||||||
|
@ -61,8 +63,6 @@ func (a *Authorize) logAuthorizeCheck(
|
||||||
}
|
}
|
||||||
evt = evt.Str("user", u.GetId())
|
evt = evt.Str("user", u.GetId())
|
||||||
evt = evt.Str("email", u.GetEmail())
|
evt = evt.Str("email", u.GetEmail())
|
||||||
evt = evt.Uint64("databroker_server_version", res.DataBrokerServerVersion)
|
|
||||||
evt = evt.Uint64("databroker_record_version", res.DataBrokerRecordVersion)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// potentially sensitive, only log if debug mode
|
// potentially sensitive, only log if debug mode
|
||||||
|
@ -80,10 +80,6 @@ func (a *Authorize) logAuthorizeCheck(
|
||||||
Request: in,
|
Request: in,
|
||||||
Response: out,
|
Response: out,
|
||||||
}
|
}
|
||||||
if res != nil {
|
|
||||||
record.DatabrokerServerVersion = res.DataBrokerServerVersion
|
|
||||||
record.DatabrokerRecordVersion = res.DataBrokerRecordVersion
|
|
||||||
}
|
|
||||||
sealed, err := enc.Encrypt(record)
|
sealed, err := enc.Encrypt(record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record")
|
log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record")
|
||||||
|
@ -96,26 +92,50 @@ func (a *Authorize) logAuthorizeCheck(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) populateLogSessionDetails(evt *zerolog.Event, s *session.Session) *zerolog.Event {
|
func (a *Authorize) populateLogSessionDetails(ctx context.Context, evt *zerolog.Event, s *session.Session) *zerolog.Event {
|
||||||
evt = evt.Str("session-id", s.GetId())
|
evt = evt.Str("session-id", s.GetId())
|
||||||
if s.GetImpersonateSessionId() == "" {
|
if s.GetImpersonateSessionId() == "" {
|
||||||
return evt
|
return evt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
querier := storage.GetQuerier(ctx)
|
||||||
|
|
||||||
evt = evt.Str("impersonate-session-id", s.GetImpersonateSessionId())
|
evt = evt.Str("impersonate-session-id", s.GetImpersonateSessionId())
|
||||||
impersonatedSession, ok := a.store.GetRecordData(
|
req := &databroker.QueryRequest{
|
||||||
grpcutil.GetTypeURL(new(session.Session)),
|
Type: grpcutil.GetTypeURL(new(session.Session)),
|
||||||
s.GetImpersonateSessionId(),
|
Limit: 1,
|
||||||
).(*session.Session)
|
}
|
||||||
|
req.SetFilterByID(s.GetImpersonateSessionId())
|
||||||
|
res, err := querier.Query(ctx, req)
|
||||||
|
if err != nil || len(res.GetRecords()) == 0 {
|
||||||
|
return evt
|
||||||
|
}
|
||||||
|
|
||||||
|
impersonatedSessionMsg, err := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return evt
|
||||||
|
}
|
||||||
|
impersonatedSession, ok := impersonatedSessionMsg.(*session.Session)
|
||||||
if !ok {
|
if !ok {
|
||||||
return evt
|
return evt
|
||||||
}
|
}
|
||||||
evt = evt.Str("impersonate-user-id", impersonatedSession.GetUserId())
|
evt = evt.Str("impersonate-user-id", impersonatedSession.GetUserId())
|
||||||
|
|
||||||
impersonatedUser, ok := a.store.GetRecordData(
|
req = &databroker.QueryRequest{
|
||||||
grpcutil.GetTypeURL(new(user.User)),
|
Type: grpcutil.GetTypeURL(new(user.User)),
|
||||||
impersonatedSession.GetUserId(),
|
Limit: 1,
|
||||||
).(*user.User)
|
}
|
||||||
|
req.SetFilterByID(impersonatedSession.GetUserId())
|
||||||
|
res, err = querier.Query(ctx, req)
|
||||||
|
if err != nil || len(res.GetRecords()) == 0 {
|
||||||
|
return evt
|
||||||
|
}
|
||||||
|
|
||||||
|
impersonatedUserMsg, err := res.GetRecords()[0].GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return evt
|
||||||
|
}
|
||||||
|
impersonatedUser, ok := impersonatedUserMsg.(*user.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return evt
|
return evt
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,209 +0,0 @@
|
||||||
package authorize
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
forceSyncRecordMaxWait = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type sessionOrServiceAccount interface {
|
|
||||||
GetUserId() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type dataBrokerSyncer struct {
|
|
||||||
*databroker.Syncer
|
|
||||||
authorize *Authorize
|
|
||||||
signalOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDataBrokerSyncer(authorize *Authorize) *dataBrokerSyncer {
|
|
||||||
syncer := &dataBrokerSyncer{
|
|
||||||
authorize: authorize,
|
|
||||||
}
|
|
||||||
syncer.Syncer = databroker.NewSyncer("authorize", syncer)
|
|
||||||
return syncer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
|
||||||
return syncer.authorize.state.Load().dataBrokerClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
|
|
||||||
syncer.authorize.stateLock.Lock()
|
|
||||||
syncer.authorize.store.ClearRecords()
|
|
||||||
syncer.authorize.stateLock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
|
|
||||||
syncer.authorize.stateLock.Lock()
|
|
||||||
for _, record := range records {
|
|
||||||
syncer.authorize.store.UpdateRecord(serverVersion, record)
|
|
||||||
}
|
|
||||||
syncer.authorize.stateLock.Unlock()
|
|
||||||
|
|
||||||
// the first time we update records we signal the initial sync
|
|
||||||
syncer.signalOnce.Do(func() {
|
|
||||||
log.Info(ctx).Msg("initial sync from databroker complete")
|
|
||||||
close(syncer.authorize.dataBrokerInitialSync)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionOrServiceAccount, *user.User, error) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSync")
|
|
||||||
defer span.End()
|
|
||||||
if ss == nil {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the session state has databroker versions, wait for those to finish syncing
|
|
||||||
if ss.DatabrokerServerVersion != 0 && ss.DatabrokerRecordVersion != 0 {
|
|
||||||
a.forceSyncToVersion(ctx, ss.DatabrokerServerVersion, ss.DatabrokerRecordVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
s := a.forceSyncSession(ctx, ss.ID)
|
|
||||||
if s == nil {
|
|
||||||
return nil, nil, errors.New("session not found")
|
|
||||||
}
|
|
||||||
u := a.forceSyncUser(ctx, s.GetUserId())
|
|
||||||
return s, u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) forceSyncToVersion(ctx context.Context, serverVersion, recordVersion uint64) (ready bool) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncToVersion")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Millisecond * 50)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
currentServerVersion, currentRecordVersion := a.store.GetDataBrokerVersions()
|
|
||||||
// check if the local record version is up to date with the expected record version
|
|
||||||
if currentServerVersion == serverVersion && currentRecordVersion >= recordVersion {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
|
|
||||||
if ok {
|
|
||||||
a.accessTracker.TrackSessionAccess(sessionID)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
|
|
||||||
if ok {
|
|
||||||
a.accessTracker.TrackServiceAccountAccess(sessionID)
|
|
||||||
return sa
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for the session to show up
|
|
||||||
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s, ok = record.(*session.Session)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
a.accessTracker.TrackSessionAccess(sessionID)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
|
|
||||||
if ok {
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for the user to show up
|
|
||||||
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
u, ok = record.(*user.User)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForRecordSync waits for the first sync of a record to complete
|
|
||||||
func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, recordID string) (proto.Message, error) {
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.InitialInterval = time.Millisecond
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
for {
|
|
||||||
current := a.store.GetRecordData(recordTypeURL, recordID)
|
|
||||||
if current != nil {
|
|
||||||
// record found, so it's already synced
|
|
||||||
return current, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := a.state.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
|
||||||
Type: recordTypeURL,
|
|
||||||
Id: recordID,
|
|
||||||
})
|
|
||||||
if status.Code(err) == codes.NotFound {
|
|
||||||
// record not found, so no need to wait
|
|
||||||
return nil, nil
|
|
||||||
} else if err != nil {
|
|
||||||
log.Error(ctx).
|
|
||||||
Err(err).
|
|
||||||
Str("type", recordTypeURL).
|
|
||||||
Str("id", recordID).
|
|
||||||
Msg("authorize: error retrieving record")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Warn(ctx).
|
|
||||||
Str("type", recordTypeURL).
|
|
||||||
Str("id", recordID).
|
|
||||||
Msg("authorize: first sync of record did not complete")
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,150 +0,0 @@
|
||||||
package authorize
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthorize_forceSyncToVersion(t *testing.T) {
|
|
||||||
o := &config.Options{
|
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
|
||||||
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
|
|
||||||
Policies: testPolicies(t),
|
|
||||||
}
|
|
||||||
a, err := New(&config.Config{Options: o})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
a.store.UpdateRecord(1, &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
})
|
|
||||||
t.Run("ready", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel()
|
|
||||||
assert.True(t, a.forceSyncToVersion(ctx, 1, 1))
|
|
||||||
})
|
|
||||||
t.Run("not ready", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel()
|
|
||||||
assert.False(t, a.forceSyncToVersion(ctx, 1, 2))
|
|
||||||
})
|
|
||||||
t.Run("becomes ready", func(t *testing.T) {
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-time.After(time.Millisecond * 100)
|
|
||||||
a.store.UpdateRecord(1, &databroker.Record{
|
|
||||||
Version: 2,
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
assert.True(t, a.forceSyncToVersion(ctx, 1, 2))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthorize_waitForRecordSync(t *testing.T) {
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
o := &config.Options{
|
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
|
||||||
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
|
|
||||||
Policies: testPolicies(t),
|
|
||||||
}
|
|
||||||
t.Run("skip if exists", func(t *testing.T) {
|
|
||||||
a, err := New(&config.Config{Options: o})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
a.store.UpdateRecord(0, newRecord(&session.Session{
|
|
||||||
Id: "SESSION_ID",
|
|
||||||
}))
|
|
||||||
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
panic("should never be called")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
||||||
})
|
|
||||||
t.Run("skip if not found", func(t *testing.T) {
|
|
||||||
a, err := New(&config.Config{Options: o})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
callCount := 0
|
|
||||||
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
callCount++
|
|
||||||
return nil, status.Error(codes.NotFound, "not found")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
||||||
assert.Equal(t, 1, callCount, "should be called once")
|
|
||||||
})
|
|
||||||
t.Run("poll", func(t *testing.T) {
|
|
||||||
a, err := New(&config.Config{Options: o})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
callCount := 0
|
|
||||||
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
callCount++
|
|
||||||
switch callCount {
|
|
||||||
case 1:
|
|
||||||
s := &session.Session{Id: "SESSION_ID"}
|
|
||||||
a.store.UpdateRecord(0, newRecord(s))
|
|
||||||
return &databroker.GetResponse{Record: newRecord(s)}, nil
|
|
||||||
default:
|
|
||||||
panic("should never be called")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
||||||
})
|
|
||||||
t.Run("timeout", func(t *testing.T) {
|
|
||||||
a, err := New(&config.Config{Options: o})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
callCount := 0
|
|
||||||
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
callCount++
|
|
||||||
s := &session.Session{Id: "SESSION_ID"}
|
|
||||||
return &databroker.GetResponse{Record: newRecord(s)}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
||||||
assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type storableMessage interface {
|
|
||||||
proto.Message
|
|
||||||
GetId() string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRecord(msg storableMessage) *databroker.Record {
|
|
||||||
any := protoutil.NewAny(msg)
|
|
||||||
return &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: any.GetTypeUrl(),
|
|
||||||
Id: msg.GetId(),
|
|
||||||
Data: any,
|
|
||||||
}
|
|
||||||
}
|
|
16
go.mod
16
go.mod
|
@ -6,7 +6,9 @@ require (
|
||||||
contrib.go.opencensus.io/exporter/jaeger v0.2.1
|
contrib.go.opencensus.io/exporter/jaeger v0.2.1
|
||||||
contrib.go.opencensus.io/exporter/prometheus v0.4.1
|
contrib.go.opencensus.io/exporter/prometheus v0.4.1
|
||||||
contrib.go.opencensus.io/exporter/zipkin v0.1.2
|
contrib.go.opencensus.io/exporter/zipkin v0.1.2
|
||||||
|
github.com/CAFxX/httpcompression v0.0.8
|
||||||
github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0
|
github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0
|
||||||
|
github.com/VictoriaMetrics/fastcache v1.10.0
|
||||||
github.com/caddyserver/certmagic v0.16.0
|
github.com/caddyserver/certmagic v0.16.0
|
||||||
github.com/cenkalti/backoff/v4 v4.1.3
|
github.com/cenkalti/backoff/v4 v4.1.3
|
||||||
github.com/cespare/xxhash/v2 v2.1.2
|
github.com/cespare/xxhash/v2 v2.1.2
|
||||||
|
@ -29,9 +31,11 @@ require (
|
||||||
github.com/gorilla/handlers v1.5.1
|
github.com/gorilla/handlers v1.5.1
|
||||||
github.com/gorilla/mux v1.8.0
|
github.com/gorilla/mux v1.8.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/golang-lru v0.5.4
|
github.com/hashicorp/golang-lru v0.5.4
|
||||||
|
github.com/jackc/pgconn v1.12.1
|
||||||
|
github.com/jackc/pgtype v1.11.0
|
||||||
|
github.com/jackc/pgx/v4 v4.16.1
|
||||||
github.com/martinlindhe/base36 v1.1.0
|
github.com/martinlindhe/base36 v1.1.0
|
||||||
github.com/mholt/acmez v1.0.2
|
github.com/mholt/acmez v1.0.2
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
|
@ -75,14 +79,6 @@ require (
|
||||||
sigs.k8s.io/yaml v1.3.0
|
sigs.k8s.io/yaml v1.3.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/CAFxX/httpcompression v0.0.8
|
|
||||||
github.com/jackc/pgconn v1.12.1
|
|
||||||
github.com/jackc/pgtype v1.11.0
|
|
||||||
github.com/jackc/pgx/v4 v4.16.1
|
|
||||||
github.com/kentik/patricia v1.0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
4d63.com/gochecknoglobals v0.1.0 // indirect
|
4d63.com/gochecknoglobals v0.1.0 // indirect
|
||||||
cloud.google.com/go/compute v1.6.1 // indirect
|
cloud.google.com/go/compute v1.6.1 // indirect
|
||||||
|
@ -148,6 +144,7 @@ require (
|
||||||
github.com/gofrs/flock v0.8.1 // indirect
|
github.com/gofrs/flock v0.8.1 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
|
github.com/golang/snappy v0.0.4 // indirect
|
||||||
github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect
|
github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect
|
||||||
github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect
|
github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect
|
||||||
github.com/golangci/go-misc v0.0.0-20220329215616-d24fe342adfe // indirect
|
github.com/golangci/go-misc v0.0.0-20220329215616-d24fe342adfe // indirect
|
||||||
|
@ -166,6 +163,7 @@ require (
|
||||||
github.com/gostaticanalysis/comment v1.4.2 // indirect
|
github.com/gostaticanalysis/comment v1.4.2 // indirect
|
||||||
github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect
|
github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect
|
||||||
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
|
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
|
||||||
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
github.com/hashicorp/go-version v1.4.0 // indirect
|
github.com/hashicorp/go-version v1.4.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/hexops/gotextdiff v1.0.3 // indirect
|
github.com/hexops/gotextdiff v1.0.3 // indirect
|
||||||
|
|
7
go.sum
7
go.sum
|
@ -165,6 +165,8 @@ github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWX
|
||||||
github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs=
|
github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs=
|
||||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||||
github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0=
|
github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0=
|
||||||
|
github.com/VictoriaMetrics/fastcache v1.10.0 h1:5hDJnLsKLpnUEToub7ETuRu8RCkb40woBZAUiKonXzY=
|
||||||
|
github.com/VictoriaMetrics/fastcache v1.10.0/go.mod h1:tjiYeEfYXCqacuvYw/7UoDIeJaNxq6132xHICNP77w8=
|
||||||
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
|
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
|
@ -175,6 +177,8 @@ github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:C
|
||||||
github.com/alexflint/go-filemutex v1.1.0/go.mod h1:7P4iRhttt/nUvUOrYIhcpMzv2G6CY9UnI16Z+UJqRyk=
|
github.com/alexflint/go-filemutex v1.1.0/go.mod h1:7P4iRhttt/nUvUOrYIhcpMzv2G6CY9UnI16Z+UJqRyk=
|
||||||
github.com/alexkohler/prealloc v1.0.0 h1:Hbq0/3fJPQhNkN0dR95AVrr6R7tou91y0uHG5pOcUuw=
|
github.com/alexkohler/prealloc v1.0.0 h1:Hbq0/3fJPQhNkN0dR95AVrr6R7tou91y0uHG5pOcUuw=
|
||||||
github.com/alexkohler/prealloc v1.0.0/go.mod h1:VetnK3dIgFBBKmg0YnD9F9x6Icjd+9cvfHR56wJVlKE=
|
github.com/alexkohler/prealloc v1.0.0/go.mod h1:VetnK3dIgFBBKmg0YnD9F9x6Icjd+9cvfHR56wJVlKE=
|
||||||
|
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
|
||||||
|
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
|
||||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
|
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
|
||||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
|
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
|
||||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||||
|
@ -995,8 +999,6 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8
|
||||||
github.com/julz/importas v0.1.0 h1:F78HnrsjY3cR7j0etXy5+TU1Zuy7Xt08X/1aJnH5xXY=
|
github.com/julz/importas v0.1.0 h1:F78HnrsjY3cR7j0etXy5+TU1Zuy7Xt08X/1aJnH5xXY=
|
||||||
github.com/julz/importas v0.1.0/go.mod h1:oSFU2R4XK/P7kNBrnL/FEQlDGN1/6WoxXEjSSXO0DV0=
|
github.com/julz/importas v0.1.0/go.mod h1:oSFU2R4XK/P7kNBrnL/FEQlDGN1/6WoxXEjSSXO0DV0=
|
||||||
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k=
|
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k=
|
||||||
github.com/kentik/patricia v1.0.0 h1:jx/8kXf0JvQEHNPX4njL+PDzpxxqNKg0RjA8hJcX38A=
|
|
||||||
github.com/kentik/patricia v1.0.0/go.mod h1:e0nkPLU9NQl8v05ukfHU6+R5ykbKcXO+NqaC3ifTm0Y=
|
|
||||||
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
|
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
|
||||||
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
|
@ -2054,6 +2056,7 @@ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||||
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|
|
@ -28,8 +28,6 @@ type Record struct {
|
||||||
|
|
||||||
Request *v3.CheckRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"`
|
Request *v3.CheckRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"`
|
||||||
Response *v3.CheckResponse `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"`
|
Response *v3.CheckResponse `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"`
|
||||||
DatabrokerServerVersion uint64 `protobuf:"varint,3,opt,name=databroker_server_version,json=databrokerServerVersion,proto3" json:"databroker_server_version,omitempty"`
|
|
||||||
DatabrokerRecordVersion uint64 `protobuf:"varint,4,opt,name=databroker_record_version,json=databrokerRecordVersion,proto3" json:"databroker_record_version,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *Record) Reset() {
|
func (x *Record) Reset() {
|
||||||
|
@ -78,20 +76,6 @@ func (x *Record) GetResponse() *v3.CheckResponse {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *Record) GetDatabrokerServerVersion() uint64 {
|
|
||||||
if x != nil {
|
|
||||||
return x.DatabrokerServerVersion
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x *Record) GetDatabrokerRecordVersion() uint64 {
|
|
||||||
if x != nil {
|
|
||||||
return x.DatabrokerRecordVersion
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var File_audit_proto protoreflect.FileDescriptor
|
var File_audit_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
var file_audit_proto_rawDesc = []byte{
|
var file_audit_proto_rawDesc = []byte{
|
||||||
|
@ -99,7 +83,7 @@ var file_audit_proto_rawDesc = []byte{
|
||||||
0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x61, 0x75, 0x64, 0x69, 0x74, 0x1a, 0x29, 0x65,
|
0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x61, 0x75, 0x64, 0x69, 0x74, 0x1a, 0x29, 0x65,
|
||||||
0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x61, 0x75, 0x74,
|
0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x61, 0x75, 0x74,
|
||||||
0x68, 0x2f, 0x76, 0x33, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, 0x75,
|
0x68, 0x2f, 0x76, 0x33, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, 0x75,
|
||||||
0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x02, 0x0a, 0x06, 0x52, 0x65, 0x63,
|
0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x89, 0x01, 0x0a, 0x06, 0x52, 0x65, 0x63,
|
||||||
0x6f, 0x72, 0x64, 0x12, 0x3d, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01,
|
0x6f, 0x72, 0x64, 0x12, 0x3d, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01,
|
||||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72,
|
0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72,
|
||||||
0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65,
|
0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65,
|
||||||
|
@ -108,18 +92,10 @@ var file_audit_proto_rawDesc = []byte{
|
||||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72,
|
0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72,
|
||||||
0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65,
|
0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65,
|
||||||
0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70,
|
0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70,
|
||||||
0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x19, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b,
|
0x6f, 0x6e, 0x73, 0x65, 0x42, 0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
|
||||||
0x65, 0x72, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f,
|
0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65,
|
||||||
0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x17, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f,
|
0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75,
|
||||||
0x6b, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
|
0x64, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
0x12, 0x3a, 0x0a, 0x19, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x5f, 0x72,
|
|
||||||
0x65, 0x63, 0x6f, 0x72, 0x64, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
|
|
||||||
0x01, 0x28, 0x04, 0x52, 0x17, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x52,
|
|
||||||
0x65, 0x63, 0x6f, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x2d, 0x5a, 0x2b,
|
|
||||||
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72,
|
|
||||||
0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67,
|
|
||||||
0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x64, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
|
||||||
0x74, 0x6f, 0x33,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -8,6 +8,4 @@ import "envoy/service/auth/v3/external_auth.proto";
|
||||||
message Record {
|
message Record {
|
||||||
envoy.service.auth.v3.CheckRequest request = 1;
|
envoy.service.auth.v3.CheckRequest request = 1;
|
||||||
envoy.service.auth.v3.CheckResponse response = 2;
|
envoy.service.auth.v3.CheckResponse response = 2;
|
||||||
uint64 databroker_server_version = 3;
|
|
||||||
uint64 databroker_record_version = 4;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
structpb "google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
@ -118,6 +119,27 @@ func (x *PutResponse) GetRecord() *Record {
|
||||||
return records[0]
|
return records[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFilterByID sets the filter to an id.
|
||||||
|
func (x *QueryRequest) SetFilterByID(id string) {
|
||||||
|
x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||||
|
"id": structpb.NewStringValue(id),
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFilterByIDOrIndex sets the filter to an id or an index.
|
||||||
|
func (x *QueryRequest) SetFilterByIDOrIndex(idOrIndex string) {
|
||||||
|
x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||||
|
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
|
||||||
|
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||||
|
"id": structpb.NewStringValue(idOrIndex),
|
||||||
|
}}),
|
||||||
|
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||||
|
"$index": structpb.NewStringValue(idOrIndex),
|
||||||
|
}}),
|
||||||
|
}}),
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
// default is 4MB, but we'll do 1MB
|
// default is 4MB, but we'll do 1MB
|
||||||
const maxMessageSize = 1024 * 1024 * 1
|
const maxMessageSize = 1024 * 1024 * 1
|
||||||
|
|
||||||
|
|
155
pkg/storage/cache.go
Normal file
155
pkg/storage/cache.go
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/VictoriaMetrics/fastcache"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Cache will return cached data when available or call update when not.
|
||||||
|
type Cache interface {
|
||||||
|
GetOrUpdate(
|
||||||
|
ctx context.Context,
|
||||||
|
key []byte,
|
||||||
|
update func(ctx context.Context) ([]byte, error),
|
||||||
|
) ([]byte, error)
|
||||||
|
Invalidate(key []byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
type localCache struct {
|
||||||
|
singleflight singleflight.Group
|
||||||
|
mu sync.RWMutex
|
||||||
|
m map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalCache creates a new Cache backed by a map.
|
||||||
|
func NewLocalCache() Cache {
|
||||||
|
return &localCache{
|
||||||
|
m: make(map[string][]byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *localCache) GetOrUpdate(
|
||||||
|
ctx context.Context,
|
||||||
|
key []byte,
|
||||||
|
update func(ctx context.Context) ([]byte, error),
|
||||||
|
) ([]byte, error) {
|
||||||
|
strkey := string(key)
|
||||||
|
|
||||||
|
cache.mu.RLock()
|
||||||
|
cached, ok := cache.m[strkey]
|
||||||
|
cache.mu.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err, _ := cache.singleflight.Do(strkey, func() (interface{}, error) {
|
||||||
|
cache.mu.RLock()
|
||||||
|
cached, ok := cache.m[strkey]
|
||||||
|
cache.mu.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := update(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.m[strkey] = result
|
||||||
|
cache.mu.Unlock()
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return v.([]byte), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *localCache) Invalidate(key []byte) {
|
||||||
|
cache.mu.Lock()
|
||||||
|
delete(cache.m, string(key))
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type globalCache struct {
|
||||||
|
ttl time.Duration
|
||||||
|
|
||||||
|
singleflight singleflight.Group
|
||||||
|
mu sync.RWMutex
|
||||||
|
fastcache *fastcache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGlobalCache creates a new Cache backed by fastcache and a TTL.
|
||||||
|
func NewGlobalCache(ttl time.Duration) Cache {
|
||||||
|
return &globalCache{
|
||||||
|
ttl: ttl,
|
||||||
|
fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *globalCache) GetOrUpdate(
|
||||||
|
ctx context.Context,
|
||||||
|
key []byte,
|
||||||
|
update func(ctx context.Context) ([]byte, error),
|
||||||
|
) ([]byte, error) {
|
||||||
|
now := time.Now()
|
||||||
|
data, expiry, ok := cache.get(key)
|
||||||
|
if ok && now.Before(expiry) {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err, _ := cache.singleflight.Do(string(key), func() (interface{}, error) {
|
||||||
|
data, expiry, ok := cache.get(key)
|
||||||
|
if ok && now.Before(expiry) {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := update(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cache.set(key, value)
|
||||||
|
return value, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return v.([]byte), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *globalCache) Invalidate(key []byte) {
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.fastcache.Del(key)
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
|
||||||
|
cache.mu.RLock()
|
||||||
|
item := cache.fastcache.Get(nil, k)
|
||||||
|
cache.mu.RUnlock()
|
||||||
|
if len(item) < 8 {
|
||||||
|
return data, expiry, false
|
||||||
|
}
|
||||||
|
|
||||||
|
unix, data := binary.LittleEndian.Uint64(item), item[8:]
|
||||||
|
expiry = time.UnixMilli(int64(unix))
|
||||||
|
return data, expiry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *globalCache) set(k, v []byte) {
|
||||||
|
unix := time.Now().Add(cache.ttl).UnixMilli()
|
||||||
|
item := make([]byte, len(v)+8)
|
||||||
|
binary.LittleEndian.PutUint64(item, uint64(unix))
|
||||||
|
copy(item[8:], v)
|
||||||
|
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.fastcache.Set(k, item)
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}
|
73
pkg/storage/cache_test.go
Normal file
73
pkg/storage/cache_test.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalCache(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
update := func(ctx context.Context) ([]byte, error) {
|
||||||
|
callCount++
|
||||||
|
return []byte("v1"), nil
|
||||||
|
}
|
||||||
|
c := NewLocalCache()
|
||||||
|
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("v1"), v)
|
||||||
|
assert.Equal(t, 1, callCount)
|
||||||
|
|
||||||
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("v1"), v)
|
||||||
|
assert.Equal(t, 1, callCount)
|
||||||
|
|
||||||
|
c.Invalidate([]byte("k1"))
|
||||||
|
|
||||||
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("v1"), v)
|
||||||
|
assert.Equal(t, 2, callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalCache(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
update := func(ctx context.Context) ([]byte, error) {
|
||||||
|
callCount++
|
||||||
|
return []byte("v1"), nil
|
||||||
|
}
|
||||||
|
c := NewGlobalCache(time.Millisecond * 100)
|
||||||
|
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("v1"), v)
|
||||||
|
assert.Equal(t, 1, callCount)
|
||||||
|
|
||||||
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("v1"), v)
|
||||||
|
assert.Equal(t, 1, callCount)
|
||||||
|
|
||||||
|
c.Invalidate([]byte("k1"))
|
||||||
|
|
||||||
|
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("v1"), v)
|
||||||
|
assert.Equal(t, 2, callCount)
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
_, err := c.GetOrUpdate(ctx, []byte("k1"), func(ctx context.Context) ([]byte, error) {
|
||||||
|
return nil, fmt.Errorf("ERROR")
|
||||||
|
})
|
||||||
|
return err != nil
|
||||||
|
}, time.Second, time.Millisecond*10, "should honor TTL")
|
||||||
|
}
|
210
pkg/storage/querier.go
Normal file
210
pkg/storage/querier.go
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Querier is a read-only subset of the client methods
|
||||||
|
type Querier interface {
|
||||||
|
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// nilQuerier always returns NotFound.
|
||||||
|
type nilQuerier struct{}
|
||||||
|
|
||||||
|
func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
return nil, status.Error(codes.NotFound, "not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var querierKey struct{}
|
||||||
|
|
||||||
|
// GetQuerier gets the databroker Querier from the context.
|
||||||
|
func GetQuerier(ctx context.Context) Querier {
|
||||||
|
q, ok := ctx.Value(querierKey).(Querier)
|
||||||
|
if !ok {
|
||||||
|
q = nilQuerier{}
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithQuerier sets the databroker Querier on a context.
|
||||||
|
func WithQuerier(ctx context.Context, querier Querier) context.Context {
|
||||||
|
return context.WithValue(ctx, querierKey, querier)
|
||||||
|
}
|
||||||
|
|
||||||
|
type staticQuerier struct {
|
||||||
|
records []*databroker.Record
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
|
||||||
|
func NewStaticQuerier(msgs ...proto.Message) Querier {
|
||||||
|
getter := &staticQuerier{}
|
||||||
|
for _, msg := range msgs {
|
||||||
|
any := protoutil.NewAny(msg)
|
||||||
|
record := new(databroker.Record)
|
||||||
|
record.ModifiedAt = timestamppb.Now()
|
||||||
|
record.Version = cryptutil.NewRandomUInt64()
|
||||||
|
record.Id = uuid.New().String()
|
||||||
|
record.Data = any
|
||||||
|
record.Type = any.TypeUrl
|
||||||
|
if hasID, ok := msg.(interface{ GetId() string }); ok {
|
||||||
|
record.Id = hasID.GetId()
|
||||||
|
}
|
||||||
|
getter.records = append(getter.records, record)
|
||||||
|
}
|
||||||
|
return getter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query queries for records.
|
||||||
|
func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
expr, err := FilterExpressionFromStruct(in.GetFilter())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, err := RecordStreamFilterFromFilterExpression(expr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := new(databroker.QueryResponse)
|
||||||
|
for _, record := range q.records {
|
||||||
|
if record.GetType() != in.GetType() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !filter(record) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.GetQuery() != "" && !MatchAny(record.GetData(), in.GetQuery()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Records = append(res.Records, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int
|
||||||
|
res.Records, total = databroker.ApplyOffsetAndLimit(
|
||||||
|
res.Records,
|
||||||
|
int(in.GetOffset()),
|
||||||
|
int(in.GetLimit()),
|
||||||
|
)
|
||||||
|
res.TotalCount = int64(total)
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientQuerier struct {
|
||||||
|
client databroker.DataBrokerServiceClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewQuerier creates a new Querier that implements the Querier interface by making calls to the databroker over gRPC.
|
||||||
|
func NewQuerier(client databroker.DataBrokerServiceClient) Querier {
|
||||||
|
return &clientQuerier{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query queries for records.
|
||||||
|
func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
return q.client.Query(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A TracingQuerier records calls to Query.
|
||||||
|
type TracingQuerier struct {
|
||||||
|
underlying Querier
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
traces []QueryTrace
|
||||||
|
}
|
||||||
|
|
||||||
|
// A QueryTrace traces a call to Query.
|
||||||
|
type QueryTrace struct {
|
||||||
|
ServerVersion, RecordVersion uint64
|
||||||
|
|
||||||
|
RecordType string
|
||||||
|
Query string
|
||||||
|
Filter *structpb.Struct
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTracingQuerier creates a new TracingQuerier.
|
||||||
|
func NewTracingQuerier(q Querier) *TracingQuerier {
|
||||||
|
return &TracingQuerier{
|
||||||
|
underlying: q,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query queries for records.
|
||||||
|
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
res, err := q.underlying.Query(ctx, in, opts...)
|
||||||
|
if err == nil {
|
||||||
|
q.mu.Lock()
|
||||||
|
q.traces = append(q.traces, QueryTrace{
|
||||||
|
RecordType: in.GetType(),
|
||||||
|
Query: in.GetQuery(),
|
||||||
|
Filter: in.GetFilter(),
|
||||||
|
})
|
||||||
|
q.mu.Unlock()
|
||||||
|
}
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traces returns all the traces.
|
||||||
|
func (q *TracingQuerier) Traces() []QueryTrace {
|
||||||
|
q.mu.Lock()
|
||||||
|
traces := make([]QueryTrace, len(q.traces))
|
||||||
|
copy(traces, q.traces)
|
||||||
|
q.mu.Unlock()
|
||||||
|
return traces
|
||||||
|
}
|
||||||
|
|
||||||
|
type cachingQuerier struct {
|
||||||
|
q Querier
|
||||||
|
cache Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCachingQuerier creates a new querier that caches results in a Cache.
|
||||||
|
func NewCachingQuerier(q Querier, cache Cache) Querier {
|
||||||
|
return &cachingQuerier{
|
||||||
|
q: q,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
key, err := (&proto.MarshalOptions{
|
||||||
|
Deterministic: true,
|
||||||
|
}).Marshal(in)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) {
|
||||||
|
res, err := q.q.Query(ctx, in, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return proto.Marshal(res)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var res databroker.QueryResponse
|
||||||
|
err = proto.Unmarshal(rawResult, &res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &res, nil
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue