Merge branch 'main' into kralicky/pgx-tracing

This commit is contained in:
Joe Kralicky 2025-01-24 14:56:08 -05:00 committed by GitHub
commit d1b8e3b92f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 938 additions and 216 deletions

View file

@ -6,11 +6,13 @@ import (
"context"
"fmt"
"slices"
"sync"
"time"
"github.com/rs/zerolog"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
@ -21,21 +23,16 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
oteltrace "go.opentelemetry.io/otel/trace"
)
// Authorize struct holds
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker
globalCache storage.Cache
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
// This should provide a consistent view of the data at a given server/record version and
// avoid partial updates.
stateLock sync.RWMutex
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker
globalCache storage.Cache
groupsCacheWarmer *cacheWarmer
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
@ -60,6 +57,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
}
a.state = atomicutil.NewValue(state)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
return a, nil
}
@ -70,8 +68,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
// Run runs the authorize service.
func (a *Authorize) Run(ctx context.Context) error {
a.accessTracker.Run(ctx)
return nil
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
a.accessTracker.Run(ctx)
return nil
})
eg.Go(func() error {
a.groupsCacheWarmer.Run(ctx)
return nil
})
return eg.Wait()
}
func validateOptions(o *config.Options) error {
@ -150,9 +156,13 @@ func newPolicyEvaluator(
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load()
a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else {
a.state.Store(state)
a.state.Store(newState)
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
}
}
}

122
authorize/cache_warmer.go Normal file
View file

@ -0,0 +1,122 @@
package authorize
import (
"context"
"time"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
type cacheWarmer struct {
cc *grpc.ClientConn
cache storage.Cache
typeURL string
updatedCC chan *grpc.ClientConn
}
func newCacheWarmer(
cc *grpc.ClientConn,
cache storage.Cache,
typeURL string,
) *cacheWarmer {
return &cacheWarmer{
cc: cc,
cache: cache,
typeURL: typeURL,
updatedCC: make(chan *grpc.ClientConn, 1),
}
}
func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) {
for {
select {
case cw.updatedCC <- cc:
return
default:
}
select {
case <-cw.updatedCC:
default:
}
}
}
func (cw *cacheWarmer) Run(ctx context.Context) {
// Run a syncer for the cache warmer until the underlying databroker connection is changed.
// When that happens cancel the currently running syncer and start a new one.
runCtx, runCancel := context.WithCancel(ctx)
go cw.run(runCtx, cw.cc)
for {
select {
case <-ctx.Done():
runCancel()
return
case cc := <-cw.updatedCC:
log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer")
cw.cc = cc
runCancel()
runCtx, runCancel = context.WithCancel(ctx)
go cw.run(runCtx, cw.cc)
}
}
}
func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) {
log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache")
_ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{
client: databroker.NewDataBrokerServiceClient(cc),
cache: cw.cache,
}, databroker.WithTypeURL(cw.typeURL)).Run(ctx)
}
type cacheWarmerSyncerHandler struct {
client databroker.DataBrokerServiceClient
cache storage.Cache
}
func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.client
}
func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) {
h.cache.InvalidateAll()
}
func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
for _, record := range records {
req := &databroker.QueryRequest{
Type: record.Type,
Limit: 1,
}
req.SetFilterByIDOrIndex(record.Id)
res := &databroker.QueryResponse{
Records: []*databroker.Record{record},
TotalCount: 1,
ServerVersion: serverVersion,
RecordVersion: record.Version,
}
expiry := time.Now().Add(time.Hour * 24 * 365)
key, err := storage.MarshalQueryRequest(req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
continue
}
value, err := storage.MarshalQueryResponse(res)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
continue
}
h.cache.Set(expiry, key, value)
}
}

View file

@ -0,0 +1,52 @@
package authorize
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/testutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestCacheWarmer(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, 10*time.Minute)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
_, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{
{Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)},
{Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)},
},
})
require.NoError(t, err)
cache := storage.NewGlobalCache(time.Minute)
cw := newCacheWarmer(cc, cache, "example.com/record")
go cw.Run(ctx)
assert.Eventually(t, func() bool {
req := &databrokerpb.QueryRequest{
Type: "example.com/record",
Limit: 1,
}
req.SetFilterByIDOrIndex("e1")
res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req)
require.NoError(t, err)
return len(res.GetRecords()) == 1
}, 10*time.Second, time.Millisecond*100)
}

