mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-24 03:59:49 +02:00
Merge branch 'main' into kralicky/pgx-tracing
This commit is contained in:
commit
d1b8e3b92f
25 changed files with 938 additions and 216 deletions
|
@ -6,11 +6,13 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
|
@ -21,21 +23,16 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Authorize struct holds
|
||||
type Authorize struct {
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
|
||||
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
|
||||
// This should provide a consistent view of the data at a given server/record version and
|
||||
// avoid partial updates.
|
||||
stateLock sync.RWMutex
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
tracer oteltrace.Tracer
|
||||
|
@ -60,6 +57,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
}
|
||||
a.state = atomicutil.NewValue(state)
|
||||
|
||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
|
@ -70,8 +68,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
|
|||
|
||||
// Run runs the authorize service.
|
||||
func (a *Authorize) Run(ctx context.Context) error {
|
||||
a.accessTracker.Run(ctx)
|
||||
return nil
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
a.accessTracker.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
a.groupsCacheWarmer.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func validateOptions(o *config.Options) error {
|
||||
|
@ -150,9 +156,13 @@ func newPolicyEvaluator(
|
|||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
currentState := a.state.Load()
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
a.state.Store(newState)
|
||||
|
||||
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
|
||||
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
122
authorize/cache_warmer.go
Normal file
122
authorize/cache_warmer.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
type cacheWarmer struct {
|
||||
cc *grpc.ClientConn
|
||||
cache storage.Cache
|
||||
typeURL string
|
||||
|
||||
updatedCC chan *grpc.ClientConn
|
||||
}
|
||||
|
||||
func newCacheWarmer(
|
||||
cc *grpc.ClientConn,
|
||||
cache storage.Cache,
|
||||
typeURL string,
|
||||
) *cacheWarmer {
|
||||
return &cacheWarmer{
|
||||
cc: cc,
|
||||
cache: cache,
|
||||
typeURL: typeURL,
|
||||
|
||||
updatedCC: make(chan *grpc.ClientConn, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) {
|
||||
for {
|
||||
select {
|
||||
case cw.updatedCC <- cc:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-cw.updatedCC:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *cacheWarmer) Run(ctx context.Context) {
|
||||
// Run a syncer for the cache warmer until the underlying databroker connection is changed.
|
||||
// When that happens cancel the currently running syncer and start a new one.
|
||||
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
go cw.run(runCtx, cw.cc)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
return
|
||||
case cc := <-cw.updatedCC:
|
||||
log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer")
|
||||
cw.cc = cc
|
||||
runCancel()
|
||||
runCtx, runCancel = context.WithCancel(ctx)
|
||||
go cw.run(runCtx, cw.cc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) {
|
||||
log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache")
|
||||
_ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{
|
||||
client: databroker.NewDataBrokerServiceClient(cc),
|
||||
cache: cw.cache,
|
||||
}, databroker.WithTypeURL(cw.typeURL)).Run(ctx)
|
||||
}
|
||||
|
||||
type cacheWarmerSyncerHandler struct {
|
||||
client databroker.DataBrokerServiceClient
|
||||
cache storage.Cache
|
||||
}
|
||||
|
||||
func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return h.client
|
||||
}
|
||||
|
||||
func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) {
|
||||
h.cache.InvalidateAll()
|
||||
}
|
||||
|
||||
func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
|
||||
for _, record := range records {
|
||||
req := &databroker.QueryRequest{
|
||||
Type: record.Type,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(record.Id)
|
||||
|
||||
res := &databroker.QueryResponse{
|
||||
Records: []*databroker.Record{record},
|
||||
TotalCount: 1,
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: record.Version,
|
||||
}
|
||||
|
||||
expiry := time.Now().Add(time.Hour * 24 * 365)
|
||||
key, err := storage.MarshalQueryRequest(req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
|
||||
continue
|
||||
}
|
||||
value, err := storage.MarshalQueryResponse(res)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
|
||||
continue
|
||||
}
|
||||
|
||||
h.cache.Set(expiry, key, value)
|
||||
}
|
||||
}
|
52
authorize/cache_warmer_test.go
Normal file
52
authorize/cache_warmer_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestCacheWarmer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, 10*time.Minute)
|
||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||
})
|
||||
t.Cleanup(func() { cc.Close() })
|
||||
|
||||
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||
_, err := client.Put(ctx, &databrokerpb.PutRequest{
|
||||
Records: []*databrokerpb.Record{
|
||||
{Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)},
|
||||
{Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := storage.NewGlobalCache(time.Minute)
|
||||
|
||||
cw := newCacheWarmer(cc, cache, "example.com/record")
|
||||
go cw.Run(ctx)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
req := &databrokerpb.QueryRequest{
|
||||
Type: "example.com/record",
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex("e1")
|
||||
res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req)
|
||||
require.NoError(t, err)
|
||||
return len(res.GetRecords()) == 1
|
||||
}, 10*time.Second, time.Millisecond*100)
|
||||
}
|
|
@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) {
|
|||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
tsq := storage.NewTracingQuerier(sq)
|
||||
cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache())
|
||||
tcq := storage.NewTracingQuerier(cq)
|
||||
qctx := storage.WithQuerier(ctx, tcq)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
|
@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) {
|
|||
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
assert.Len(t, tsq.Traces(), tc.underlyingQueryCount,
|
||||
"should have %d traces to the underlying querier", tc.underlyingQueryCount)
|
||||
assert.Len(t, tcq.Traces(), tc.cachedQueryCount,
|
||||
"should have %d traces to the cached querier", tc.cachedQueryCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -241,6 +241,7 @@ var internalPathsNeedingLogin = set.From([]string{
|
|||
"/.pomerium/jwt",
|
||||
"/.pomerium/user",
|
||||
"/.pomerium/webauthn",
|
||||
"/.pomerium/api/v1/routes",
|
||||
})
|
||||
|
||||
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
||||
|
|
|
@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
|
||||
defer span.End()
|
||||
|
||||
querier := storage.NewTracingQuerier(
|
||||
storage.NewCachingQuerier(
|
||||
storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
a.globalCache,
|
||||
),
|
||||
storage.NewLocalCache(),
|
||||
),
|
||||
querier := storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
a.globalCache,
|
||||
)
|
||||
ctx = storage.WithQuerier(ctx, querier)
|
||||
|
||||
|
@ -74,10 +69,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// take the state lock here so we don't update while evaluating
|
||||
a.stateLock.RLock()
|
||||
res, err := state.evaluator.Evaluate(ctx, req)
|
||||
a.stateLock.RUnlock()
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
|
||||
return nil, err
|
||||
|
|
|
@ -150,8 +150,8 @@ func (b *Builder) BuildBootstrapDynamicResources(
|
|||
|
||||
// BuildBootstrapLayeredRuntime builds the layered runtime for the envoy bootstrap.
|
||||
func (b *Builder) BuildBootstrapLayeredRuntime(ctx context.Context) (*envoy_config_bootstrap_v3.LayeredRuntime, error) {
|
||||
flushIntervalMs := 5000
|
||||
minFlushSpans := 3
|
||||
flushIntervalMs := trace.BatchSpanProcessorScheduleDelay()
|
||||
minFlushSpans := trace.BatchSpanProcessorMaxExportBatchSize()
|
||||
if trace.DebugFlagsFromContext(ctx).Check(trace.EnvoyFlushEverySpan) {
|
||||
minFlushSpans = 1
|
||||
flushIntervalMs = math.MaxInt32
|
||||
|
@ -166,15 +166,12 @@ func (b *Builder) BuildBootstrapLayeredRuntime(ctx context.Context) (*envoy_conf
|
|||
"tracing": map[string]any{
|
||||
"opentelemetry": map[string]any{
|
||||
"flush_interval_ms": flushIntervalMs,
|
||||
// For most requests, envoy generates 3 spans:
|
||||
// Note: for most requests, envoy generates 3 spans:
|
||||
// - ingress (downstream->envoy)
|
||||
// - ext_authz check request (envoy->pomerium)
|
||||
// - egress (envoy->upstream)
|
||||
// The default value is 5, which usually leads to delayed exports.
|
||||
// This can be set lower, e.g. 1 to have envoy export every span
|
||||
// individually (useful for testing), but 3 is a reasonable default.
|
||||
// If set to 1, also set flush_interval_ms to a very large number to
|
||||
// effectively disable it.
|
||||
// Some requests only generate 2 spans, e.g. if there is no upstream
|
||||
// request made or auth fails.
|
||||
"min_flush_spans": minFlushSpans,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) {
|
|||
"tracing": {
|
||||
"opentelemetry": {
|
||||
"flush_interval_ms": 5000,
|
||||
"min_flush_spans": 3
|
||||
"min_flush_spans": 512
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/envutil"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
|
@ -18,18 +19,30 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
numRoutes int
|
||||
dumpErrLogs bool
|
||||
numRoutes int
|
||||
dumpErrLogs bool
|
||||
enableTracing bool
|
||||
publicRoutes bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.IntVar(&numRoutes, "routes", 100, "number of routes")
|
||||
flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/<test-name>)")
|
||||
flag.BoolVar(&enableTracing, "enable-tracing", false, "enable tracing")
|
||||
flag.BoolVar(&publicRoutes, "public-routes", false, "use public unauthenticated routes")
|
||||
}
|
||||
|
||||
func TestRequestLatency(t *testing.T) {
|
||||
resume := envutil.PauseProfiling(t)
|
||||
env := testenv.New(t, testenv.Silent())
|
||||
var env testenv.Environment
|
||||
if enableTracing {
|
||||
receiver := scenarios.NewOTLPTraceReceiver()
|
||||
env = testenv.New(t, testenv.Silent(), testenv.WithTraceClient(receiver.NewGRPCClient()))
|
||||
env.Add(receiver)
|
||||
} else {
|
||||
env = testenv.New(t, testenv.Silent())
|
||||
}
|
||||
|
||||
users := []*scenarios.User{}
|
||||
for i := range numRoutes {
|
||||
users = append(users, &scenarios.User{
|
||||
|
@ -47,9 +60,12 @@ func TestRequestLatency(t *testing.T) {
|
|||
routes := make([]testenv.Route, numRoutes)
|
||||
for i := range numRoutes {
|
||||
routes[i] = up.Route().
|
||||
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||
// Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
|
||||
From(env.SubdomainURL(fmt.Sprintf("from-%d", i)))
|
||||
if publicRoutes {
|
||||
routes[i] = routes[i].Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
} else {
|
||||
routes[i] = routes[i].PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
|
||||
}
|
||||
}
|
||||
env.AddUpstream(up)
|
||||
|
||||
|
|
|
@ -2,9 +2,12 @@ package trace
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"go.opentelemetry.io/contrib/propagators/autoprop"
|
||||
"go.opentelemetry.io/otel"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.opentelemetry.io/otel/trace/embedded"
|
||||
)
|
||||
|
@ -44,3 +47,25 @@ var _ trace.Tracer = panicTracer{}
|
|||
func (p panicTracer) Start(context.Context, string, ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
panic("global tracer used")
|
||||
}
|
||||
|
||||
// functions below mimic those with the same name in otel/sdk/internal/env/env.go
|
||||
|
||||
func BatchSpanProcessorScheduleDelay() int {
|
||||
const defaultValue = sdktrace.DefaultScheduleDelay
|
||||
if v, ok := os.LookupEnv("OTEL_BSP_SCHEDULE_DELAY"); ok {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func BatchSpanProcessorMaxExportBatchSize() int {
|
||||
const defaultValue = sdktrace.DefaultMaxExportBatchSize
|
||||
if v, ok := os.LookupEnv("OTEL_BSP_MAX_EXPORT_BATCH_SIZE"); ok {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ func NewServer(ctx context.Context) *ExporterServer {
|
|||
}
|
||||
|
||||
func (srv *ExporterServer) Start(ctx context.Context) {
|
||||
lis := bufconn.Listen(4096)
|
||||
lis := bufconn.Listen(2 * 1024 * 1024)
|
||||
go func() {
|
||||
if err := srv.remoteClient.Start(ctx); err != nil {
|
||||
panic(err)
|
||||
|
@ -95,5 +95,6 @@ func (srv *ExporterServer) Shutdown(ctx context.Context) error {
|
|||
if err := srv.remoteClient.Stop(ctx); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
srv.cc.Close()
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
|
|
@ -77,6 +77,7 @@ func (rec *OTLPTraceReceiver) Attach(ctx context.Context) {
|
|||
// Modify implements testenv.Modifier.
|
||||
func (rec *OTLPTraceReceiver) Modify(cfg *config.Config) {
|
||||
cfg.Options.TracingProvider = "otlp"
|
||||
cfg.Options.TracingOTLPEndpoint = rec.GRPCEndpointURL().Value()
|
||||
}
|
||||
|
||||
func (rec *OTLPTraceReceiver) handleV1Traces(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -12,13 +12,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
@ -27,6 +20,14 @@ import (
|
|||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||
)
|
||||
|
||||
func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) {
|
||||
|
@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) {
|
|||
Exact: true,
|
||||
CheckDetachedSpans: true,
|
||||
},
|
||||
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}},
|
||||
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}},
|
||||
Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices},
|
||||
Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)},
|
||||
Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)},
|
||||
|
|
|
@ -18,64 +18,8 @@ type Cache interface {
|
|||
update func(ctx context.Context) ([]byte, error),
|
||||
) ([]byte, error)
|
||||
Invalidate(key []byte)
|
||||
}
|
||||
|
||||
type localCache struct {
|
||||
singleflight singleflight.Group
|
||||
mu sync.RWMutex
|
||||
m map[string][]byte
|
||||
}
|
||||
|
||||
// NewLocalCache creates a new Cache backed by a map.
|
||||
func NewLocalCache() Cache {
|
||||
return &localCache{
|
||||
m: make(map[string][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *localCache) GetOrUpdate(
|
||||
ctx context.Context,
|
||||
key []byte,
|
||||
update func(ctx context.Context) ([]byte, error),
|
||||
) ([]byte, error) {
|
||||
strkey := string(key)
|
||||
|
||||
cache.mu.RLock()
|
||||
cached, ok := cache.m[strkey]
|
||||
cache.mu.RUnlock()
|
||||
if ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
v, err, _ := cache.singleflight.Do(strkey, func() (any, error) {
|
||||
cache.mu.RLock()
|
||||
cached, ok := cache.m[strkey]
|
||||
cache.mu.RUnlock()
|
||||
if ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
result, err := update(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache.mu.Lock()
|
||||
cache.m[strkey] = result
|
||||
cache.mu.Unlock()
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v.([]byte), nil
|
||||
}
|
||||
|
||||
func (cache *localCache) Invalidate(key []byte) {
|
||||
cache.mu.Lock()
|
||||
delete(cache.m, string(key))
|
||||
cache.mu.Unlock()
|
||||
InvalidateAll()
|
||||
Set(expiry time.Time, key, value []byte)
|
||||
}
|
||||
|
||||
type globalCache struct {
|
||||
|
@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache.set(key, value)
|
||||
cache.set(time.Now().Add(cache.ttl), key, value)
|
||||
return value, nil
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) {
|
|||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
func (cache *globalCache) InvalidateAll() {
|
||||
cache.mu.Lock()
|
||||
cache.fastcache.Reset()
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
func (cache *globalCache) Set(expiry time.Time, key, value []byte) {
|
||||
cache.set(expiry, key, value)
|
||||
}
|
||||
|
||||
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
|
||||
cache.mu.RLock()
|
||||
item := cache.fastcache.Get(nil, k)
|
||||
|
@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool)
|
|||
return data, expiry, true
|
||||
}
|
||||
|
||||
func (cache *globalCache) set(k, v []byte) {
|
||||
unix := time.Now().Add(cache.ttl).UnixMilli()
|
||||
item := make([]byte, len(v)+8)
|
||||
func (cache *globalCache) set(expiry time.Time, key, value []byte) {
|
||||
unix := expiry.UnixMilli()
|
||||
item := make([]byte, len(value)+8)
|
||||
binary.LittleEndian.PutUint64(item, uint64(unix))
|
||||
copy(item[8:], v)
|
||||
copy(item[8:], value)
|
||||
|
||||
cache.mu.Lock()
|
||||
cache.fastcache.Set(k, item)
|
||||
cache.fastcache.Set(key, item)
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
|
|
@ -9,34 +9,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLocalCache(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
callCount := 0
|
||||
update := func(_ context.Context) ([]byte, error) {
|
||||
callCount++
|
||||
return []byte("v1"), nil
|
||||
}
|
||||
c := NewLocalCache()
|
||||
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v1"), v)
|
||||
assert.Equal(t, 1, callCount)
|
||||
|
||||
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v1"), v)
|
||||
assert.Equal(t, 1, callCount)
|
||||
|
||||
c.Invalidate([]byte("k1"))
|
||||
|
||||
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v1"), v)
|
||||
assert.Equal(t, 2, callCount)
|
||||
}
|
||||
|
||||
func TestGlobalCache(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) {
|
|||
})
|
||||
return err != nil
|
||||
}, time.Second, time.Millisecond*10, "should honor TTL")
|
||||
|
||||
c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2"))
|
||||
v, err = c.GetOrUpdate(ctx, []byte("k2"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v2"), v)
|
||||
assert.Equal(t, 2, callCount)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
grpc "google.golang.org/grpc"
|
||||
|
@ -12,7 +11,6 @@ import (
|
|||
status "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
|||
return q.client.Query(ctx, in, opts...)
|
||||
}
|
||||
|
||||
// A TracingQuerier records calls to Query.
|
||||
type TracingQuerier struct {
|
||||
underlying Querier
|
||||
|
||||
mu sync.Mutex
|
||||
traces []QueryTrace
|
||||
}
|
||||
|
||||
// A QueryTrace traces a call to Query.
|
||||
type QueryTrace struct {
|
||||
ServerVersion, RecordVersion uint64
|
||||
|
||||
RecordType string
|
||||
Query string
|
||||
Filter *structpb.Struct
|
||||
}
|
||||
|
||||
// NewTracingQuerier creates a new TracingQuerier.
|
||||
func NewTracingQuerier(q Querier) *TracingQuerier {
|
||||
return &TracingQuerier{
|
||||
underlying: q,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates the cache.
|
||||
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
|
||||
q.underlying.InvalidateCache(ctx, in)
|
||||
}
|
||||
|
||||
// Query queries for records.
|
||||
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
res, err := q.underlying.Query(ctx, in, opts...)
|
||||
if err == nil {
|
||||
q.mu.Lock()
|
||||
q.traces = append(q.traces, QueryTrace{
|
||||
RecordType: in.GetType(),
|
||||
Query: in.GetQuery(),
|
||||
Filter: in.GetFilter(),
|
||||
})
|
||||
q.mu.Unlock()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Traces returns all the traces.
|
||||
func (q *TracingQuerier) Traces() []QueryTrace {
|
||||
q.mu.Lock()
|
||||
traces := make([]QueryTrace, len(q.traces))
|
||||
copy(traces, q.traces)
|
||||
q.mu.Unlock()
|
||||
return traces
|
||||
}
|
||||
|
||||
type cachingQuerier struct {
|
||||
q Querier
|
||||
cache Cache
|
||||
|
@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que
|
|||
}
|
||||
|
||||
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
key, err := (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(in)
|
||||
key, err := MarshalQueryRequest(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proto.Marshal(res)
|
||||
return MarshalQueryResponse(res)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
|||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// MarshalQueryRequest marshales the query request.
|
||||
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
|
||||
return (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(req)
|
||||
}
|
||||
|
||||
// MarshalQueryResponse marshals the query response.
|
||||
func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
||||
return (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(res)
|
||||
}
|
||||
|
|
|
@ -40,11 +40,32 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) *
|
|||
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
|
||||
|
||||
// Programmatic API handlers and middleware
|
||||
a := r.PathPrefix(dashboardPath + "/api").Subrouter()
|
||||
// login api handler generates a user-navigable login url to authenticate
|
||||
a.Path("/v1/login").Handler(httputil.HandlerFunc(p.ProgrammaticLogin)).
|
||||
Queries(urlutil.QueryRedirectURI, "").
|
||||
Methods(http.MethodGet)
|
||||
// gorilla mux has a bug that prevents HTTP 405 errors from being returned properly so we do all this manually
|
||||
// https://github.com/gorilla/mux/issues/739
|
||||
r.PathPrefix(dashboardPath + "/api").
|
||||
Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
// login api handler generates a user-navigable login url to authenticate
|
||||
case dashboardPath + "/api/v1/login":
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
if !r.URL.Query().Has(urlutil.QueryRedirectURI) {
|
||||
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
return p.ProgrammaticLogin(w, r)
|
||||
case dashboardPath + "/api/v1/routes":
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
return p.routesPortalJSON(w, r)
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||
return nil
|
||||
}))
|
||||
|
||||
return r
|
||||
}
|
||||
|
|
56
proxy/handlers_portal.go
Normal file
56
proxy/handlers_portal.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/proxy/portal"
|
||||
)
|
||||
|
||||
func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
|
||||
u := p.getUserInfoData(r)
|
||||
rs := p.getPortalRoutes(u)
|
||||
m := map[string]any{}
|
||||
m["routes"] = rs
|
||||
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route {
|
||||
options := p.currentOptions.Load()
|
||||
pu := p.getPortalUser(u)
|
||||
var routes []*config.Policy
|
||||
for route := range options.GetAllPolicies() {
|
||||
if portal.CheckRouteAccess(pu, route) {
|
||||
routes = append(routes, route)
|
||||
}
|
||||
}
|
||||
return portal.RoutesFromConfigRoutes(routes)
|
||||
}
|
||||
|
||||
func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User {
|
||||
pu := portal.User{}
|
||||
pu.SessionID = u.Session.GetId()
|
||||
pu.UserID = u.User.GetId()
|
||||
pu.Email = u.User.GetEmail()
|
||||
for _, dg := range u.DirectoryGroups {
|
||||
if v := dg.ID; v != "" {
|
||||
pu.Groups = append(pu.Groups, dg.ID)
|
||||
}
|
||||
if v := dg.Name; v != "" {
|
||||
pu.Groups = append(pu.Groups, dg.Name)
|
||||
}
|
||||
}
|
||||
return pu
|
||||
}
|
51
proxy/handlers_portal_test.go
Normal file
51
proxy/handlers_portal_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
func TestProxy_routesPortalJSON(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := &config.Config{Options: config.NewDefaultOptions()}
|
||||
to, err := config.ParseWeightedUrls("https://to.example.com")
|
||||
require.NoError(t, err)
|
||||
cfg.Options.Routes = append(cfg.Options.Routes, config.Policy{
|
||||
Name: "public",
|
||||
Description: "PUBLIC ROUTE",
|
||||
LogoURL: "https://logo.example.com",
|
||||
From: "https://from.example.com",
|
||||
To: to,
|
||||
AllowPublicUnauthenticatedAccess: true,
|
||||
})
|
||||
proxy, err := New(ctx, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/api/v1/routes", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router := httputil.NewRouter()
|
||||
router = proxy.registerDashboardHandlers(router, cfg.Options)
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
||||
assert.JSONEq(t, `{"routes":[
|
||||
{
|
||||
"id": "4e71df99c0317efb",
|
||||
"name": "public",
|
||||
"from": "https://from.example.com",
|
||||
"type": "http",
|
||||
"description": "PUBLIC ROUTE",
|
||||
"logo_url": "https://logo.example.com"
|
||||
}
|
||||
]}`, w.Body.String())
|
||||
}
|
105
proxy/portal/filter.go
Normal file
105
proxy/portal/filter.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package portal
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
// User is the computed user information needed for access decisions.
|
||||
type User struct {
|
||||
SessionID string
|
||||
UserID string
|
||||
Email string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
// CheckRouteAccess checks if the user has access to the route.
|
||||
func CheckRouteAccess(user User, route *config.Policy) bool {
|
||||
// check the main policy
|
||||
ppl := route.ToPPL()
|
||||
if checkPPLAccess(user, ppl) {
|
||||
return true
|
||||
}
|
||||
|
||||
// check sub-policies
|
||||
for _, sp := range route.SubPolicies {
|
||||
if sp.SourcePPL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ppl, err := parser.New().ParseYAML(strings.NewReader(sp.SourcePPL))
|
||||
if err != nil {
|
||||
// ignore invalid PPL
|
||||
continue
|
||||
}
|
||||
|
||||
if checkPPLAccess(user, ppl) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// nothing matched
|
||||
return false
|
||||
}
|
||||
|
||||
func checkPPLAccess(user User, ppl *parser.Policy) bool {
|
||||
for _, r := range ppl.Rules {
|
||||
// ignore deny rules
|
||||
if r.Action != parser.ActionAllow {
|
||||
continue
|
||||
}
|
||||
|
||||
// ignore complex rules
|
||||
if len(r.Nor) > 0 || len(r.Not) > 0 || len(r.And) > 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
cs := append(append([]parser.Criterion{}, r.Or...), r.And...)
|
||||
for _, c := range cs {
|
||||
ok := checkPPLCriterionAccess(user, c)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkPPLCriterionAccess(user User, criterion parser.Criterion) bool {
|
||||
switch criterion.Name {
|
||||
case "accept":
|
||||
return true
|
||||
}
|
||||
|
||||
// require a session
|
||||
if user.SessionID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch criterion.Name {
|
||||
case "authenticated_user":
|
||||
return true
|
||||
}
|
||||
|
||||
// require a user
|
||||
if user.UserID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch criterion.Name {
|
||||
case "domain":
|
||||
parts := strings.SplitN(user.Email, "@", 2)
|
||||
return len(parts) == 2 && matchString(parts[1], criterion.Data)
|
||||
case "email":
|
||||
return matchString(user.Email, criterion.Data)
|
||||
case "groups":
|
||||
return matchStringList(user.Groups, criterion.Data)
|
||||
case "user":
|
||||
return matchString(user.UserID, criterion.Data)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
67
proxy/portal/filter_test.go
Normal file
67
proxy/portal/filter_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package portal_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
"github.com/pomerium/pomerium/proxy/portal"
|
||||
)
|
||||
|
||||
func TestCheckRouteAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
u1 := portal.User{}
|
||||
u2 := portal.User{SessionID: "s2", UserID: "u2", Email: "u2@example.com", Groups: []string{"g2"}}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
user portal.User
|
||||
route *config.Policy
|
||||
}{
|
||||
{"no ppl", u1, &config.Policy{}},
|
||||
{"allow_any_authenticated_user", u1, &config.Policy{AllowAnyAuthenticatedUser: true}},
|
||||
{"allowed_domains", u2, &config.Policy{AllowedDomains: []string{"not.example.com"}}},
|
||||
{"allowed_users", u2, &config.Policy{AllowedUsers: []string{"u3"}}},
|
||||
{"not conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"not": [{"accept": 1}]}}`)}},
|
||||
{"nor conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"nor": [{"accept": 1}]}}`)}},
|
||||
{"and conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"and": [{"accept": 1}, {"accept": 1}]}}`)}},
|
||||
{"authenticated_user", u1, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"authenticated_user": 1}]}}`)}},
|
||||
{"domain", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"domain": "not.example.com"}]}}`)}},
|
||||
{"email", u1, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"email": "u2@example.com"}]}}`)}},
|
||||
{"groups", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"groups": {"has": "g3"}}]}}`)}},
|
||||
} {
|
||||
assert.False(t, portal.CheckRouteAccess(tc.user, tc.route), "%s: should deny access for %v to %v",
|
||||
tc.name, tc.user, tc.route)
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
user portal.User
|
||||
route *config.Policy
|
||||
}{
|
||||
{"allow_public_unauthenticated_access", u1, &config.Policy{AllowPublicUnauthenticatedAccess: true}},
|
||||
{"allow_any_authenticated_user", u2, &config.Policy{AllowAnyAuthenticatedUser: true}},
|
||||
{"allowed_domains", u2, &config.Policy{AllowedDomains: []string{"example.com"}}},
|
||||
{"allowed_users", u2, &config.Policy{AllowedUsers: []string{"u2"}}},
|
||||
{"and conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"and": [{"accept": 1}]}}`)}},
|
||||
{"or conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"reject": 1}, {"accept": 1}]}}`)}},
|
||||
{"authenticated_user", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"authenticated_user": 1}]}}`)}},
|
||||
{"domain", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"domain": "example.com"}]}}`)}},
|
||||
{"email", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"email": "u2@example.com"}]}}`)}},
|
||||
{"groups", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"groups": {"has": "g2"}}]}}`)}},
|
||||
} {
|
||||
assert.True(t, portal.CheckRouteAccess(tc.user, tc.route), "%s: should grant access for %v to %v",
|
||||
tc.name, tc.user, tc.route)
|
||||
}
|
||||
}
|
||||
|
||||
func mustParsePPL(t testing.TB, raw string) *config.PPLPolicy {
|
||||
ppl, err := parser.New().ParseJSON(strings.NewReader(raw))
|
||||
require.NoError(t, err)
|
||||
return &config.PPLPolicy{Policy: ppl}
|
||||
}
|
108
proxy/portal/matchers.go
Normal file
108
proxy/portal/matchers.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package portal
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
type matcher[T any] func(left T, right parser.Value) bool
|
||||
|
||||
var stringMatchers = map[string]matcher[string]{
|
||||
"contains": matchStringContains,
|
||||
"ends_with": matchStringEndsWith,
|
||||
"is": matchStringIs,
|
||||
"starts_with": matchStringStartsWith,
|
||||
}
|
||||
|
||||
var stringListMatchers = map[string]matcher[[]string]{
|
||||
"has": matchStringListHas,
|
||||
"is": matchStringListIs,
|
||||
}
|
||||
|
||||
func matchString(left string, right parser.Value) bool {
|
||||
obj, ok := right.(parser.Object)
|
||||
if !ok {
|
||||
obj = parser.Object{
|
||||
"is": right,
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range obj {
|
||||
f, ok := stringMatchers[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
ok = f(left, v)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchStringContains(left string, right parser.Value) bool {
|
||||
str, ok := right.(parser.String)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(left, string(str))
|
||||
}
|
||||
|
||||
func matchStringEndsWith(left string, right parser.Value) bool {
|
||||
str, ok := right.(parser.String)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(left, string(str))
|
||||
}
|
||||
|
||||
func matchStringIs(left string, right parser.Value) bool {
|
||||
str, ok := right.(parser.String)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return left == string(str)
|
||||
}
|
||||
|
||||
func matchStringStartsWith(left string, right parser.Value) bool {
|
||||
str, ok := right.(parser.String)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(left, string(str))
|
||||
}
|
||||
|
||||
func matchStringList(left []string, right parser.Value) bool {
|
||||
obj, ok := right.(parser.Object)
|
||||
if !ok {
|
||||
obj = parser.Object{
|
||||
"has": right,
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range obj {
|
||||
f, ok := stringListMatchers[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
ok = f(left, v)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchStringListHas(left []string, right parser.Value) bool {
|
||||
for _, str := range left {
|
||||
if matchStringIs(str, right) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchStringListIs(left []string, right parser.Value) bool {
|
||||
return len(left) == 1 && matchStringListHas(left, right)
|
||||
}
|
83
proxy/portal/matchers_test.go
Normal file
83
proxy/portal/matchers_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package portal
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
func Test_matchString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
assert.True(t, matchString("TEST", mustParseValue(t, `"TEST"`)))
|
||||
})
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
assert.False(t, matchString("true", mustParseValue(t, `true`)))
|
||||
})
|
||||
t.Run("number", func(t *testing.T) {
|
||||
assert.False(t, matchString("1", mustParseValue(t, `1`)))
|
||||
})
|
||||
t.Run("null", func(t *testing.T) {
|
||||
assert.False(t, matchString("null", mustParseValue(t, `null`)))
|
||||
})
|
||||
t.Run("array", func(t *testing.T) {
|
||||
assert.False(t, matchString("[]", mustParseValue(t, `[]`)))
|
||||
})
|
||||
t.Run("contains", func(t *testing.T) {
|
||||
assert.True(t, matchString("XYZ", mustParseValue(t, `{"contains":"Y"}`)))
|
||||
assert.False(t, matchString("XYZ", mustParseValue(t, `{"contains":"A"}`)))
|
||||
})
|
||||
t.Run("ends_with", func(t *testing.T) {
|
||||
assert.True(t, matchString("XYZ", mustParseValue(t, `{"ends_with":"Z"}`)))
|
||||
assert.False(t, matchString("XYZ", mustParseValue(t, `{"ends_with":"X"}`)))
|
||||
})
|
||||
t.Run("is", func(t *testing.T) {
|
||||
assert.True(t, matchString("XYZ", mustParseValue(t, `{"is":"XYZ"}`)))
|
||||
assert.False(t, matchString("XYZ", mustParseValue(t, `{"is":"X"}`)))
|
||||
})
|
||||
t.Run("starts_with", func(t *testing.T) {
|
||||
assert.True(t, matchString("XYZ", mustParseValue(t, `{"starts_with":"X"}`)))
|
||||
assert.False(t, matchString("XYZ", mustParseValue(t, `{"starts_with":"Z"}`)))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_matchStringList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
assert.True(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `"Y"`)))
|
||||
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `"A"`)))
|
||||
})
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
assert.False(t, matchStringList([]string{"true"}, mustParseValue(t, `true`)))
|
||||
})
|
||||
t.Run("number", func(t *testing.T) {
|
||||
assert.False(t, matchStringList([]string{"1"}, mustParseValue(t, `1`)))
|
||||
})
|
||||
t.Run("null", func(t *testing.T) {
|
||||
assert.False(t, matchStringList([]string{"null"}, mustParseValue(t, `null`)))
|
||||
})
|
||||
t.Run("array", func(t *testing.T) {
|
||||
assert.False(t, matchStringList([]string{"[]"}, mustParseValue(t, `[]`)))
|
||||
})
|
||||
t.Run("has", func(t *testing.T) {
|
||||
assert.True(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"has":"Y"}`)))
|
||||
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"has":"A"}`)))
|
||||
})
|
||||
t.Run("is", func(t *testing.T) {
|
||||
assert.True(t, matchStringList([]string{"X"}, mustParseValue(t, `{"is":"X"}`)))
|
||||
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"is":"Y"}`)))
|
||||
assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"is":"A"}`)))
|
||||
})
|
||||
}
|
||||
|
||||
func mustParseValue(t testing.TB, raw string) parser.Value {
|
||||
v, err := parser.ParseValue(strings.NewReader(raw))
|
||||
require.NoError(t, err)
|
||||
return v
|
||||
}
|
60
proxy/portal/portal.go
Normal file
60
proxy/portal/portal.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
// Package portal contains the code for the routes portal
|
||||
package portal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/zero/importutil"
|
||||
)
|
||||
|
||||
// A Route is a portal route.
|
||||
type Route struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
From string `json:"from"`
|
||||
Description string `json:"description"`
|
||||
ConnectCommand string `json:"connect_command,omitempty"`
|
||||
LogoURL string `json:"logo_url"`
|
||||
}
|
||||
|
||||
// RoutesFromConfigRoutes converts config routes into portal routes.
|
||||
func RoutesFromConfigRoutes(routes []*config.Policy) []Route {
|
||||
prs := make([]Route, len(routes))
|
||||
for i, route := range routes {
|
||||
pr := Route{}
|
||||
pr.ID = route.ID
|
||||
if pr.ID == "" {
|
||||
pr.ID = fmt.Sprintf("%x", route.MustRouteID())
|
||||
}
|
||||
pr.Name = route.Name
|
||||
pr.From = route.From
|
||||
fromURL, err := urlutil.ParseAndValidateURL(route.From)
|
||||
if err == nil {
|
||||
if strings.HasPrefix(fromURL.Scheme, "tcp+") {
|
||||
pr.Type = "tcp"
|
||||
pr.ConnectCommand = "pomerium-cli tcp " + fromURL.Host
|
||||
} else if strings.HasPrefix(fromURL.Scheme, "udp+") {
|
||||
pr.Type = "udp"
|
||||
pr.ConnectCommand = "pomerium-cli udp " + fromURL.Host
|
||||
} else {
|
||||
pr.Type = "http"
|
||||
}
|
||||
} else {
|
||||
pr.Type = "http"
|
||||
}
|
||||
pr.Description = route.Description
|
||||
pr.LogoURL = route.LogoURL
|
||||
prs[i] = pr
|
||||
}
|
||||
// generate names if they're empty
|
||||
for i, name := range importutil.GenerateRouteNames(routes) {
|
||||
if prs[i].Name == "" {
|
||||
prs[i].Name = name
|
||||
}
|
||||
}
|
||||
return prs
|
||||
}
|
71
proxy/portal/portal_test.go
Normal file
71
proxy/portal/portal_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package portal_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/proxy/portal"
|
||||
)
|
||||
|
||||
func TestRouteFromConfigRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
to1, err := config.ParseWeightedUrls("https://to.example.com")
|
||||
require.NoError(t, err)
|
||||
to2, err := config.ParseWeightedUrls("tcp://postgres:5432")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []portal.Route{
|
||||
{
|
||||
ID: "4e71df99c0317efb",
|
||||
Name: "from",
|
||||
Type: "http",
|
||||
From: "https://from.example.com",
|
||||
Description: "ROUTE #1",
|
||||
LogoURL: "https://logo.example.com",
|
||||
},
|
||||
{
|
||||
ID: "7c377f11cdb9700e",
|
||||
Name: "from-path",
|
||||
Type: "http",
|
||||
From: "https://from.example.com",
|
||||
},
|
||||
{
|
||||
ID: "708e3cbd0bbe8547",
|
||||
Name: "postgres",
|
||||
Type: "tcp",
|
||||
From: "tcp+https://postgres.example.com:5432",
|
||||
ConnectCommand: "pomerium-cli tcp postgres.example.com:5432",
|
||||
},
|
||||
{
|
||||
ID: "2dd08d87486e051a",
|
||||
Name: "dns",
|
||||
Type: "udp",
|
||||
From: "udp+https://dns.example.com:53",
|
||||
ConnectCommand: "pomerium-cli udp dns.example.com:53",
|
||||
},
|
||||
}, portal.RoutesFromConfigRoutes([]*config.Policy{
|
||||
{
|
||||
From: "https://from.example.com",
|
||||
To: to1,
|
||||
Description: "ROUTE #1",
|
||||
LogoURL: "https://logo.example.com",
|
||||
},
|
||||
{
|
||||
From: "https://from.example.com",
|
||||
To: to1,
|
||||
Path: "/path",
|
||||
},
|
||||
{
|
||||
From: "tcp+https://postgres.example.com:5432",
|
||||
To: to2,
|
||||
},
|
||||
{
|
||||
From: "udp+https://dns.example.com:53",
|
||||
To: to2,
|
||||
},
|
||||
}))
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue