diff --git a/authorize/authorize.go b/authorize/authorize.go index 4cd3dd496..a25da96fe 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -6,11 +6,13 @@ import ( "context" "fmt" "slices" - "sync" "time" "github.com/rs/zerolog" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + "github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" @@ -21,21 +23,16 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" - oteltrace "go.opentelemetry.io/otel/trace" ) // Authorize struct holds type Authorize struct { - state *atomicutil.Value[*authorizeState] - store *store.Store - currentOptions *atomicutil.Value[*config.Options] - accessTracker *AccessTracker - globalCache storage.Cache - - // The stateLock prevents updating the evaluator store simultaneously with an evaluation. - // This should provide a consistent view of the data at a given server/record version and - // avoid partial updates. - stateLock sync.RWMutex + state *atomicutil.Value[*authorizeState] + store *store.Store + currentOptions *atomicutil.Value[*config.Options] + accessTracker *AccessTracker + globalCache storage.Cache + groupsCacheWarmer *cacheWarmer tracerProvider oteltrace.TracerProvider tracer oteltrace.Tracer @@ -60,6 +57,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { } a.state = atomicutil.NewValue(state) + a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType) return a, nil } @@ -70,8 +68,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli // Run runs the authorize service. func (a *Authorize) Run(ctx context.Context) error { - a.accessTracker.Run(ctx) - return nil + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + a.accessTracker.Run(ctx) + return nil + }) + eg.Go(func() error { + a.groupsCacheWarmer.Run(ctx) + return nil + }) + return eg.Wait() } func validateOptions(o *config.Options) error { @@ -150,9 +156,13 @@ func newPolicyEvaluator( func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { currentState := a.state.Load() a.currentOptions.Store(cfg.Options) - if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { + if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state") } else { - a.state.Store(state) + a.state.Store(newState) + + if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection { + a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection) + } } } diff --git a/authorize/cache_warmer.go b/authorize/cache_warmer.go new file mode 100644 index 000000000..41c1d0ae2 --- /dev/null +++ b/authorize/cache_warmer.go @@ -0,0 +1,122 @@ +package authorize + +import ( + "context" + "time" + + "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" +) + +type cacheWarmer struct { + cc *grpc.ClientConn + cache storage.Cache + typeURL string + + updatedCC chan *grpc.ClientConn +} + +func newCacheWarmer( + cc *grpc.ClientConn, + cache storage.Cache, + typeURL string, +) *cacheWarmer { + return &cacheWarmer{ + cc: cc, + cache: cache, + typeURL: typeURL, + + updatedCC: make(chan *grpc.ClientConn, 1), + } +} + +func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) { + for { + select { + case cw.updatedCC <- cc: + return + default: + } + select { + case <-cw.updatedCC: + default: + } + } +} + +func (cw *cacheWarmer) Run(ctx context.Context) { + // Run a syncer for the cache warmer until the underlying databroker connection is changed. + // When that happens cancel the currently running syncer and start a new one. + + runCtx, runCancel := context.WithCancel(ctx) + go cw.run(runCtx, cw.cc) + + for { + select { + case <-ctx.Done(): + runCancel() + return + case cc := <-cw.updatedCC: + log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer") + cw.cc = cc + runCancel() + runCtx, runCancel = context.WithCancel(ctx) + go cw.run(runCtx, cw.cc) + } + } +} + +func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) { + log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache") + _ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{ + client: databroker.NewDataBrokerServiceClient(cc), + cache: cw.cache, + }, databroker.WithTypeURL(cw.typeURL)).Run(ctx) +} + +type cacheWarmerSyncerHandler struct { + client databroker.DataBrokerServiceClient + cache storage.Cache +} + +func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return h.client +} + +func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) { + h.cache.InvalidateAll() +} + +func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { + for _, record := range records { + req := &databroker.QueryRequest{ + Type: record.Type, + Limit: 1, + } + req.SetFilterByIDOrIndex(record.Id) + + res := &databroker.QueryResponse{ + Records: []*databroker.Record{record}, + TotalCount: 1, + ServerVersion: serverVersion, + RecordVersion: record.Version, + } + + expiry := time.Now().Add(time.Hour * 24 * 365) + key, err := storage.MarshalQueryRequest(req) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request") + continue + } + value, err := storage.MarshalQueryResponse(res) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response") + continue + } + + h.cache.Set(expiry, key, value) + } +} diff --git a/authorize/cache_warmer_test.go b/authorize/cache_warmer_test.go new file mode 100644 index 000000000..58df29018 --- /dev/null +++ b/authorize/cache_warmer_test.go @@ -0,0 +1,52 @@ +package authorize + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestCacheWarmer(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, 10*time.Minute) + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider())) + }) + t.Cleanup(func() { cc.Close() }) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + _, err := client.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{ + {Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)}, + {Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)}, + }, + }) + require.NoError(t, err) + + cache := storage.NewGlobalCache(time.Minute) + + cw := newCacheWarmer(cc, cache, "example.com/record") + go cw.Run(ctx) + + assert.Eventually(t, func() bool { + req := &databrokerpb.QueryRequest{ + Type: "example.com/record", + Limit: 1, + } + req.SetFilterByIDOrIndex("e1") + res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req) + require.NoError(t, err) + return len(res.GetRecords()) == 1 + }, 10*time.Second, time.Millisecond*100) +} diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index 1c773fc21..eefb5987b 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) { s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} sq := storage.NewStaticQuerier(s1) - tsq := storage.NewTracingQuerier(sq) - cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache()) - tcq := storage.NewTracingQuerier(cq) - qctx := storage.WithQuerier(ctx, tcq) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) assert.NoError(t, err) @@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) { s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) assert.NoError(t, err) assert.NotNil(t, s) - - assert.Len(t, tsq.Traces(), tc.underlyingQueryCount, - "should have %d traces to the underlying querier", tc.underlyingQueryCount) - assert.Len(t, tcq.Traces(), tc.cachedQueryCount, - "should have %d traces to the cached querier", tc.cachedQueryCount) }) } } diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 1523cf04c..e6316de6c 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -241,6 +241,7 @@ var internalPathsNeedingLogin = set.From([]string{ "/.pomerium/jwt", "/.pomerium/user", "/.pomerium/webauthn", + "/.pomerium/api/v1/routes", }) func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) { diff --git a/authorize/grpc.go b/authorize/grpc.go index 8172226f9..f7da2ab0a 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check") defer span.End() - querier := storage.NewTracingQuerier( - storage.NewCachingQuerier( - storage.NewCachingQuerier( - storage.NewQuerier(a.state.Load().dataBrokerClient), - a.globalCache, - ), - storage.NewLocalCache(), - ), + querier := storage.NewCachingQuerier( + storage.NewQuerier(a.state.Load().dataBrokerClient), + a.globalCache, ) ctx = storage.WithQuerier(ctx, querier) @@ -74,10 +69,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe return nil, err } - // take the state lock here so we don't update while evaluating - a.stateLock.RLock() res, err := state.evaluator.Evaluate(ctx, req) - a.stateLock.RUnlock() if err != nil { log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation") return nil, err diff --git a/config/envoyconfig/bootstrap.go b/config/envoyconfig/bootstrap.go index 665f216fb..f97338c44 100644 --- a/config/envoyconfig/bootstrap.go +++ b/config/envoyconfig/bootstrap.go @@ -150,8 +150,8 @@ func (b *Builder) BuildBootstrapDynamicResources( // BuildBootstrapLayeredRuntime builds the layered runtime for the envoy bootstrap. func (b *Builder) BuildBootstrapLayeredRuntime(ctx context.Context) (*envoy_config_bootstrap_v3.LayeredRuntime, error) { - flushIntervalMs := 5000 - minFlushSpans := 3 + flushIntervalMs := trace.BatchSpanProcessorScheduleDelay() + minFlushSpans := trace.BatchSpanProcessorMaxExportBatchSize() if trace.DebugFlagsFromContext(ctx).Check(trace.EnvoyFlushEverySpan) { minFlushSpans = 1 flushIntervalMs = math.MaxInt32 @@ -166,15 +166,12 @@ func (b *Builder) BuildBootstrapLayeredRuntime(ctx context.Context) (*envoy_conf "tracing": map[string]any{ "opentelemetry": map[string]any{ "flush_interval_ms": flushIntervalMs, - // For most requests, envoy generates 3 spans: + // Note: for most requests, envoy generates 3 spans: // - ingress (downstream->envoy) // - ext_authz check request (envoy->pomerium) // - egress (envoy->upstream) - // The default value is 5, which usually leads to delayed exports. - // This can be set lower, e.g. 1 to have envoy export every span - // individually (useful for testing), but 3 is a reasonable default. - // If set to 1, also set flush_interval_ms to a very large number to - // effectively disable it. + // Some requests only generate 2 spans, e.g. if there is no upstream + // request made or auth fails. "min_flush_spans": minFlushSpans, }, }, diff --git a/config/envoyconfig/bootstrap_test.go b/config/envoyconfig/bootstrap_test.go index ec27b7fe7..09d4ab7eb 100644 --- a/config/envoyconfig/bootstrap_test.go +++ b/config/envoyconfig/bootstrap_test.go @@ -51,7 +51,7 @@ func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) { "tracing": { "opentelemetry": { "flush_interval_ms": 5000, - "min_flush_spans": 3 + "min_flush_spans": 512 } } } diff --git a/internal/benchmarks/latency_bench_test.go b/internal/benchmarks/latency_bench_test.go index d25185942..4fc0e3d9a 100644 --- a/internal/benchmarks/latency_bench_test.go +++ b/internal/benchmarks/latency_bench_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/envutil" "github.com/pomerium/pomerium/internal/testenv/scenarios" @@ -18,18 +19,30 @@ import ( ) var ( - numRoutes int - dumpErrLogs bool + numRoutes int + dumpErrLogs bool + enableTracing bool + publicRoutes bool ) func init() { flag.IntVar(&numRoutes, "routes", 100, "number of routes") flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/)") + flag.BoolVar(&enableTracing, "enable-tracing", false, "enable tracing") + flag.BoolVar(&publicRoutes, "public-routes", false, "use public unauthenticated routes") } func TestRequestLatency(t *testing.T) { resume := envutil.PauseProfiling(t) - env := testenv.New(t, testenv.Silent()) + var env testenv.Environment + if enableTracing { + receiver := scenarios.NewOTLPTraceReceiver() + env = testenv.New(t, testenv.Silent(), testenv.WithTraceClient(receiver.NewGRPCClient())) + env.Add(receiver) + } else { + env = testenv.New(t, testenv.Silent()) + } + users := []*scenarios.User{} for i := range numRoutes { users = append(users, &scenarios.User{ @@ -47,9 +60,12 @@ func TestRequestLatency(t *testing.T) { routes := make([]testenv.Route, numRoutes) for i := range numRoutes { routes[i] = up.Route(). - From(env.SubdomainURL(fmt.Sprintf("from-%d", i))). - // Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) - PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i)) + From(env.SubdomainURL(fmt.Sprintf("from-%d", i))) + if publicRoutes { + routes[i] = routes[i].Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) + } else { + routes[i] = routes[i].PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i)) + } } env.AddUpstream(up) diff --git a/internal/telemetry/trace/global.go b/internal/telemetry/trace/global.go index 16b7a5a4f..4f6338dc3 100644 --- a/internal/telemetry/trace/global.go +++ b/internal/telemetry/trace/global.go @@ -2,9 +2,12 @@ package trace import ( "context" + "os" + "strconv" "go.opentelemetry.io/contrib/propagators/autoprop" "go.opentelemetry.io/otel" + sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/embedded" ) @@ -44,3 +47,25 @@ var _ trace.Tracer = panicTracer{} func (p panicTracer) Start(context.Context, string, ...trace.SpanStartOption) (context.Context, trace.Span) { panic("global tracer used") } + +// functions below mimic those with the same name in otel/sdk/internal/env/env.go + +func BatchSpanProcessorScheduleDelay() int { + const defaultValue = sdktrace.DefaultScheduleDelay + if v, ok := os.LookupEnv("OTEL_BSP_SCHEDULE_DELAY"); ok { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return defaultValue +} + +func BatchSpanProcessorMaxExportBatchSize() int { + const defaultValue = sdktrace.DefaultMaxExportBatchSize + if v, ok := os.LookupEnv("OTEL_BSP_MAX_EXPORT_BATCH_SIZE"); ok { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return defaultValue +} diff --git a/internal/telemetry/trace/server.go b/internal/telemetry/trace/server.go index 1bd14c7d0..779455b6d 100644 --- a/internal/telemetry/trace/server.go +++ b/internal/telemetry/trace/server.go @@ -50,7 +50,7 @@ func NewServer(ctx context.Context) *ExporterServer { } func (srv *ExporterServer) Start(ctx context.Context) { - lis := bufconn.Listen(4096) + lis := bufconn.Listen(2 * 1024 * 1024) go func() { if err := srv.remoteClient.Start(ctx); err != nil { panic(err) @@ -95,5 +95,6 @@ func (srv *ExporterServer) Shutdown(ctx context.Context) error { if err := srv.remoteClient.Stop(ctx); err != nil { errs = append(errs, err) } + srv.cc.Close() return errors.Join(errs...) } diff --git a/internal/testenv/scenarios/trace_receiver.go b/internal/testenv/scenarios/trace_receiver.go index 343d6f8f3..f318d59ae 100644 --- a/internal/testenv/scenarios/trace_receiver.go +++ b/internal/testenv/scenarios/trace_receiver.go @@ -77,6 +77,7 @@ func (rec *OTLPTraceReceiver) Attach(ctx context.Context) { // Modify implements testenv.Modifier. func (rec *OTLPTraceReceiver) Modify(cfg *config.Config) { cfg.Options.TracingProvider = "otlp" + cfg.Options.TracingOTLPEndpoint = rec.GRPCEndpointURL().Value() } func (rec *OTLPTraceReceiver) handleV1Traces(w http.ResponseWriter, r *http.Request) { diff --git a/internal/testenv/selftests/tracing_test.go b/internal/testenv/selftests/tracing_test.go index 0a799d872..4e59e1d6c 100644 --- a/internal/testenv/selftests/tracing_test.go +++ b/internal/testenv/selftests/tracing_test.go @@ -12,13 +12,6 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/internal/testenv" - "github.com/pomerium/pomerium/internal/testenv/scenarios" - "github.com/pomerium/pomerium/internal/testenv/snippets" - "github.com/pomerium/pomerium/internal/testenv/upstreams" - . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -27,6 +20,14 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.17.0" oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive ) func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) { @@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) { Exact: true, CheckDetachedSpans: true, }, - Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}}, + Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}}, Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices}, Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)}, Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)}, diff --git a/pkg/storage/cache.go b/pkg/storage/cache.go index ee75cb631..1f0e391b4 100644 --- a/pkg/storage/cache.go +++ b/pkg/storage/cache.go @@ -18,64 +18,8 @@ type Cache interface { update func(ctx context.Context) ([]byte, error), ) ([]byte, error) Invalidate(key []byte) -} - -type localCache struct { - singleflight singleflight.Group - mu sync.RWMutex - m map[string][]byte -} - -// NewLocalCache creates a new Cache backed by a map. -func NewLocalCache() Cache { - return &localCache{ - m: make(map[string][]byte), - } -} - -func (cache *localCache) GetOrUpdate( - ctx context.Context, - key []byte, - update func(ctx context.Context) ([]byte, error), -) ([]byte, error) { - strkey := string(key) - - cache.mu.RLock() - cached, ok := cache.m[strkey] - cache.mu.RUnlock() - if ok { - return cached, nil - } - - v, err, _ := cache.singleflight.Do(strkey, func() (any, error) { - cache.mu.RLock() - cached, ok := cache.m[strkey] - cache.mu.RUnlock() - if ok { - return cached, nil - } - - result, err := update(ctx) - if err != nil { - return nil, err - } - - cache.mu.Lock() - cache.m[strkey] = result - cache.mu.Unlock() - - return result, nil - }) - if err != nil { - return nil, err - } - return v.([]byte), nil -} - -func (cache *localCache) Invalidate(key []byte) { - cache.mu.Lock() - delete(cache.m, string(key)) - cache.mu.Unlock() + InvalidateAll() + Set(expiry time.Time, key, value []byte) } type globalCache struct { @@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate( if err != nil { return nil, err } - cache.set(key, value) + cache.set(time.Now().Add(cache.ttl), key, value) return value, nil }) if err != nil { @@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) { cache.mu.Unlock() } +func (cache *globalCache) InvalidateAll() { + cache.mu.Lock() + cache.fastcache.Reset() + cache.mu.Unlock() +} + +func (cache *globalCache) Set(expiry time.Time, key, value []byte) { + cache.set(expiry, key, value) +} + func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) { cache.mu.RLock() item := cache.fastcache.Get(nil, k) @@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) return data, expiry, true } -func (cache *globalCache) set(k, v []byte) { - unix := time.Now().Add(cache.ttl).UnixMilli() - item := make([]byte, len(v)+8) +func (cache *globalCache) set(expiry time.Time, key, value []byte) { + unix := expiry.UnixMilli() + item := make([]byte, len(value)+8) binary.LittleEndian.PutUint64(item, uint64(unix)) - copy(item[8:], v) + copy(item[8:], value) cache.mu.Lock() - cache.fastcache.Set(k, item) + cache.fastcache.Set(key, item) cache.mu.Unlock() } diff --git a/pkg/storage/cache_test.go b/pkg/storage/cache_test.go index 132dc201a..e33cf8dc7 100644 --- a/pkg/storage/cache_test.go +++ b/pkg/storage/cache_test.go @@ -9,34 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLocalCache(t *testing.T) { - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - defer clearTimeout() - - callCount := 0 - update := func(_ context.Context) ([]byte, error) { - callCount++ - return []byte("v1"), nil - } - c := NewLocalCache() - v, err := c.GetOrUpdate(ctx, []byte("k1"), update) - assert.NoError(t, err) - assert.Equal(t, []byte("v1"), v) - assert.Equal(t, 1, callCount) - - v, err = c.GetOrUpdate(ctx, []byte("k1"), update) - assert.NoError(t, err) - assert.Equal(t, []byte("v1"), v) - assert.Equal(t, 1, callCount) - - c.Invalidate([]byte("k1")) - - v, err = c.GetOrUpdate(ctx, []byte("k1"), update) - assert.NoError(t, err) - assert.Equal(t, []byte("v1"), v) - assert.Equal(t, 2, callCount) -} - func TestGlobalCache(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() @@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) { }) return err != nil }, time.Second, time.Millisecond*10, "should honor TTL") + + c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2")) + v, err = c.GetOrUpdate(ctx, []byte("k2"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v2"), v) + assert.Equal(t, 2, callCount) } diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 1de0b195d..c70247e5c 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "strconv" - "sync" "github.com/google/uuid" grpc "google.golang.org/grpc" @@ -12,7 +11,6 @@ import ( status "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/structpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/pkg/cryptutil" @@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, return q.client.Query(ctx, in, opts...) } -// A TracingQuerier records calls to Query. -type TracingQuerier struct { - underlying Querier - - mu sync.Mutex - traces []QueryTrace -} - -// A QueryTrace traces a call to Query. -type QueryTrace struct { - ServerVersion, RecordVersion uint64 - - RecordType string - Query string - Filter *structpb.Struct -} - -// NewTracingQuerier creates a new TracingQuerier. -func NewTracingQuerier(q Querier) *TracingQuerier { - return &TracingQuerier{ - underlying: q, - } -} - -// InvalidateCache invalidates the cache. -func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) { - q.underlying.InvalidateCache(ctx, in) -} - -// Query queries for records. -func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { - res, err := q.underlying.Query(ctx, in, opts...) - if err == nil { - q.mu.Lock() - q.traces = append(q.traces, QueryTrace{ - RecordType: in.GetType(), - Query: in.GetQuery(), - Filter: in.GetFilter(), - }) - q.mu.Unlock() - } - return res, err -} - -// Traces returns all the traces. -func (q *TracingQuerier) Traces() []QueryTrace { - q.mu.Lock() - traces := make([]QueryTrace, len(q.traces)) - copy(traces, q.traces) - q.mu.Unlock() - return traces -} - type cachingQuerier struct { q Querier cache Cache @@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que } func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { - key, err := (&proto.MarshalOptions{ - Deterministic: true, - }).Marshal(in) + key, err := MarshalQueryRequest(in) if err != nil { return nil, err } @@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, if err != nil { return nil, err } - return proto.Marshal(res) + return MarshalQueryResponse(res) }) if err != nil { return nil, err @@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, } return &res, nil } + +// MarshalQueryRequest marshales the query request. +func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) { + return (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(req) +} + +// MarshalQueryResponse marshals the query response. +func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) { + return (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(res) +} diff --git a/proxy/handlers.go b/proxy/handlers.go index db5dd2179..6e9411ac8 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -40,11 +40,32 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) * c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet) // Programmatic API handlers and middleware - a := r.PathPrefix(dashboardPath + "/api").Subrouter() - // login api handler generates a user-navigable login url to authenticate - a.Path("/v1/login").Handler(httputil.HandlerFunc(p.ProgrammaticLogin)). - Queries(urlutil.QueryRedirectURI, ""). - Methods(http.MethodGet) + // gorilla mux has a bug that prevents HTTP 405 errors from being returned properly so we do all this manually + // https://github.com/gorilla/mux/issues/739 + r.PathPrefix(dashboardPath + "/api"). + Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + switch r.URL.Path { + // login api handler generates a user-navigable login url to authenticate + case dashboardPath + "/api/v1/login": + if r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return nil + } + if !r.URL.Query().Has(urlutil.QueryRedirectURI) { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return nil + } + return p.ProgrammaticLogin(w, r) + case dashboardPath + "/api/v1/routes": + if r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return nil + } + return p.routesPortalJSON(w, r) + } + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return nil + })) return r } diff --git a/proxy/handlers_portal.go b/proxy/handlers_portal.go new file mode 100644 index 000000000..7d28aed06 --- /dev/null +++ b/proxy/handlers_portal.go @@ -0,0 +1,56 @@ +package proxy + +import ( + "encoding/json" + "net/http" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/proxy/portal" +) + +func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { + u := p.getUserInfoData(r) + rs := p.getPortalRoutes(u) + m := map[string]any{} + m["routes"] = rs + + b, err := json.Marshal(m) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(b) + return nil +} + +func (p *Proxy) getPortalRoutes(u handlers.UserInfoData) []portal.Route { + options := p.currentOptions.Load() + pu := p.getPortalUser(u) + var routes []*config.Policy + for route := range options.GetAllPolicies() { + if portal.CheckRouteAccess(pu, route) { + routes = append(routes, route) + } + } + return portal.RoutesFromConfigRoutes(routes) +} + +func (p *Proxy) getPortalUser(u handlers.UserInfoData) portal.User { + pu := portal.User{} + pu.SessionID = u.Session.GetId() + pu.UserID = u.User.GetId() + pu.Email = u.User.GetEmail() + for _, dg := range u.DirectoryGroups { + if v := dg.ID; v != "" { + pu.Groups = append(pu.Groups, dg.ID) + } + if v := dg.Name; v != "" { + pu.Groups = append(pu.Groups, dg.Name) + } + } + return pu +} diff --git a/proxy/handlers_portal_test.go b/proxy/handlers_portal_test.go new file mode 100644 index 000000000..c0a885516 --- /dev/null +++ b/proxy/handlers_portal_test.go @@ -0,0 +1,51 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/httputil" +) + +func TestProxy_routesPortalJSON(t *testing.T) { + ctx := context.Background() + cfg := &config.Config{Options: config.NewDefaultOptions()} + to, err := config.ParseWeightedUrls("https://to.example.com") + require.NoError(t, err) + cfg.Options.Routes = append(cfg.Options.Routes, config.Policy{ + Name: "public", + Description: "PUBLIC ROUTE", + LogoURL: "https://logo.example.com", + From: "https://from.example.com", + To: to, + AllowPublicUnauthenticatedAccess: true, + }) + proxy, err := New(ctx, cfg) + require.NoError(t, err) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/api/v1/routes", nil) + w := httptest.NewRecorder() + + router := httputil.NewRouter() + router = proxy.registerDashboardHandlers(router, cfg.Options) + router.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.JSONEq(t, `{"routes":[ + { + "id": "4e71df99c0317efb", + "name": "public", + "from": "https://from.example.com", + "type": "http", + "description": "PUBLIC ROUTE", + "logo_url": "https://logo.example.com" + } + ]}`, w.Body.String()) +} diff --git a/proxy/portal/filter.go b/proxy/portal/filter.go new file mode 100644 index 000000000..f63966e50 --- /dev/null +++ b/proxy/portal/filter.go @@ -0,0 +1,105 @@ +package portal + +import ( + "strings" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/policy/parser" +) + +// User is the computed user information needed for access decisions. +type User struct { + SessionID string + UserID string + Email string + Groups []string +} + +// CheckRouteAccess checks if the user has access to the route. +func CheckRouteAccess(user User, route *config.Policy) bool { + // check the main policy + ppl := route.ToPPL() + if checkPPLAccess(user, ppl) { + return true + } + + // check sub-policies + for _, sp := range route.SubPolicies { + if sp.SourcePPL == "" { + continue + } + + ppl, err := parser.New().ParseYAML(strings.NewReader(sp.SourcePPL)) + if err != nil { + // ignore invalid PPL + continue + } + + if checkPPLAccess(user, ppl) { + return true + } + } + + // nothing matched + return false +} + +func checkPPLAccess(user User, ppl *parser.Policy) bool { + for _, r := range ppl.Rules { + // ignore deny rules + if r.Action != parser.ActionAllow { + continue + } + + // ignore complex rules + if len(r.Nor) > 0 || len(r.Not) > 0 || len(r.And) > 1 { + continue + } + + cs := append(append([]parser.Criterion{}, r.Or...), r.And...) + for _, c := range cs { + ok := checkPPLCriterionAccess(user, c) + if ok { + return true + } + } + } + + return false +} + +func checkPPLCriterionAccess(user User, criterion parser.Criterion) bool { + switch criterion.Name { + case "accept": + return true + } + + // require a session + if user.SessionID == "" { + return false + } + + switch criterion.Name { + case "authenticated_user": + return true + } + + // require a user + if user.UserID == "" { + return false + } + + switch criterion.Name { + case "domain": + parts := strings.SplitN(user.Email, "@", 2) + return len(parts) == 2 && matchString(parts[1], criterion.Data) + case "email": + return matchString(user.Email, criterion.Data) + case "groups": + return matchStringList(user.Groups, criterion.Data) + case "user": + return matchString(user.UserID, criterion.Data) + } + + return false +} diff --git a/proxy/portal/filter_test.go b/proxy/portal/filter_test.go new file mode 100644 index 000000000..9a9d1519a --- /dev/null +++ b/proxy/portal/filter_test.go @@ -0,0 +1,67 @@ +package portal_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/policy/parser" + "github.com/pomerium/pomerium/proxy/portal" +) + +func TestCheckRouteAccess(t *testing.T) { + t.Parallel() + + u1 := portal.User{} + u2 := portal.User{SessionID: "s2", UserID: "u2", Email: "u2@example.com", Groups: []string{"g2"}} + + for _, tc := range []struct { + name string + user portal.User + route *config.Policy + }{ + {"no ppl", u1, &config.Policy{}}, + {"allow_any_authenticated_user", u1, &config.Policy{AllowAnyAuthenticatedUser: true}}, + {"allowed_domains", u2, &config.Policy{AllowedDomains: []string{"not.example.com"}}}, + {"allowed_users", u2, &config.Policy{AllowedUsers: []string{"u3"}}}, + {"not conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"not": [{"accept": 1}]}}`)}}, + {"nor conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"nor": [{"accept": 1}]}}`)}}, + {"and conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"and": [{"accept": 1}, {"accept": 1}]}}`)}}, + {"authenticated_user", u1, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"authenticated_user": 1}]}}`)}}, + {"domain", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"domain": "not.example.com"}]}}`)}}, + {"email", u1, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"email": "u2@example.com"}]}}`)}}, + {"groups", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"groups": {"has": "g3"}}]}}`)}}, + } { + assert.False(t, portal.CheckRouteAccess(tc.user, tc.route), "%s: should deny access for %v to %v", + tc.name, tc.user, tc.route) + } + + for _, tc := range []struct { + name string + user portal.User + route *config.Policy + }{ + {"allow_public_unauthenticated_access", u1, &config.Policy{AllowPublicUnauthenticatedAccess: true}}, + {"allow_any_authenticated_user", u2, &config.Policy{AllowAnyAuthenticatedUser: true}}, + {"allowed_domains", u2, &config.Policy{AllowedDomains: []string{"example.com"}}}, + {"allowed_users", u2, &config.Policy{AllowedUsers: []string{"u2"}}}, + {"and conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"and": [{"accept": 1}]}}`)}}, + {"or conditionals", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"reject": 1}, {"accept": 1}]}}`)}}, + {"authenticated_user", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"authenticated_user": 1}]}}`)}}, + {"domain", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"domain": "example.com"}]}}`)}}, + {"email", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"email": "u2@example.com"}]}}`)}}, + {"groups", u2, &config.Policy{Policy: mustParsePPL(t, `{"allow": {"or": [{"groups": {"has": "g2"}}]}}`)}}, + } { + assert.True(t, portal.CheckRouteAccess(tc.user, tc.route), "%s: should grant access for %v to %v", + tc.name, tc.user, tc.route) + } +} + +func mustParsePPL(t testing.TB, raw string) *config.PPLPolicy { + ppl, err := parser.New().ParseJSON(strings.NewReader(raw)) + require.NoError(t, err) + return &config.PPLPolicy{Policy: ppl} +} diff --git a/proxy/portal/matchers.go b/proxy/portal/matchers.go new file mode 100644 index 000000000..fad178e74 --- /dev/null +++ b/proxy/portal/matchers.go @@ -0,0 +1,108 @@ +package portal + +import ( + "strings" + + "github.com/pomerium/pomerium/pkg/policy/parser" +) + +type matcher[T any] func(left T, right parser.Value) bool + +var stringMatchers = map[string]matcher[string]{ + "contains": matchStringContains, + "ends_with": matchStringEndsWith, + "is": matchStringIs, + "starts_with": matchStringStartsWith, +} + +var stringListMatchers = map[string]matcher[[]string]{ + "has": matchStringListHas, + "is": matchStringListIs, +} + +func matchString(left string, right parser.Value) bool { + obj, ok := right.(parser.Object) + if !ok { + obj = parser.Object{ + "is": right, + } + } + + for k, v := range obj { + f, ok := stringMatchers[k] + if !ok { + return false + } + ok = f(left, v) + if ok { + return true + } + } + return false +} + +func matchStringContains(left string, right parser.Value) bool { + str, ok := right.(parser.String) + if !ok { + return false + } + return strings.Contains(left, string(str)) +} + +func matchStringEndsWith(left string, right parser.Value) bool { + str, ok := right.(parser.String) + if !ok { + return false + } + return strings.HasSuffix(left, string(str)) +} + +func matchStringIs(left string, right parser.Value) bool { + str, ok := right.(parser.String) + if !ok { + return false + } + return left == string(str) +} + +func matchStringStartsWith(left string, right parser.Value) bool { + str, ok := right.(parser.String) + if !ok { + return false + } + return strings.HasPrefix(left, string(str)) +} + +func matchStringList(left []string, right parser.Value) bool { + obj, ok := right.(parser.Object) + if !ok { + obj = parser.Object{ + "has": right, + } + } + + for k, v := range obj { + f, ok := stringListMatchers[k] + if !ok { + return false + } + ok = f(left, v) + if ok { + return true + } + } + return false +} + +func matchStringListHas(left []string, right parser.Value) bool { + for _, str := range left { + if matchStringIs(str, right) { + return true + } + } + return false +} + +func matchStringListIs(left []string, right parser.Value) bool { + return len(left) == 1 && matchStringListHas(left, right) +} diff --git a/proxy/portal/matchers_test.go b/proxy/portal/matchers_test.go new file mode 100644 index 000000000..bf1a9beab --- /dev/null +++ b/proxy/portal/matchers_test.go @@ -0,0 +1,83 @@ +package portal + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/pkg/policy/parser" +) + +func Test_matchString(t *testing.T) { + t.Parallel() + + t.Run("string", func(t *testing.T) { + assert.True(t, matchString("TEST", mustParseValue(t, `"TEST"`))) + }) + t.Run("bool", func(t *testing.T) { + assert.False(t, matchString("true", mustParseValue(t, `true`))) + }) + t.Run("number", func(t *testing.T) { + assert.False(t, matchString("1", mustParseValue(t, `1`))) + }) + t.Run("null", func(t *testing.T) { + assert.False(t, matchString("null", mustParseValue(t, `null`))) + }) + t.Run("array", func(t *testing.T) { + assert.False(t, matchString("[]", mustParseValue(t, `[]`))) + }) + t.Run("contains", func(t *testing.T) { + assert.True(t, matchString("XYZ", mustParseValue(t, `{"contains":"Y"}`))) + assert.False(t, matchString("XYZ", mustParseValue(t, `{"contains":"A"}`))) + }) + t.Run("ends_with", func(t *testing.T) { + assert.True(t, matchString("XYZ", mustParseValue(t, `{"ends_with":"Z"}`))) + assert.False(t, matchString("XYZ", mustParseValue(t, `{"ends_with":"X"}`))) + }) + t.Run("is", func(t *testing.T) { + assert.True(t, matchString("XYZ", mustParseValue(t, `{"is":"XYZ"}`))) + assert.False(t, matchString("XYZ", mustParseValue(t, `{"is":"X"}`))) + }) + t.Run("starts_with", func(t *testing.T) { + assert.True(t, matchString("XYZ", mustParseValue(t, `{"starts_with":"X"}`))) + assert.False(t, matchString("XYZ", mustParseValue(t, `{"starts_with":"Z"}`))) + }) +} + +func Test_matchStringList(t *testing.T) { + t.Parallel() + + t.Run("string", func(t *testing.T) { + assert.True(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `"Y"`))) + assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `"A"`))) + }) + t.Run("bool", func(t *testing.T) { + assert.False(t, matchStringList([]string{"true"}, mustParseValue(t, `true`))) + }) + t.Run("number", func(t *testing.T) { + assert.False(t, matchStringList([]string{"1"}, mustParseValue(t, `1`))) + }) + t.Run("null", func(t *testing.T) { + assert.False(t, matchStringList([]string{"null"}, mustParseValue(t, `null`))) + }) + t.Run("array", func(t *testing.T) { + assert.False(t, matchStringList([]string{"[]"}, mustParseValue(t, `[]`))) + }) + t.Run("has", func(t *testing.T) { + assert.True(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"has":"Y"}`))) + assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"has":"A"}`))) + }) + t.Run("is", func(t *testing.T) { + assert.True(t, matchStringList([]string{"X"}, mustParseValue(t, `{"is":"X"}`))) + assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"is":"Y"}`))) + assert.False(t, matchStringList([]string{"X", "Y", "Z"}, mustParseValue(t, `{"is":"A"}`))) + }) +} + +func mustParseValue(t testing.TB, raw string) parser.Value { + v, err := parser.ParseValue(strings.NewReader(raw)) + require.NoError(t, err) + return v +} diff --git a/proxy/portal/portal.go b/proxy/portal/portal.go new file mode 100644 index 000000000..3756d3638 --- /dev/null +++ b/proxy/portal/portal.go @@ -0,0 +1,60 @@ +// Package portal contains the code for the routes portal +package portal + +import ( + "fmt" + "strings" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/zero/importutil" +) + +// A Route is a portal route. +type Route struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + From string `json:"from"` + Description string `json:"description"` + ConnectCommand string `json:"connect_command,omitempty"` + LogoURL string `json:"logo_url"` +} + +// RoutesFromConfigRoutes converts config routes into portal routes. +func RoutesFromConfigRoutes(routes []*config.Policy) []Route { + prs := make([]Route, len(routes)) + for i, route := range routes { + pr := Route{} + pr.ID = route.ID + if pr.ID == "" { + pr.ID = fmt.Sprintf("%x", route.MustRouteID()) + } + pr.Name = route.Name + pr.From = route.From + fromURL, err := urlutil.ParseAndValidateURL(route.From) + if err == nil { + if strings.HasPrefix(fromURL.Scheme, "tcp+") { + pr.Type = "tcp" + pr.ConnectCommand = "pomerium-cli tcp " + fromURL.Host + } else if strings.HasPrefix(fromURL.Scheme, "udp+") { + pr.Type = "udp" + pr.ConnectCommand = "pomerium-cli udp " + fromURL.Host + } else { + pr.Type = "http" + } + } else { + pr.Type = "http" + } + pr.Description = route.Description + pr.LogoURL = route.LogoURL + prs[i] = pr + } + // generate names if they're empty + for i, name := range importutil.GenerateRouteNames(routes) { + if prs[i].Name == "" { + prs[i].Name = name + } + } + return prs +} diff --git a/proxy/portal/portal_test.go b/proxy/portal/portal_test.go new file mode 100644 index 000000000..5b6887f29 --- /dev/null +++ b/proxy/portal/portal_test.go @@ -0,0 +1,71 @@ +package portal_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/proxy/portal" +) + +func TestRouteFromConfigRoute(t *testing.T) { + t.Parallel() + + to1, err := config.ParseWeightedUrls("https://to.example.com") + require.NoError(t, err) + to2, err := config.ParseWeightedUrls("tcp://postgres:5432") + require.NoError(t, err) + + assert.Equal(t, []portal.Route{ + { + ID: "4e71df99c0317efb", + Name: "from", + Type: "http", + From: "https://from.example.com", + Description: "ROUTE #1", + LogoURL: "https://logo.example.com", + }, + { + ID: "7c377f11cdb9700e", + Name: "from-path", + Type: "http", + From: "https://from.example.com", + }, + { + ID: "708e3cbd0bbe8547", + Name: "postgres", + Type: "tcp", + From: "tcp+https://postgres.example.com:5432", + ConnectCommand: "pomerium-cli tcp postgres.example.com:5432", + }, + { + ID: "2dd08d87486e051a", + Name: "dns", + Type: "udp", + From: "udp+https://dns.example.com:53", + ConnectCommand: "pomerium-cli udp dns.example.com:53", + }, + }, portal.RoutesFromConfigRoutes([]*config.Policy{ + { + From: "https://from.example.com", + To: to1, + Description: "ROUTE #1", + LogoURL: "https://logo.example.com", + }, + { + From: "https://from.example.com", + To: to1, + Path: "/path", + }, + { + From: "tcp+https://postgres.example.com:5432", + To: to2, + }, + { + From: "udp+https://dns.example.com:53", + To: to2, + }, + })) +}