View file

@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) {
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
sq := storage.NewStaticQuerier(s1)
tsq := storage.NewTracingQuerier(sq)
cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache())
tcq := storage.NewTracingQuerier(cq)
qctx := storage.WithQuerier(ctx, tcq)
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
qctx := storage.WithQuerier(ctx, cq)
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) {
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
assert.NotNil(t, s)
assert.Len(t, tsq.Traces(), tc.underlyingQueryCount,
"should have %d traces to the underlying querier", tc.underlyingQueryCount)
assert.Len(t, tcq.Traces(), tc.cachedQueryCount,
"should have %d traces to the cached querier", tc.cachedQueryCount)
})
}
}

View file

@ -241,6 +241,7 @@ var internalPathsNeedingLogin = set.From([]string{
"/.pomerium/jwt",
"/.pomerium/user",
"/.pomerium/webauthn",
"/.pomerium/api/v1/routes",
})
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {

View file

@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End()
querier := storage.NewTracingQuerier(
storage.NewCachingQuerier(
storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
),
storage.NewLocalCache(),
),
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)
@ -74,10 +69,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
return nil, err
}
// take the state lock here so we don't update while evaluating
a.stateLock.RLock()
res, err := state.evaluator.Evaluate(ctx, req)
a.stateLock.RUnlock()
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
return nil, err

View file

@ -150,8 +150,8 @@ func (b *Builder) BuildBootstrapDynamicResources(
// BuildBootstrapLayeredRuntime builds the layered runtime for the envoy bootstrap.
func (b *Builder) BuildBootstrapLayeredRuntime(ctx context.Context) (*envoy_config_bootstrap_v3.LayeredRuntime, error) {
flushIntervalMs := 5000
minFlushSpans := 3
flushIntervalMs := trace.BatchSpanProcessorScheduleDelay()
minFlushSpans := trace.BatchSpanProcessorMaxExportBatchSize()
if trace.DebugFlagsFromContext(ctx).Check(trace.EnvoyFlushEverySpan) {
minFlushSpans = 1
flushIntervalMs = math.MaxInt32
@ -166,15 +166,12 @@ func (b *Builder) BuildBootstrapLayeredRuntime(ctx context.Context) (*envoy_conf
"tracing": map[string]any{
"opentelemetry": map[string]any{
"flush_interval_ms": flushIntervalMs,
// For most requests, envoy generates 3 spans:
// Note: for most requests, envoy generates 3 spans:
// - ingress (downstream->envoy)
// - ext_authz check request (envoy->pomerium)
// - egress (envoy->upstream)
// The default value is 5, which usually leads to delayed exports.
// This can be set lower, e.g. 1 to have envoy export every span
// individually (useful for testing), but 3 is a reasonable default.
// If set to 1, also set flush_interval_ms to a very large number to
// effectively disable it.
// Some requests only generate 2 spans, e.g. if there is no upstream
// request made or auth fails.
"min_flush_spans": minFlushSpans,
},
},

View file

@ -51,7 +51,7 @@ func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) {
"tracing": {
"opentelemetry": {
"flush_interval_ms": 5000,
"min_flush_spans": 3
"min_flush_spans": 512
}
}
}

View file

@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/envutil"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
@ -18,18 +19,30 @@ import (
)
var (
numRoutes int
dumpErrLogs bool
numRoutes int
dumpErrLogs bool
enableTracing bool
publicRoutes bool
)
func init() {
flag.IntVar(&numRoutes, "routes", 100, "number of routes")
flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/<test-name>)")
flag.BoolVar(&enableTracing, "enable-tracing", false, "enable tracing")
flag.BoolVar(&publicRoutes, "public-routes", false, "use public unauthenticated routes")
}
func TestRequestLatency(t *testing.T) {
resume := envutil.PauseProfiling(t)
env := testenv.New(t, testenv.Silent())
var env testenv.Environment
if enableTracing {
receiver := scenarios.NewOTLPTraceReceiver()
env = testenv.New(t, testenv.Silent(), testenv.WithTraceClient(receiver.NewGRPCClient()))
env.Add(receiver)
} else {
env = testenv.New(t, testenv.Silent())
}
users := []*scenarios.User{}
for i := range numRoutes {
users = append(users, &scenarios.User{
@ -47,9 +60,12 @@ func TestRequestLatency(t *testing.T) {
routes := make([]testenv.Route, numRoutes)
for i := range numRoutes {
routes[i] = up.Route().
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
// Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
From(env.SubdomainURL(fmt.Sprintf("from-%d", i)))
if publicRoutes {
routes[i] = routes[i].Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
} else {
routes[i] = routes[i].PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
}
}
env.AddUpstream(up)

View file

@ -2,9 +2,12 @@ package trace
import (
"context"
"os"
"strconv"
"go.opentelemetry.io/contrib/propagators/autoprop"
"go.opentelemetry.io/otel"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/embedded"
)
@ -44,3 +47,25 @@ var _ trace.Tracer = panicTracer{}
func (p panicTracer) Start(context.Context, string, ...trace.SpanStartOption) (context.Context, trace.Span) {
panic("global tracer used")
}
// functions below mimic those with the same name in otel/sdk/internal/env/env.go
func BatchSpanProcessorScheduleDelay() int {
const defaultValue = sdktrace.DefaultScheduleDelay
if v, ok := os.LookupEnv("OTEL_BSP_SCHEDULE_DELAY"); ok {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return defaultValue
}
func BatchSpanProcessorMaxExportBatchSize() int {
const defaultValue = sdktrace.DefaultMaxExportBatchSize
if v, ok := os.LookupEnv("OTEL_BSP_MAX_EXPORT_BATCH_SIZE"); ok {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return defaultValue
}

View file

@ -50,7 +50,7 @@ func NewServer(ctx context.Context) *ExporterServer {
}
func (srv *ExporterServer) Start(ctx context.Context) {
lis := bufconn.Listen(4096)
lis := bufconn.Listen(2 * 1024 * 1024)
go func() {
if err := srv.remoteClient.Start(ctx); err != nil {
panic(err)
@ -95,5 +95,6 @@ func (srv *ExporterServer) Shutdown(ctx context.Context) error {
if err := srv.remoteClient.Stop(ctx); err != nil {
errs = append(errs, err)
}
srv.cc.Close()
return errors.Join(errs...)
}

View file

@ -77,6 +77,7 @@ func (rec *OTLPTraceReceiver) Attach(ctx context.Context) {
// Modify implements testenv.Modifier.
func (rec *OTLPTraceReceiver) Modify(cfg *config.Config) {
cfg.Options.TracingProvider = "otlp"
cfg.Options.TracingOTLPEndpoint = rec.GRPCEndpointURL().Value()
}
func (rec *OTLPTraceReceiver) handleV1Traces(w http.ResponseWriter, r *http.Request) {

View file

@ -12,13 +12,6 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@ -27,6 +20,14 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
oteltrace "go.opentelemetry.io/otel/trace"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
)
func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) {
@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) {
Exact: true,
CheckDetachedSpans: true,
},
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}},
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}},
Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices},
Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)},
Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)},

View file

@ -18,64 +18,8 @@ type Cache interface {
update func(ctx context.Context) ([]byte, error),
) ([]byte, error)
Invalidate(key []byte)
}
type localCache struct {
singleflight singleflight.Group
mu sync.RWMutex
m map[string][]byte
}
// NewLocalCache creates a new Cache backed by a map.
func NewLocalCache() Cache {
return &localCache{
m: make(map[string][]byte),
}
}
func (cache *localCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
strkey := string(key)
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
v, err, _ := cache.singleflight.Do(strkey, func() (any, error) {
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
result, err := update(ctx)
if err != nil {
return nil, err
}
cache.mu.Lock()
cache.m[strkey] = result
cache.mu.Unlock()
return result, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *localCache) Invalidate(key []byte) {
cache.mu.Lock()
delete(cache.m, string(key))
cache.mu.Unlock()
InvalidateAll()
Set(expiry time.Time, key, value []byte)
}
type globalCache struct {
@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate(
if err != nil {
return nil, err
}
cache.set(key, value)
cache.set(time.Now().Add(cache.ttl), key, value)
return value, nil
})
if err != nil {
@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) {
cache.mu.Unlock()
}
func (cache *globalCache) InvalidateAll() {
cache.mu.Lock()
cache.fastcache.Reset()
cache.mu.Unlock()
}
func (cache *globalCache) Set(expiry time.Time, key, value []byte) {
cache.set(expiry, key, value)
}
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
cache.mu.RLock()
item := cache.fastcache.Get(nil, k)
@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool)
return data, expiry, true
}
func (cache *globalCache) set(k, v []byte) {
unix := time.Now().Add(cache.ttl).UnixMilli()
item := make([]byte, len(v)+8)
func (cache *globalCache) set(expiry time.Time, key, value []byte) {
unix := expiry.UnixMilli()
item := make([]byte, len(value)+8)
binary.LittleEndian.PutUint64(item, uint64(unix))
copy(item[8:], v)
copy(item[8:], value)
cache.mu.Lock()
cache.fastcache.Set(k, item)
cache.fastcache.Set(key, item)
cache.mu.Unlock()
}

View file

@ -9,34 +9,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLocalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
callCount := 0
update := func(_ context.Context) ([]byte, error) {
callCount++
return []byte("v1"), nil
}
c := NewLocalCache()
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
c.Invalidate([]byte("k1"))
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 2, callCount)
}
func TestGlobalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) {
})
return err != nil
}, time.Second, time.Millisecond*10, "should honor TTL")
c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2"))
v, err = c.GetOrUpdate(ctx, []byte("k2"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v2"), v)
assert.Equal(t, 2, callCount)
}

View file

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"strconv"
"sync"
"github.com/google/uuid"
grpc "google.golang.org/grpc"
@ -12,7 +11,6 @@ import (
status "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/cryptutil"
@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
return q.client.Query(ctx, in, opts...)
}
// A TracingQuerier records calls to Query.
type TracingQuerier struct {
underlying Querier
mu sync.Mutex
traces []QueryTrace
}
// A QueryTrace traces a call to Query.
type QueryTrace struct {
ServerVersion, RecordVersion uint64
RecordType string
Query string
Filter *structpb.Struct
}
// NewTracingQuerier creates a new TracingQuerier.
func NewTracingQuerier(q Querier) *TracingQuerier {
return &TracingQuerier{
underlying: q,
}
}
// InvalidateCache invalidates the cache.
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
q.underlying.InvalidateCache(ctx, in)
}
// Query queries for records.
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
res, err := q.underlying.Query(ctx, in, opts...)
if err == nil {
q.mu.Lock()
q.traces = append(q.traces, QueryTrace{
RecordType: in.GetType(),
Query: in.GetQuery(),
Filter: in.GetFilter(),
})
q.mu.Unlock()
}
return res, err
}
// Traces returns all the traces.
func (q *TracingQuerier) Traces() []QueryTrace {
q.mu.Lock()
traces := make([]QueryTrace, len(q.traces))
copy(traces, q.traces)
q.mu.Unlock()
return traces
}
type cachingQuerier struct {
q Querier
cache Cache
@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que
}
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
key, err := MarshalQueryRequest(in)
if err != nil {
return nil, err
}
@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
if err != nil {
return nil, err
}
return proto.Marshal(res)
return MarshalQueryResponse(res)
})
if err != nil {
return nil, err
@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
}
return &res, nil
}
// MarshalQueryRequest marshales the query request.
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
return (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(req)
}
// MarshalQueryResponse marshals the query response.
func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
return (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(res)
}

View file

@ -40,11 +40,32 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) *
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
// Programmatic API handlers and middleware
a := r.PathPrefix(dashboardPath + "/api").Subrouter()
// login api handler generates a user-navigable login url to authenticate
a.Path("/v1/login").Handler(httputil.HandlerFunc(p.ProgrammaticLogin)).
Queries(urlutil.QueryRedirectURI, "").
Methods(http.MethodGet)
// gorilla mux has a bug that prevents HTTP 405 errors from being returned properly so we do all this manually
// https://github.com/gorilla/mux/issues/739
r.PathPrefix(dashboardPath + "/api").
Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
// login api handler generates a user-navigable login url to authenticate
case dashboardPath + "/api/v1/login":
if r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return nil
}
if !r.URL.Query().Has(urlutil.QueryRedirectURI) {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
}
return p.ProgrammaticLogin(w, r)
case dashboardPath + "/api/v1/routes":
if r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return nil
}
return p.routesPortalJSON(w, r)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
}))
return r
}

56
proxy/handlers_portal.go Normal file
View file

@ -0,0 +1,56 @@
package proxy
import (
"encoding/json"
"net/http"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/proxy/portal"
)
func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
u := p.getUserInfoData(r)
rs := p.getPortalRoutes(u)
m := map[string]any{}
m["routes"] = rs
b, err := json.Marshal(m)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(b)
return nil
}
func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route {
options := p.currentOptions.Load()
pu := p.getPortalUser(u)
var routes []*config.Policy
for route := range options.GetAllPolicies() {
if portal.CheckRouteAccess(pu, route) {
routes = append(routes, route)
}
}
return portal.RoutesFromConfigRoutes(routes)
}
func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User {
pu := portal.User{}
pu.SessionID = u.Session.GetId()
pu.UserID = u.User.GetId()
pu.Email = u.User.GetEmail()
for _, dg := range u.DirectoryGroups {
if v := dg.ID; v != "" {
pu.Groups = append(pu.Groups, dg.ID)
}
if v := dg.Name; v != "" {
pu.Groups = append(pu.Groups, dg.Name)
}
}
return pu
}

View file

@ -0,0 +1,51 @@
package proxy
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil"
)
func TestProxy_routesPortalJSON(t *testing.T) {
ctx := context.Background()
cfg := &config.Config{Options: config.NewDefaultOptions()}
to, err := config.ParseWeightedUrls("https://to.example.com")
require.NoError(t, err)
cfg.Options.Routes = append(cfg.Options.Routes, config.Policy{
Name: "public",
Description: "PUBLIC ROUTE",
LogoURL: "https://logo.example.com",
From: "https://from.example.com",
To: to,
AllowPublicUnauthenticatedAccess: true,
})
proxy, err := New(ctx, cfg)
require.NoError(t, err)
r := httptest.NewRequest(http.MethodGet, "/.pomerium/api/v1/routes", nil)
w := httptest.NewRecorder()
router := httputil.NewRouter()
router = proxy.registerDashboardHandlers(router, cfg.Options)
router.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
assert.JSONEq(t, `{"routes":[
{
"id": "4e71df99c0317efb",
"name": "public",
"from": "https://from.example.com",
"type": "http",
"description": "PUBLIC ROUTE",
"logo_url": "https://logo.example.com"
}
]}`, w.Body.String())
}

105
proxy/portal/filter.go Normal file
View file

@ -0,0 +1,105 @@
package portal
import (
"strings"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
// User is the computed user information needed for access decisions.
type User struct {
SessionID string
UserID string
Email string
Groups []string
}
// CheckRouteAccess checks if the user has access to the route.
func CheckRouteAccess(user User, route *config.Policy) bool {
// check the main policy
ppl := route.ToPPL()
if checkPPLAccess(user, ppl) {
return true
}
// check sub-policies
for _, sp := range route.SubPolicies {
if sp.SourcePPL == "" {
continue
}
ppl, err := parser.New().ParseYAML(strings.NewReader(sp.SourcePPL))
if err != nil {
// ignore invalid PPL
continue
}
if checkPPLAccess(user, ppl) {
return true
}
}
// nothing matched
return false
}
func checkPPLAccess(user User, ppl *parser.Policy) bool {
for _, r := range ppl.Rules {
// ignore deny rules
if r.Action != parser.ActionAllow {
continue
}
// ignore complex rules
if len(r.Nor) > 0 || len(r.Not) > 0 || len(r.And) > 1 {
continue
}
cs := append(append([]parser.Criterion{}, r.Or...), r.And...)
for _, c := range cs {
ok := checkPPLCriterionAccess(user, c)
if ok {
return true
}
}
}
return false
}
func checkPPLCriterionAccess(user User, criterion parser.Criterion) bool {
switch criterion.Name {
case "accept":
return true
}
// require a session
if user.SessionID == "" {
return false
}
switch criterion.Name {
case "authenticated_user":
return true
}
// require a user
if user.UserID == "" {
return false
}
switch criterion.Name {
case "domain":
parts := strings.SplitN(user.Email, "@", 2)
return len(parts) == 2 && matchString(parts[1], criterion.Data)
case "email":
return matchString(user.Email, criterion.Data)
case "groups":
return matchStringList(user.Groups, criterion.Data)
case "user":
return matchString(user.UserID, criterion.Data)
}
return false
}

View file

@ -0,0 +1,67 @@
package portal_test
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/proxy/portal"
)
func TestCheckRouteAccess(t *testing.T) {
t.Parallel()
u1 := portal.User{}
u2 := portal.User{SessionID: "s2", UserID: "u2", Email: "u2@example.com", Groups: []string{"g2"}}
for _, tc := range []struct {
name string
user portal.User
route *config.Policy
}{
{"no ppl", u1, &config.Policy{}},
{"allow_any_authenticated_user", u1, &config.Policy{AllowAnyAuthenticatedUser: true}},
{"allowed_domains", u2, &config.Policy{AllowedDomains: []string{"not.example.com"}}},
{"allowed_users", u2, &config.Policy{AllowedUsers: []string{"u3"}}},
{"not conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"not": [{"accept": 1}]}}`)}},
{"nor conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"nor": [{"accept": 1}]}}`)}},
{"and conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"and": [{"accept": 1}, {"accept": 1}]}}`)}},
{"authenticated_user", u1, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"authenticated_user": 1}]}}`)}},
{"domain", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"domain": "not.example.com"}]}}`)}},
{"email", u1, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"email": "u2@example.com"}]}}`)}},
{"groups", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"groups": {"has": "g3"}}]}}`)}},
} {
assert.False(t, portal.CheckRouteAccess(tc.user, tc.route), "%s: should deny access for %v to %v",
tc.name, tc.user, tc.route)
}
for _, tc := range []struct {
name string
user portal.User
route *config.Policy
}{
{"allow_public_unauthenticated_access", u1, &config.Policy{AllowPublicUnauthenticatedAccess: true}},
{"allow_any_authenticated_user", u2, &config.Policy{AllowAnyAuthenticatedUser: true}},
{"allowed_domains", u2, &config.Policy{AllowedDomains: []string{"example.com"}}},
{"allowed_users", u2, &config.Policy{AllowedUsers: []string{"u2"}}},
{"and conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"and": [{"accept": 1}]}}`)}},
{"or conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"reject": 1}, {"accept": 1}]}}`)}},
{"authenticated_user", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"authenticated_user": 1}]}}`)}},
{"domain", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"domain": "example.com"}]}}`)}},
{"email", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"email": "u2@example.com"}]}}`)}},
{"groups", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"groups": {"has": "g2"}}]}}`)}},
} {
assert.True(t, portal.CheckRouteAccess(tc.user, tc.route), "%s: should grant access for %v to %v",
tc.name, tc.user, tc.route)
}
}
func mustParsePPL(t testing.TB, raw string) *config.PPLPolicy {
ppl, err := parser.New().ParseJSON(strings.NewReader(raw))
require.NoError(t, err)
return &config.PPLPolicy{Policy: ppl}
}

108
proxy/portal/matchers.go Normal file
View file

@ -0,0 +1,108 @@
package portal
import (
"strings"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
type matcher[T any] func(left T, right parser.Value) bool
var stringMatchers = map[string]matcher[string]{
"contains": matchStringContains,
"ends_with": matchStringEndsWith,
"is": matchStringIs,
"starts_with": matchStringStartsWith,
}
var stringListMatchers = map[string]matcher[[]string]{
"has": matchStringListHas,
"is": matchStringListIs,
}
func matchString(left string, right parser.Value) bool {
obj, ok := right.(parser.Object)
if !ok {
obj = parser.Object{
"is": right,
}
}
for k, v := range obj {
f, ok := stringMatchers[k]
if !ok {
return false
}
ok = f(left, v)
if ok {
return true
}
}
return false
}
func matchStringContains(left string, right parser.Value) bool {
str, ok := right.(parser.String)
if !ok {
return false
}
return strings.Contains(left, string(str))
}
func matchStringEndsWith(left string, right parser.Value) bool {
str, ok := right.(parser.String)
if !ok {
return false
}
return strings.HasSuffix(left, string(str))
}
func matchStringIs(left string, right parser.Value) bool {
str, ok := right.(parser.String)
if !ok {
return false
}
return left == string(str)
}
func matchStringStartsWith(left string, right parser.Value) bool {
str, ok := right.(parser.String)
if !ok {
return false
}
return strings.HasPrefix(left, string(str))
}
func matchStringList(left []string, right parser.Value) bool {
obj, ok := right.(parser.Object)
if !ok {
obj = parser.Object{
"has": right,
}
}
for k, v := range obj {
f, ok := stringListMatchers[k]
if !ok {
return false
}
ok = f(left, v)
if ok {
return true
}
}
return false
}
func matchStringListHas(left []string, right parser.Value) bool {
for _, str := range left {
if matchStringIs(str, right) {
return true
}
}
return false
}
func matchStringListIs(left []string, right parser.Value) bool {
return len(left) == 1 && matchStringListHas(left, right)
}

View file

@ -0,0 +1,83 @@
package portal
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
func Test_matchString(t *testing.T) {
t.Parallel()
t.Run("string", func(t *testing.T) {
assert.True(t, matchString("TEST", mustParseValue(t, `"TEST"`)))
})
t.Run("bool", func(t *testing.T) {
assert.False(t, matchString("true", mustParseValue(t, `true`)))
})
t.Run("number", func(t *testing.T) {
assert.False(t, matchString("1", mustParseValue(t, `1`)))
})
t.Run("null", func(t *testing.T) {
assert.False(t, matchString("null", mustParseValue(t, `null`)))
})
t.Run("array", func(t *testing.T) {
assert.False(t, matchString("[]", mustParseValue(t, `[]`)))
})
t.Run("contains", func(t *testing.T) {
assert.True(t, matchString("XYZ", mustParseValue(t, `{"contains":"Y"}`)))
assert.False(t, matchString("XYZ", mustParseValue(t, `{"contains":"A"}`)))
})
t.Run("ends_with", func(t *testing.T) {
assert.True(t, matchString("XYZ", mustParseValue(t, `{"ends_with":"Z"}`)))
assert.False(t, matchString("XYZ", mustParseValue(t, `{"ends_with":"X"}`)))
})
t.Run("is", func(t *testing.T) {
assert.True(t, matchString("XYZ", mustParseValue(t, `{"is":"XYZ"}`)))
assert.False(t, matchString("XYZ", mustParseValue(t, `{"is":"X"}`)))
})
t.Run("starts_with", func(t *testing.T) {
assert.True(t, matchString("XYZ", mustParseValue(t, `{"starts_with":"X"}`)))
assert.False(t, matchString("XYZ", mustParseValue(t, `{"starts_with":"Z"}`)))
})
}
func Test_matchStringList(t *testing.T) {
t.Parallel()
t.Run("string", func(t *testing.T) {
assert.True(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `"Y"`)))
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `"A"`)))
})
t.Run("bool", func(t *testing.T) {
assert.False(t, matchStringList([]string{"true"}, mustParseValue(t, `true`)))
})
t.Run("number", func(t *testing.T) {
assert.False(t, matchStringList([]string{"1"}, mustParseValue(t, `1`)))
})
t.Run("null", func(t *testing.T) {
assert.False(t, matchStringList([]string{"null"}, mustParseValue(t, `null`)))
})
t.Run("array", func(t *testing.T) {
assert.False(t, matchStringList([]string{"[]"}, mustParseValue(t, `[]`)))
})
t.Run("has", func(t *testing.T) {
assert.True(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"has":"Y"}`)))
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"has":"A"}`)))
})
t.Run("is", func(t *testing.T) {
assert.True(t, matchStringList([]string{"X"}, mustParseValue(t, `{"is":"X"}`)))
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"is":"Y"}`)))
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"is":"A"}`)))
})
}
func mustParseValue(t testing.TB, raw string) parser.Value {
v, err := parser.ParseValue(strings.NewReader(raw))
require.NoError(t, err)
return v
}

60
proxy/portal/portal.go Normal file
View file

@ -0,0 +1,60 @@
// Package portal contains the code for the routes portal
package portal
import (
"fmt"
"strings"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/zero/importutil"
)
// A Route is a portal route.
type Route struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
From string `json:"from"`
Description string `json:"description"`
ConnectCommand string `json:"connect_command,omitempty"`
LogoURL string `json:"logo_url"`
}
// RoutesFromConfigRoutes converts config routes into portal routes.
func RoutesFromConfigRoutes(routes []*config.Policy) []Route {
prs := make([]Route, len(routes))
for i, route := range routes {
pr := Route{}
pr.ID = route.ID
if pr.ID == "" {
pr.ID = fmt.Sprintf("%x", route.MustRouteID())
}
pr.Name = route.Name
pr.From = route.From
fromURL, err := urlutil.ParseAndValidateURL(route.From)
if err == nil {
if strings.HasPrefix(fromURL.Scheme, "tcp+") {
pr.Type = "tcp"
pr.ConnectCommand = "pomerium-cli tcp " + fromURL.Host
} else if strings.HasPrefix(fromURL.Scheme, "udp+") {
pr.Type = "udp"
pr.ConnectCommand = "pomerium-cli udp " + fromURL.Host
} else {
pr.Type = "http"
}
} else {
pr.Type = "http"
}
pr.Description = route.Description
pr.LogoURL = route.LogoURL
prs[i] = pr
}
// generate names if they're empty
for i, name := range importutil.GenerateRouteNames(routes) {
if prs[i].Name == "" {
prs[i].Name = name
}
}
return prs
}

View file

@ -0,0 +1,71 @@
package portal_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/proxy/portal"
)
func TestRouteFromConfigRoute(t *testing.T) {
t.Parallel()
to1, err := config.ParseWeightedUrls("https://to.example.com")
require.NoError(t, err)
to2, err := config.ParseWeightedUrls("tcp://postgres:5432")
require.NoError(t, err)
assert.Equal(t, []portal.Route{
{
ID: "4e71df99c0317efb",
Name: "from",
Type: "http",
From: "https://from.example.com",
Description: "ROUTE #1",
LogoURL: "https://logo.example.com",
},
{
ID: "7c377f11cdb9700e",
Name: "from-path",
Type: "http",
From: "https://from.example.com",
},
{
ID: "708e3cbd0bbe8547",
Name: "postgres",
Type: "tcp",
From: "tcp+https://postgres.example.com:5432",
ConnectCommand: "pomerium-cli tcp postgres.example.com:5432",
},
{
ID: "2dd08d87486e051a",
Name: "dns",
Type: "udp",
From: "udp+https://dns.example.com:53",
ConnectCommand: "pomerium-cli udp dns.example.com:53",
},
}, portal.RoutesFromConfigRoutes([]*config.Policy{
{
From: "https://from.example.com",
To: to1,
Description: "ROUTE #1",
LogoURL: "https://logo.example.com",
},
{
From: "https://from.example.com",
To: to1,
Path: "/path",
},
{
From: "tcp+https://postgres.example.com:5432",
To: to2,
},
{
From: "udp+https://dns.example.com:53",
To: to2,
},
}))
}