mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 19:32:48 +02:00
Merge remote-tracking branch 'origin/main' into experimental/ssh
This commit is contained in:
commit
8eff4a48a4
94 changed files with 1563 additions and 468 deletions
|
@ -10,8 +10,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,11 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/oidc"
|
"github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler returns the authenticate service's handler chain.
|
// Handler returns the authenticate service's handler chain.
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
@ -22,7 +21,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -30,6 +28,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authorize struct holds
|
// Authorize struct holds
|
||||||
|
@ -38,7 +37,6 @@ type Authorize struct {
|
||||||
store *store.Store
|
store *store.Store
|
||||||
currentConfig *atomicutil.Value[*config.Config]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
globalCache storage.Cache
|
|
||||||
groupsCacheWarmer *cacheWarmer
|
groupsCacheWarmer *cacheWarmer
|
||||||
|
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
|
@ -54,7 +52,6 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
globalCache: storage.NewGlobalCache(time.Minute),
|
|
||||||
tracerProvider: tracerProvider,
|
tracerProvider: tracerProvider,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
activeStreams: ActiveStreams{
|
activeStreams: ActiveStreams{
|
||||||
|
@ -69,7 +66,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
}
|
}
|
||||||
a.state = atomicutil.NewValue(state)
|
a.state = atomicutil.NewValue(state)
|
||||||
|
|
||||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
|
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,7 +195,7 @@ func (a *Authorize) evaluate(
|
||||||
) (*evaluator.Result, error) {
|
) (*evaluator.Result, error) {
|
||||||
querier := storage.NewCachingQuerier(
|
querier := storage.NewCachingQuerier(
|
||||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||||
a.globalCache,
|
storage.GlobalCache,
|
||||||
)
|
)
|
||||||
ctx = storage.WithQuerier(ctx, querier)
|
ctx = storage.WithQuerier(ctx, querier)
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,6 @@ package authorize
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
@ -18,47 +15,6 @@ type sessionOrServiceAccount interface {
|
||||||
Validate() error
|
Validate() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDataBrokerRecord(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
recordID string,
|
|
||||||
lowestRecordVersion uint64,
|
|
||||||
) (*databroker.Record, error) {
|
|
||||||
q := storage.GetQuerier(ctx)
|
|
||||||
|
|
||||||
req := &databroker.QueryRequest{
|
|
||||||
Type: recordType,
|
|
||||||
Limit: 1,
|
|
||||||
}
|
|
||||||
req.SetFilterByIDOrIndex(recordID)
|
|
||||||
|
|
||||||
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res.GetRecords()) == 0 {
|
|
||||||
return nil, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
|
||||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
|
||||||
q.InvalidateCache(ctx, req)
|
|
||||||
} else {
|
|
||||||
return res.GetRecords()[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// retry with an up to date cache
|
|
||||||
res, err = q.Query(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res.GetRecords()) == 0 {
|
|
||||||
return nil, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.GetRecords()[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID string,
|
sessionID string,
|
||||||
|
@ -67,9 +23,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
||||||
if storage.IsNotFound(err) {
|
if storage.IsNotFound(err) {
|
||||||
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
record, err = storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -100,7 +56,7 @@ func (a *Authorize) getDataBrokerUser(
|
||||||
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
|
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,45 +11,9 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_getDataBrokerRecord(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
t.Cleanup(clearTimeout)
|
|
||||||
|
|
||||||
for _, tc := range []struct {
|
|
||||||
name string
|
|
||||||
recordVersion, queryVersion uint64
|
|
||||||
underlyingQueryCount, cachedQueryCount int
|
|
||||||
}{
|
|
||||||
{"cached", 1, 1, 1, 2},
|
|
||||||
{"invalidated", 1, 2, 3, 4},
|
|
||||||
} {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
|
||||||
|
|
||||||
sq := storage.NewStaticQuerier(s1)
|
|
||||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
|
||||||
qctx := storage.WithQuerier(ctx, cq)
|
|
||||||
|
|
||||||
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, s)
|
|
||||||
|
|
||||||
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, s)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,10 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/errgrouputil"
|
"github.com/pomerium/pomerium/internal/errgrouputil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Request contains the inputs needed for evaluation.
|
// Request contains the inputs needed for evaluation.
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HeadersResponse is the output from the headers.rego script.
|
// HeadersResponse is the output from the headers.rego script.
|
||||||
|
|
|
@ -11,11 +11,11 @@ import (
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/policy"
|
"github.com/pomerium/pomerium/pkg/policy"
|
||||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PolicyRequest is the input to policy evaluation.
|
// PolicyRequest is the input to policy evaluation.
|
||||||
|
|
|
@ -31,6 +31,12 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
|
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
querier := storage.NewCachingQuerier(
|
||||||
|
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||||
|
storage.GlobalCache,
|
||||||
|
)
|
||||||
|
ctx = storage.WithQuerier(ctx, querier)
|
||||||
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
// convert the incoming envoy-style http request into a go-style http request
|
// convert the incoming envoy-style http request into a go-style http request
|
||||||
|
@ -69,7 +75,7 @@ func (a *Authorize) loadSession(
|
||||||
// attempt to create a session from an incoming idp token
|
// attempt to create a session from an incoming idp token
|
||||||
s, err = config.NewIncomingIDPTokenSessionCreator(
|
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||||
return getDataBrokerRecord(ctx, recordType, recordID, 0)
|
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||||
},
|
},
|
||||||
func(ctx context.Context, records []*databroker.Record) error {
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
|
@ -78,15 +84,7 @@ func (a *Authorize) loadSession(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// invalidate cache
|
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||||
for _, record := range records {
|
|
||||||
q := &databroker.QueryRequest{
|
|
||||||
Type: record.GetType(),
|
|
||||||
Limit: 1,
|
|
||||||
}
|
|
||||||
q.SetFilterByIDOrIndex(record.GetId())
|
|
||||||
storage.GetQuerier(ctx).InvalidateCache(ctx, q)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||||
|
@ -97,6 +95,7 @@ func (a *Authorize) loadSession(
|
||||||
Str("request-id", requestID).
|
Str("request-id", requestID).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("error creating session for incoming idp token")
|
Msg("error creating session for incoming idp token")
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||||
|
|
|
@ -21,9 +21,9 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Store stores data for the OPA rego policy evaluation.
|
// A Store stores data for the OPA rego policy evaluation.
|
||||||
|
|
|
@ -184,7 +184,7 @@ func (a *Authorize) ManageStream(
|
||||||
// XXX
|
// XXX
|
||||||
querier := storage.NewCachingQuerier(
|
querier := storage.NewCachingQuerier(
|
||||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||||
a.globalCache,
|
storage.GlobalCache,
|
||||||
)
|
)
|
||||||
ctx = storage.WithQuerier(ctx, querier)
|
ctx = storage.WithQuerier(ctx, querier)
|
||||||
|
|
||||||
|
|
|
@ -12,13 +12,13 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
_ "github.com/pomerium/pomerium/internal/zero/bootstrap/writers/filesystem"
|
_ "github.com/pomerium/pomerium/internal/zero/bootstrap/writers/filesystem"
|
||||||
_ "github.com/pomerium/pomerium/internal/zero/bootstrap/writers/k8s"
|
_ "github.com/pomerium/pomerium/internal/zero/bootstrap/writers/k8s"
|
||||||
zero_cmd "github.com/pomerium/pomerium/internal/zero/cmd"
|
zero_cmd "github.com/pomerium/pomerium/internal/zero/cmd"
|
||||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBuilder_buildACMETLSALPNCluster(t *testing.T) {
|
func TestBuilder_buildACMETLSALPNCluster(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
b := New("local-grpc", "local-http", "local-metrics", nil, nil, true)
|
||||||
testutil.AssertProtoJSONEqual(t,
|
testutil.AssertProtoJSONEqual(t,
|
||||||
`{
|
`{
|
||||||
"name": "pomerium-acme-tls-alpn",
|
"name": "pomerium-acme-tls-alpn",
|
||||||
|
@ -34,7 +34,7 @@ func TestBuilder_buildACMETLSALPNCluster(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuilder_buildACMETLSALPNFilterChain(t *testing.T) {
|
func TestBuilder_buildACMETLSALPNFilterChain(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
b := New("local-grpc", "local-http", "local-metrics", nil, nil, true)
|
||||||
testutil.AssertProtoJSONEqual(t,
|
testutil.AssertProtoJSONEqual(t,
|
||||||
`{
|
`{
|
||||||
"filterChainMatch": {
|
"filterChainMatch": {
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/config/otelconfig"
|
"github.com/pomerium/pomerium/config/otelconfig"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
|
|
||||||
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||||
t.Setenv("TMPDIR", "/tmp")
|
t.Setenv("TMPDIR", "/tmp")
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
||||||
Options: &config.Options{
|
Options: &config.Options{
|
||||||
|
@ -35,7 +35,7 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) {
|
func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) {
|
||||||
b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil)
|
b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true)
|
||||||
staticCfg, err := b.BuildBootstrapLayeredRuntime(context.Background(), &config.Config{})
|
staticCfg, err := b.BuildBootstrapLayeredRuntime(context.Background(), &config.Config{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
@ -61,7 +61,7 @@ func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) {
|
||||||
|
|
||||||
func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
|
func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil)
|
b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true)
|
||||||
staticCfg, err := b.BuildBootstrapStaticResources(context.Background(), &config.Config{}, false)
|
staticCfg, err := b.BuildBootstrapStaticResources(context.Background(), &config.Config{}, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
@ -105,14 +105,14 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
|
||||||
`, staticCfg)
|
`, staticCfg)
|
||||||
})
|
})
|
||||||
t.Run("bad gRPC address", func(t *testing.T) {
|
t.Run("bad gRPC address", func(t *testing.T) {
|
||||||
b := New("xyz:zyx", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil)
|
b := New("xyz:zyx", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true)
|
||||||
_, err := b.BuildBootstrapStaticResources(context.Background(), &config.Config{}, false)
|
_, err := b.BuildBootstrapStaticResources(context.Background(), &config.Config{}, false)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
|
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{
|
statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{
|
||||||
Options: &config.Options{
|
Options: &config.Options{
|
||||||
|
@ -132,7 +132,7 @@ func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuilder_BuildBootstrap(t *testing.T) {
|
func TestBuilder_BuildBootstrap(t *testing.T) {
|
||||||
b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil)
|
b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true)
|
||||||
t.Run("OverloadManager", func(t *testing.T) {
|
t.Run("OverloadManager", func(t *testing.T) {
|
||||||
bootstrap, err := b.BuildBootstrap(context.Background(), &config.Config{
|
bootstrap, err := b.BuildBootstrap(context.Background(), &config.Config{
|
||||||
Options: &config.Options{
|
Options: &config.Options{
|
||||||
|
|
|
@ -7,11 +7,12 @@ import (
|
||||||
|
|
||||||
// A Builder builds envoy config from pomerium config.
|
// A Builder builds envoy config from pomerium config.
|
||||||
type Builder struct {
|
type Builder struct {
|
||||||
localGRPCAddress string
|
localGRPCAddress string
|
||||||
localHTTPAddress string
|
localHTTPAddress string
|
||||||
localMetricsAddress string
|
localMetricsAddress string
|
||||||
filemgr *filemgr.Manager
|
filemgr *filemgr.Manager
|
||||||
reproxy *reproxy.Handler
|
reproxy *reproxy.Handler
|
||||||
|
addIPV6InternalRanges bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Builder.
|
// New creates a new Builder.
|
||||||
|
@ -21,15 +22,17 @@ func New(
|
||||||
localMetricsAddress string,
|
localMetricsAddress string,
|
||||||
fileManager *filemgr.Manager,
|
fileManager *filemgr.Manager,
|
||||||
reproxyHandler *reproxy.Handler,
|
reproxyHandler *reproxy.Handler,
|
||||||
|
addIPV6InternalRanges bool,
|
||||||
) *Builder {
|
) *Builder {
|
||||||
if reproxyHandler == nil {
|
if reproxyHandler == nil {
|
||||||
reproxyHandler = reproxy.New()
|
reproxyHandler = reproxy.New()
|
||||||
}
|
}
|
||||||
return &Builder{
|
return &Builder{
|
||||||
localGRPCAddress: localGRPCAddress,
|
localGRPCAddress: localGRPCAddress,
|
||||||
localHTTPAddress: localHTTPAddress,
|
localHTTPAddress: localHTTPAddress,
|
||||||
localMetricsAddress: localMetricsAddress,
|
localMetricsAddress: localMetricsAddress,
|
||||||
filemgr: fileManager,
|
filemgr: fileManager,
|
||||||
reproxy: reproxyHandler,
|
reproxy: reproxyHandler,
|
||||||
|
addIPV6InternalRanges: addIPV6InternalRanges,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,8 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BuildClusters builds envoy clusters from the given config.
|
// BuildClusters builds envoy clusters from the given config.
|
||||||
|
|
|
@ -27,7 +27,7 @@ func Test_BuildClusters(t *testing.T) {
|
||||||
|
|
||||||
opts := config.NewDefaultOptions()
|
opts := config.NewDefaultOptions()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
clusters, err := b.BuildClusters(ctx, &config.Config{Options: opts})
|
clusters, err := b.BuildClusters(ctx, &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONFileEqual(t, "testdata/clusters.json", clusters)
|
testutil.AssertProtoJSONFileEqual(t, "testdata/clusters.json", clusters)
|
||||||
|
@ -38,7 +38,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
cacheDir, _ := os.UserCacheDir()
|
cacheDir, _ := os.UserCacheDir()
|
||||||
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3133535332543131503345494c.pem")
|
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3133535332543131503345494c.pem")
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
|
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
|
|
||||||
|
@ -517,7 +517,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
|
|
||||||
func Test_buildCluster(t *testing.T) {
|
func Test_buildCluster(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
|
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
o1 := config.NewDefaultOptions()
|
o1 := config.NewDefaultOptions()
|
||||||
|
@ -1012,7 +1012,7 @@ func Test_bindConfig(t *testing.T) {
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
t.Run("no bind config", func(t *testing.T) {
|
t.Run("no bind config", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{
|
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
|
|
|
@ -44,10 +44,10 @@ func ExtAuthzFilter(grpcClientTimeout *durationpb.Duration) *envoy_extensions_fi
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPConnectionManagerFilter creates a new HTTP connection manager filter.
|
// HTTPConnectionManagerFilter creates a new HTTP connection manager filter.
|
||||||
func HTTPConnectionManagerFilter(
|
func (b *Builder) HTTPConnectionManagerFilter(
|
||||||
httpConnectionManager *envoy_extensions_filters_network_http_connection_manager.HttpConnectionManager,
|
httpConnectionManager *envoy_extensions_filters_network_http_connection_manager.HttpConnectionManager,
|
||||||
) *envoy_config_listener_v3.Filter {
|
) *envoy_config_listener_v3.Filter {
|
||||||
applyGlobalHTTPConnectionManagerOptions(httpConnectionManager)
|
b.applyGlobalHTTPConnectionManagerOptions(httpConnectionManager)
|
||||||
return &envoy_config_listener_v3.Filter{
|
return &envoy_config_listener_v3.Filter{
|
||||||
Name: "envoy.filters.network.http_connection_manager",
|
Name: "envoy.filters.network.http_connection_manager",
|
||||||
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
|
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
|
||||||
|
|
|
@ -128,23 +128,29 @@ func (b *Builder) buildLocalReplyConfig(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyGlobalHTTPConnectionManagerOptions(hcm *envoy_http_connection_manager.HttpConnectionManager) {
|
func (b *Builder) applyGlobalHTTPConnectionManagerOptions(hcm *envoy_http_connection_manager.HttpConnectionManager) {
|
||||||
if hcm.InternalAddressConfig == nil {
|
if hcm.InternalAddressConfig == nil {
|
||||||
// see doc comment on InternalAddressConfig for details
|
ranges := []*envoy_config_core_v3.CidrRange{
|
||||||
hcm.InternalAddressConfig = &envoy_http_connection_manager.HttpConnectionManager_InternalAddressConfig{
|
// localhost
|
||||||
CidrRanges: []*envoy_config_core_v3.CidrRange{
|
{AddressPrefix: "127.0.0.1", PrefixLen: wrapperspb.UInt32(32)},
|
||||||
// localhost
|
|
||||||
{AddressPrefix: "127.0.0.1", PrefixLen: wrapperspb.UInt32(32)},
|
// RFC1918
|
||||||
|
{AddressPrefix: "10.0.0.0", PrefixLen: wrapperspb.UInt32(8)},
|
||||||
|
{AddressPrefix: "192.168.0.0", PrefixLen: wrapperspb.UInt32(16)},
|
||||||
|
{AddressPrefix: "172.16.0.0", PrefixLen: wrapperspb.UInt32(12)},
|
||||||
|
}
|
||||||
|
if b.addIPV6InternalRanges {
|
||||||
|
ranges = append(ranges, []*envoy_config_core_v3.CidrRange{
|
||||||
|
// Localhost IPv6
|
||||||
{AddressPrefix: "::1", PrefixLen: wrapperspb.UInt32(128)},
|
{AddressPrefix: "::1", PrefixLen: wrapperspb.UInt32(128)},
|
||||||
|
|
||||||
// RFC1918
|
|
||||||
{AddressPrefix: "10.0.0.0", PrefixLen: wrapperspb.UInt32(8)},
|
|
||||||
{AddressPrefix: "192.168.0.0", PrefixLen: wrapperspb.UInt32(16)},
|
|
||||||
{AddressPrefix: "172.16.0.0", PrefixLen: wrapperspb.UInt32(12)},
|
|
||||||
|
|
||||||
// RFC4193
|
// RFC4193
|
||||||
{AddressPrefix: "fd00::", PrefixLen: wrapperspb.UInt32(8)},
|
{AddressPrefix: "fd00::", PrefixLen: wrapperspb.UInt32(8)},
|
||||||
},
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// see doc comment on InternalAddressConfig for details
|
||||||
|
hcm.InternalAddressConfig = &envoy_http_connection_manager.HttpConnectionManager_InternalAddressConfig{
|
||||||
|
CidrRanges: ranges,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
const listenerBufferLimit uint32 = 32 * 1024
|
const listenerBufferLimit uint32 = 32 * 1024
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (b *Builder) buildEnvoyAdminHTTPConnectionManagerFilter() *envoy_config_lis
|
||||||
},
|
},
|
||||||
}})
|
}})
|
||||||
|
|
||||||
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
||||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||||
StatPrefix: "envoy-admin",
|
StatPrefix: "envoy-admin",
|
||||||
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
||||||
|
|
|
@ -98,7 +98,7 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
}})
|
}})
|
||||||
|
|
||||||
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
||||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||||
StatPrefix: "grpc_ingress",
|
StatPrefix: "grpc_ingress",
|
||||||
// limit request first byte to last byte time
|
// limit request first byte to last byte time
|
||||||
|
|
|
@ -233,7 +233,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return HTTPConnectionManagerFilter(mgr), nil
|
return b.HTTPConnectionManagerFilter(mgr), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newListenerAccessLog() *envoy_config_accesslog_v3.AccessLog {
|
func newListenerAccessLog() *envoy_config_accesslog_v3.AccessLog {
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_requireProxyProtocol(t *testing.T) {
|
func Test_requireProxyProtocol(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
b := New("local-grpc", "local-http", "local-metrics", nil, nil, true)
|
||||||
t.Run("required", func(t *testing.T) {
|
t.Run("required", func(t *testing.T) {
|
||||||
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
|
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
|
||||||
UseProxyProtocol: true,
|
UseProxyProtocol: true,
|
||||||
|
|
|
@ -121,7 +121,7 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() *envoy_config_listen
|
||||||
},
|
},
|
||||||
}})
|
}})
|
||||||
|
|
||||||
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
||||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||||
StatPrefix: "metrics",
|
StatPrefix: "metrics",
|
||||||
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
||||||
|
|
|
@ -51,7 +51,7 @@ func TestBuildListeners(t *testing.T) {
|
||||||
OutboundPort: "10003",
|
OutboundPort: "10003",
|
||||||
MetricsPort: "10004",
|
MetricsPort: "10004",
|
||||||
}
|
}
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
t.Run("enable grpc by default", func(t *testing.T) {
|
t.Run("enable grpc by default", func(t *testing.T) {
|
||||||
cfg := cfg.Clone()
|
cfg := cfg.Clone()
|
||||||
lis, err := b.BuildListeners(ctx, cfg, false)
|
lis, err := b.BuildListeners(ctx, cfg, false)
|
||||||
|
@ -125,7 +125,7 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) {
|
||||||
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-5a353247453159375849565a.pem")
|
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-5a353247453159375849565a.pem")
|
||||||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3159554e32473758435257364b.pem")
|
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3159554e32473758435257364b.pem")
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
li, err := b.buildMetricsListener(&config.Config{
|
li, err := b.buildMetricsListener(&config.Config{
|
||||||
Options: &config.Options{
|
Options: &config.Options{
|
||||||
MetricsAddr: "127.0.0.1:9902",
|
MetricsAddr: "127.0.0.1:9902",
|
||||||
|
@ -143,7 +143,7 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
b := New("local-grpc", "local-http", "local-metrics", nil, nil, true)
|
||||||
|
|
||||||
options := config.NewDefaultOptions()
|
options := config.NewDefaultOptions()
|
||||||
options.SkipXffAppend = true
|
options.SkipXffAppend = true
|
||||||
|
|
|
@ -42,7 +42,7 @@ func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_liste
|
||||||
func (b *Builder) buildOutboundHTTPConnectionManager() *envoy_config_listener_v3.Filter {
|
func (b *Builder) buildOutboundHTTPConnectionManager() *envoy_config_listener_v3.Filter {
|
||||||
rc := b.buildOutboundRouteConfiguration()
|
rc := b.buildOutboundRouteConfiguration()
|
||||||
|
|
||||||
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
||||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||||
StatPrefix: "grpc_egress",
|
StatPrefix: "grpc_egress",
|
||||||
// limit request first byte to last byte time
|
// limit request first byte to last byte time
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_buildOutboundRoutes(t *testing.T) {
|
func Test_buildOutboundRoutes(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
b := New("local-grpc", "local-http", "local-metrics", nil, nil, true)
|
||||||
routes := b.buildOutboundRoutes()
|
routes := b.buildOutboundRoutes()
|
||||||
testutil.AssertProtoJSONEqual(t, `[
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,11 +1,17 @@
|
||||||
package envoyconfig_test
|
package envoyconfig_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
@ -28,9 +34,9 @@ func TestH2C(t *testing.T) {
|
||||||
|
|
||||||
http := up.Route().
|
http := up.Route().
|
||||||
From(env.SubdomainURL("grpc-http")).
|
From(env.SubdomainURL("grpc-http")).
|
||||||
To(values.Bind(up.Port(), func(port int) string {
|
To(values.Bind(up.Addr(), func(addr string) string {
|
||||||
// override the target protocol to use http://
|
// override the target protocol to use http://
|
||||||
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
return fmt.Sprintf("http://%s", addr)
|
||||||
})).
|
})).
|
||||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||||
|
|
||||||
|
@ -118,6 +124,234 @@ func TestHTTP(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPTunnel(t *testing.T) {
|
||||||
|
env := testenv.New(t, testenv.Debug())
|
||||||
|
|
||||||
|
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||||
|
up := upstreams.TCP()
|
||||||
|
routeH1 := up.Route().
|
||||||
|
From(env.SubdomainURL("h1")).
|
||||||
|
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||||
|
routeH2 := up.Route().
|
||||||
|
From(env.SubdomainURL("h2")).
|
||||||
|
Policy(func(p *config.Policy) {
|
||||||
|
p.AllowWebsockets = true
|
||||||
|
}).
|
||||||
|
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||||
|
|
||||||
|
up.Handle(func(_ context.Context, c net.Conn) error {
|
||||||
|
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
n, err := c.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, string(buf[:n]), "hello")
|
||||||
|
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
_, err = c.Write([]byte("world"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
t.Run("http1", func(t *testing.T) {
|
||||||
|
assert.NoError(t, up.Dial(routeH1, func(_ context.Context, c net.Conn) error {
|
||||||
|
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
_, err := c.Write([]byte("hello"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
n, err := c.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, string(buf[:n]), "world")
|
||||||
|
return nil
|
||||||
|
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP1)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("http2", func(t *testing.T) {
|
||||||
|
assert.NoError(t, up.Dial(routeH2, func(_ context.Context, c net.Conn) error {
|
||||||
|
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
_, err := c.Write([]byte("hello"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
n, err := c.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, string(buf[:n]), "world")
|
||||||
|
return nil
|
||||||
|
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP2)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHTTP1TCPTunnel(b *testing.B) {
|
||||||
|
env := testenv.New(b, testenv.Silent())
|
||||||
|
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||||
|
up := upstreams.TCP()
|
||||||
|
h1 := up.Route().
|
||||||
|
From(env.SubdomainURL("bench-h1")).
|
||||||
|
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
b.Run("http1", func(b *testing.B) {
|
||||||
|
benchmarkTCP(b, up, h1, tcpBenchmarkParams{
|
||||||
|
msgLen: 512,
|
||||||
|
protocol: upstreams.DialHTTP1,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHTTP2TCPTunnel(b *testing.B) {
|
||||||
|
env := testenv.New(b, testenv.Silent())
|
||||||
|
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||||
|
up := upstreams.TCP()
|
||||||
|
|
||||||
|
h2 := up.Route().
|
||||||
|
From(env.SubdomainURL("bench-h2")).
|
||||||
|
Policy(func(p *config.Policy) {
|
||||||
|
p.AllowWebsockets = true
|
||||||
|
}).
|
||||||
|
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
b.Run("http2", func(b *testing.B) {
|
||||||
|
benchmarkTCP(b, up, h2, tcpBenchmarkParams{
|
||||||
|
msgLen: 512,
|
||||||
|
protocol: upstreams.DialHTTP2,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpBenchmarkParams struct {
|
||||||
|
msgLen int
|
||||||
|
protocol upstreams.Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkTCP(b *testing.B, up upstreams.TCPUpstream, route testenv.Route, params tcpBenchmarkParams) {
|
||||||
|
sendMsg := func(c net.Conn, buf []byte) error {
|
||||||
|
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
_, err := c.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
recvMsg := func(c net.Conn, buf []byte) error {
|
||||||
|
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
for read := 0; read != len(buf); {
|
||||||
|
n, err := c.Read(buf)
|
||||||
|
read += n
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
up.Handle(func(_ context.Context, c net.Conn) error {
|
||||||
|
for {
|
||||||
|
buf := make([]byte, params.msgLen)
|
||||||
|
if err := recvMsg(c, buf[:]); err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := sendMsg(c, buf[:]); err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
var threads atomic.Int32
|
||||||
|
var requests atomic.Int32
|
||||||
|
var bytes atomic.Int64
|
||||||
|
start := time.Now()
|
||||||
|
b.RunParallel(func(p *testing.PB) {
|
||||||
|
threads.Add(1)
|
||||||
|
require.NoError(b, up.Dial(route, func(_ context.Context, c net.Conn) error {
|
||||||
|
buf := make([]byte, params.msgLen)
|
||||||
|
for p.Next() {
|
||||||
|
requests.Add(1)
|
||||||
|
bytes.Add(int64(params.msgLen))
|
||||||
|
require.NoError(b, sendMsg(c, buf[:]))
|
||||||
|
require.NoError(b, recvMsg(c, buf[:]))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(params.protocol)))
|
||||||
|
})
|
||||||
|
duration := time.Since(start)
|
||||||
|
b.Logf("sent %d requests over %d parallel connections in %s", requests.Load(), threads.Load(), duration)
|
||||||
|
b.Logf("throughput: %f bytes/s", float64(bytes.Load())/duration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttp1Websocket(t *testing.T) {
|
||||||
|
env := testenv.New(t)
|
||||||
|
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
up.HandleWS("/ws", websocket.Upgrader{}, func(conn *websocket.Conn) error {
|
||||||
|
for {
|
||||||
|
mt, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// echo the message back
|
||||||
|
err = conn.WriteMessage(mt, message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
route := up.Route().
|
||||||
|
From(env.SubdomainURL("ws-test")).
|
||||||
|
Policy(func(p *config.Policy) {
|
||||||
|
p.AllowPublicUnauthenticatedAccess = true
|
||||||
|
p.AllowWebsockets = true
|
||||||
|
})
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
assert.NoError(t, up.DialWS(route, func(conn *websocket.Conn) error {
|
||||||
|
if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello world")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mt, bytes, err := conn.ReadMessage()
|
||||||
|
if err := err; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
assert.Equal(t, websocket.TextMessage, mt)
|
||||||
|
assert.Equal(t, "hello world", string(bytes))
|
||||||
|
return nil
|
||||||
|
}, upstreams.Path("/ws")))
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientCert(t *testing.T) {
|
func TestClientCert(t *testing.T) {
|
||||||
env := testenv.New(t)
|
env := testenv.New(t)
|
||||||
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
||||||
|
|
|
@ -11,8 +11,8 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BuildRouteConfigurations builds the route configurations for the RDS service.
|
// BuildRouteConfigurations builds the route configurations for the RDS service.
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestBuilder_buildMainRouteConfiguration(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
b := New("grpc", "http", "metrics", filemgr.NewManager(), nil)
|
b := New("grpc", "http", "metrics", filemgr.NewManager(), nil, true)
|
||||||
routeConfiguration, err := b.buildMainRouteConfiguration(ctx, cfg)
|
routeConfiguration, err := b.buildMainRouteConfiguration(ctx, cfg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
|
|
@ -231,10 +231,6 @@
|
||||||
"addressPrefix": "127.0.0.1",
|
"addressPrefix": "127.0.0.1",
|
||||||
"prefixLen": 32
|
"prefixLen": 32
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"addressPrefix": "::1",
|
|
||||||
"prefixLen": 128
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"addressPrefix": "10.0.0.0",
|
"addressPrefix": "10.0.0.0",
|
||||||
"prefixLen": 8
|
"prefixLen": 8
|
||||||
|
@ -247,6 +243,10 @@
|
||||||
"addressPrefix": "172.16.0.0",
|
"addressPrefix": "172.16.0.0",
|
||||||
"prefixLen": 12
|
"prefixLen": 12
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"addressPrefix": "::1",
|
||||||
|
"prefixLen": 128
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"addressPrefix": "fd00::",
|
"addressPrefix": "fd00::",
|
||||||
"prefixLen": 8
|
"prefixLen": 8
|
||||||
|
|
|
@ -61,10 +61,6 @@
|
||||||
"addressPrefix": "127.0.0.1",
|
"addressPrefix": "127.0.0.1",
|
||||||
"prefixLen": 32
|
"prefixLen": 32
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"addressPrefix": "::1",
|
|
||||||
"prefixLen": 128
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"addressPrefix": "10.0.0.0",
|
"addressPrefix": "10.0.0.0",
|
||||||
"prefixLen": 8
|
"prefixLen": 8
|
||||||
|
@ -77,6 +73,10 @@
|
||||||
"addressPrefix": "172.16.0.0",
|
"addressPrefix": "172.16.0.0",
|
||||||
"prefixLen": 12
|
"prefixLen": 12
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"addressPrefix": "::1",
|
||||||
|
"prefixLen": 128
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"addressPrefix": "fd00::",
|
"addressPrefix": "fd00::",
|
||||||
"prefixLen": 8
|
"prefixLen": 8
|
||||||
|
|
|
@ -82,7 +82,7 @@ func TestValidateCertificate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_buildDownstreamTLSContext(t *testing.T) {
|
func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true)
|
||||||
|
|
||||||
cacheDir, _ := os.UserCacheDir()
|
cacheDir, _ := os.UserCacheDir()
|
||||||
clientCAFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "client-ca-4e4c564e5a36544a4a33385a.pem")
|
clientCAFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "client-ca-4e4c564e5a36544a4a33385a.pem")
|
||||||
|
|
|
@ -17,7 +17,7 @@ import (
|
||||||
extensions_uuidx "github.com/pomerium/envoy-custom/api/extensions/request_id/uuidx"
|
extensions_uuidx "github.com/pomerium/envoy-custom/api/extensions/request_id/uuidx"
|
||||||
extensions_pomerium_otel "github.com/pomerium/envoy-custom/api/extensions/tracers/pomerium_otel"
|
extensions_pomerium_otel "github.com/pomerium/envoy-custom/api/extensions/tracers/pomerium_otel"
|
||||||
"github.com/pomerium/pomerium/config/otelconfig"
|
"github.com/pomerium/pomerium/config/otelconfig"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
)
|
)
|
||||||
|
|
|
@ -202,7 +202,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||||
} else if !res.Valid {
|
} else if !res.Valid {
|
||||||
return nil, fmt.Errorf("invalid access token")
|
return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
|
@ -265,7 +265,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
||||||
} else if !res.Valid {
|
} else if !res.Valid {
|
||||||
return nil, fmt.Errorf("invalid identity token")
|
return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||||
"github.com/pomerium/pomerium/internal/events"
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||||
|
@ -28,6 +27,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
10
go.mod
10
go.mod
|
@ -92,11 +92,11 @@ require (
|
||||||
go.uber.org/automaxprocs v1.6.0
|
go.uber.org/automaxprocs v1.6.0
|
||||||
go.uber.org/mock v0.5.0
|
go.uber.org/mock v0.5.0
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
golang.org/x/crypto v0.35.0
|
golang.org/x/crypto v0.36.0
|
||||||
golang.org/x/net v0.36.0
|
golang.org/x/net v0.37.0
|
||||||
golang.org/x/oauth2 v0.27.0
|
golang.org/x/oauth2 v0.27.0
|
||||||
golang.org/x/sync v0.11.0
|
golang.org/x/sync v0.12.0
|
||||||
golang.org/x/sys v0.30.0
|
golang.org/x/sys v0.31.0
|
||||||
golang.org/x/time v0.10.0
|
golang.org/x/time v0.10.0
|
||||||
google.golang.org/api v0.223.0
|
google.golang.org/api v0.223.0
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2
|
||||||
|
@ -255,7 +255,7 @@ require (
|
||||||
go.uber.org/zap/exp v0.3.0 // indirect
|
go.uber.org/zap/exp v0.3.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
||||||
golang.org/x/mod v0.20.0 // indirect
|
golang.org/x/mod v0.20.0 // indirect
|
||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
golang.org/x/tools v0.24.0 // indirect
|
golang.org/x/tools v0.24.0 // indirect
|
||||||
google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect
|
google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||||
|
|
24
go.sum
24
go.sum
|
@ -781,8 +781,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
|
@ -858,8 +858,8 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su
|
||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||||
golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA=
|
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||||
golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I=
|
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
|
@ -882,8 +882,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
|
||||||
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
@ -944,15 +944,15 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
@ -963,8 +963,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
|
|
@ -14,10 +14,10 @@ import (
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// timeNow is time.Now but pulled out as a variable for tests.
|
// timeNow is time.Now but pulled out as a variable for tests.
|
||||||
|
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
|
@ -33,6 +32,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Stateful implements the stateful authentication flow. In this flow, the
|
// Stateful implements the stateful authentication flow. In this flow, the
|
||||||
|
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
|
@ -31,6 +30,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
|
|
@ -17,10 +17,10 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
|
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (srv *Server) addHTTPMiddleware(ctx context.Context, root *mux.Router, _ *config.Config) {
|
func (srv *Server) addHTTPMiddleware(ctx context.Context, root *mux.Router, _ *config.Config) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
|
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
|
||||||
|
"golang.org/x/net/nettest"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/health/grpc_health_v1"
|
"google.golang.org/grpc/health/grpc_health_v1"
|
||||||
|
@ -27,7 +28,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/events"
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
"github.com/pomerium/pomerium/internal/httputil/reproxy"
|
"github.com/pomerium/pomerium/internal/httputil/reproxy"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||||
|
@ -35,6 +35,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/httputil"
|
"github.com/pomerium/pomerium/pkg/httputil"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -177,6 +178,7 @@ func NewServer(
|
||||||
srv.MetricsListener.Addr().String(),
|
srv.MetricsListener.Addr().String(),
|
||||||
srv.filemgr,
|
srv.filemgr,
|
||||||
srv.reproxy,
|
srv.reproxy,
|
||||||
|
nettest.SupportsIPv6(),
|
||||||
)
|
)
|
||||||
|
|
||||||
res, err := srv.buildDiscoveryResources(ctx)
|
res, err := srv.buildDiscoveryResources(ctx)
|
||||||
|
|
|
@ -15,13 +15,13 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/hashutil"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/health"
|
"github.com/pomerium/pomerium/pkg/health"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
googlegrpc "google.golang.org/grpc"
|
googlegrpc "google.golang.org/grpc"
|
||||||
|
|
|
@ -17,11 +17,11 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,9 @@ var (
|
||||||
// ErrNoSessionFound is the error for when no session is found.
|
// ErrNoSessionFound is the error for when no session is found.
|
||||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||||
|
|
||||||
|
// ErrInvalidSession is the error for when a session is invalid.
|
||||||
|
ErrInvalidSession = errors.New("internal/sessions: invalid session")
|
||||||
|
|
||||||
// ErrMalformed is the error for when a session is found but is malformed.
|
// ErrMalformed is the error for when a session is found but is malformed.
|
||||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||||
|
|
||||||
|
|
|
@ -240,11 +240,11 @@ var (
|
||||||
Measure: identityManagerLastSessionRefreshSuccess,
|
Measure: identityManagerLastSessionRefreshSuccess,
|
||||||
Aggregation: view.Count(),
|
Aggregation: view.Count(),
|
||||||
}
|
}
|
||||||
// IdentityManagerLastSessionRefreshErrorView contains user refresh errors counter
|
// IdentityManagerLastSessionRefreshErrorView contains session refresh errors counter
|
||||||
IdentityManagerLastSessionRefreshErrorView = &view.View{
|
IdentityManagerLastSessionRefreshErrorView = &view.View{
|
||||||
Name: identityManagerLastUserRefreshError.Name(),
|
Name: identityManagerLastSessionRefreshError.Name(),
|
||||||
Description: identityManagerLastUserRefreshError.Description(),
|
Description: identityManagerLastSessionRefreshError.Description(),
|
||||||
Measure: identityManagerLastUserRefreshError,
|
Measure: identityManagerLastSessionRefreshError,
|
||||||
Aggregation: view.Count(),
|
Aggregation: view.Count(),
|
||||||
}
|
}
|
||||||
// IdentityManagerLastSessionRefreshSuccessTimestampView contains successful session refresh counter
|
// IdentityManagerLastSessionRefreshSuccessTimestampView contains successful session refresh counter
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -34,11 +35,12 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
|
"github.com/pomerium/pomerium/config/otelconfig"
|
||||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv/envutil"
|
"github.com/pomerium/pomerium/internal/testenv/envutil"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy"
|
"github.com/pomerium/pomerium/pkg/envoy"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -46,6 +48,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/identity/manager"
|
"github.com/pomerium/pomerium/pkg/identity/manager"
|
||||||
"github.com/pomerium/pomerium/pkg/netutil"
|
"github.com/pomerium/pomerium/pkg/netutil"
|
||||||
"github.com/pomerium/pomerium/pkg/slices"
|
"github.com/pomerium/pomerium/pkg/slices"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -97,6 +100,7 @@ type Environment interface {
|
||||||
AuthenticateURL() values.Value[string]
|
AuthenticateURL() values.Value[string]
|
||||||
DatabrokerURL() values.Value[string]
|
DatabrokerURL() values.Value[string]
|
||||||
Ports() Ports
|
Ports() Ports
|
||||||
|
Host() string
|
||||||
SharedSecret() []byte
|
SharedSecret() []byte
|
||||||
CookieSecret() []byte
|
CookieSecret() []byte
|
||||||
|
|
||||||
|
@ -244,6 +248,8 @@ type EnvironmentOptions struct {
|
||||||
forceSilent bool
|
forceSilent bool
|
||||||
traceDebugFlags trace.DebugFlags
|
traceDebugFlags trace.DebugFlags
|
||||||
traceClient otlptrace.Client
|
traceClient otlptrace.Client
|
||||||
|
traceConfig *otelconfig.Config
|
||||||
|
host string
|
||||||
}
|
}
|
||||||
|
|
||||||
type EnvironmentOption func(*EnvironmentOptions)
|
type EnvironmentOption func(*EnvironmentOptions)
|
||||||
|
@ -300,15 +306,23 @@ func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithTraceConfig(traceConfig *otelconfig.Config) EnvironmentOption {
|
||||||
|
return func(o *EnvironmentOptions) {
|
||||||
|
o.traceConfig = traceConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var setGrpcLoggerOnce sync.Once
|
var setGrpcLoggerOnce sync.Once
|
||||||
|
|
||||||
const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences
|
const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences
|
||||||
|
|
||||||
var (
|
var (
|
||||||
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
|
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
|
||||||
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
|
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
|
||||||
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
|
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
|
||||||
flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)")
|
flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)")
|
||||||
|
flagBindAddress = flag.String("env.bind-address", "127.0.0.1", "bind address for local services")
|
||||||
|
flagTraceEnvironConfig = flag.Bool("env.use-trace-environ", false, "if true, will configure a trace client from environment variables if no trace client has been set")
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||||
|
@ -323,6 +337,7 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||||
forceSilent: *flagSilent,
|
forceSilent: *flagSilent,
|
||||||
traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags),
|
traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags),
|
||||||
traceClient: trace.NoopClient{},
|
traceClient: trace.NoopClient{},
|
||||||
|
host: *flagBindAddress,
|
||||||
}
|
}
|
||||||
options.apply(opts...)
|
options.apply(opts...)
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
|
@ -332,6 +347,17 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||||
if addTraceDebugFlags {
|
if addTraceDebugFlags {
|
||||||
options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags)
|
options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags)
|
||||||
}
|
}
|
||||||
|
if *flagTraceEnvironConfig && options.traceConfig == nil &&
|
||||||
|
(reflect.TypeOf(options.traceClient) == reflect.TypeFor[trace.NoopClient]()) {
|
||||||
|
cfg := newOtelConfigFromEnv(t)
|
||||||
|
options.traceConfig = &cfg
|
||||||
|
client, err := trace.NewTraceClientFromConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("tracing configured from environment")
|
||||||
|
options.traceClient = client
|
||||||
|
}
|
||||||
trace.UseGlobalPanicTracer()
|
trace.UseGlobalPanicTracer()
|
||||||
databroker.DebugUseFasterBackoff.Store(true)
|
databroker.DebugUseFasterBackoff.Store(true)
|
||||||
workspaceFolder, err := os.Getwd()
|
workspaceFolder, err := os.Getwd()
|
||||||
|
@ -495,7 +521,7 @@ func (e *environment) AuthenticateURL() values.Value[string] {
|
||||||
|
|
||||||
func (e *environment) DatabrokerURL() values.Value[string] {
|
func (e *environment) DatabrokerURL() values.Value[string] {
|
||||||
return values.Bind(e.ports.Outbound, func(port int) string {
|
return values.Bind(e.ports.Outbound, func(port int) string {
|
||||||
return fmt.Sprintf("127.0.0.1:%d", port)
|
return fmt.Sprintf("%s:%d", e.host, port)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -503,6 +529,13 @@ func (e *environment) Ports() Ports {
|
||||||
return e.ports
|
return e.ports
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *environment) Host() string {
|
||||||
|
if e.host == "" {
|
||||||
|
return "127.0.0.1"
|
||||||
|
}
|
||||||
|
return e.host
|
||||||
|
}
|
||||||
|
|
||||||
func (e *environment) CACert() *tls.Certificate {
|
func (e *environment) CACert() *tls.Certificate {
|
||||||
caCert, err := tls.LoadX509KeyPair(
|
caCert, err := tls.LoadX509KeyPair(
|
||||||
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||||
|
@ -571,9 +604,9 @@ func (e *environment) Start() {
|
||||||
cfg.Options.Services = "all"
|
cfg.Options.Services = "all"
|
||||||
cfg.Options.LogLevel = config.LogLevelDebug
|
cfg.Options.LogLevel = config.LogLevelDebug
|
||||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value())
|
cfg.Options.Addr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyHTTP.Value())
|
||||||
cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value())
|
cfg.Options.GRPCAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyGRPC.Value())
|
||||||
cfg.Options.MetricsAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyMetrics.Value())
|
cfg.Options.MetricsAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyMetrics.Value())
|
||||||
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
|
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
|
||||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||||
|
@ -598,6 +631,9 @@ func (e *environment) Start() {
|
||||||
log.AccessLogFieldUserAgent,
|
log.AccessLogFieldUserAgent,
|
||||||
log.AccessLogFieldClientCertificate,
|
log.AccessLogFieldClientCertificate,
|
||||||
}
|
}
|
||||||
|
if e.traceConfig != nil {
|
||||||
|
cfg.Options.Tracing = *e.traceConfig
|
||||||
|
}
|
||||||
|
|
||||||
e.src = &configSource{cfg: cfg}
|
e.src = &configSource{cfg: cfg}
|
||||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||||
|
@ -799,6 +835,7 @@ func (e *environment) Pause() {
|
||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
signal.Notify(c, syscall.SIGINT)
|
signal.Notify(c, syscall.SIGINT)
|
||||||
<-c
|
<-c
|
||||||
|
signal.Stop(c)
|
||||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -816,6 +853,7 @@ func (e *environment) cleanup(cancelCause error) {
|
||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
signal.Notify(c, syscall.SIGINT)
|
signal.Notify(c, syscall.SIGINT)
|
||||||
<-c
|
<-c
|
||||||
|
signal.Stop(c)
|
||||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||||
signal.Stop(c)
|
signal.Stop(c)
|
||||||
}
|
}
|
||||||
|
@ -1043,3 +1081,13 @@ func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) {
|
||||||
li(ctx, src.cfg)
|
li(ctx, src.cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newOtelConfigFromEnv(t testing.TB) otelconfig.Config {
|
||||||
|
f, err := os.CreateTemp("", "tmp-config-*.yaml")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
f.Close()
|
||||||
|
cfg, err := config.NewFileOrEnvironmentSource(context.Background(), f.Name(), version.FullVersion())
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cfg.GetConfig().Options.Tracing
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package testenv
|
package testenv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -74,3 +75,19 @@ func (b *PolicyRoute) PPL(ppl string) Route {
|
||||||
func (b *PolicyRoute) URL() values.Value[string] {
|
func (b *PolicyRoute) URL() values.Value[string] {
|
||||||
return b.from
|
return b.from
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TCPRoute struct {
|
||||||
|
PolicyRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *TCPRoute) From(fromURL values.Value[string]) Route {
|
||||||
|
b.from = values.Bind(fromURL, func(urlStr string) string {
|
||||||
|
from, _ := url.Parse(urlStr)
|
||||||
|
from.Scheme = "tcp+https"
|
||||||
|
from.Host = fmt.Sprintf("%s:%s", from.Hostname(), from.Port())
|
||||||
|
return from.String()
|
||||||
|
})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Route = (*TCPRoute)(nil)
|
||||||
|
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -32,6 +34,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDP struct {
|
type IDP struct {
|
||||||
|
IDPOptions
|
||||||
id values.Value[string]
|
id values.Value[string]
|
||||||
url values.Value[string]
|
url values.Value[string]
|
||||||
publicJWK jose.JSONWebKey
|
publicJWK jose.JSONWebKey
|
||||||
|
@ -41,18 +44,56 @@ type IDP struct {
|
||||||
userLookup map[string]*User
|
userLookup map[string]*User
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type IDPOptions struct {
|
||||||
|
enableTLS bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type IDPOption func(*IDPOptions)
|
||||||
|
|
||||||
|
func (o *IDPOptions) apply(opts ...IDPOption) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithEnableTLS(enableTLS bool) IDPOption {
|
||||||
|
return func(o *IDPOptions) {
|
||||||
|
o.enableTLS = enableTLS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Attach implements testenv.Modifier.
|
// Attach implements testenv.Modifier.
|
||||||
func (idp *IDP) Attach(ctx context.Context) {
|
func (idp *IDP) Attach(ctx context.Context) {
|
||||||
env := testenv.EnvFromContext(ctx)
|
env := testenv.EnvFromContext(ctx)
|
||||||
|
|
||||||
router := upstreams.HTTP(nil, upstreams.WithDisplayName("IDP"))
|
idpURL := env.SubdomainURL("mock-idp")
|
||||||
|
|
||||||
idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string {
|
var tlsConfig values.Value[*tls.Config]
|
||||||
|
if idp.enableTLS {
|
||||||
|
tlsConfig = values.Bind(idpURL, func(urlStr string) *tls.Config {
|
||||||
|
u, _ := url.Parse(urlStr)
|
||||||
|
cert := env.NewServerCert(&x509.Certificate{
|
||||||
|
DNSNames: []string{u.Hostname()},
|
||||||
|
})
|
||||||
|
return &tls.Config{
|
||||||
|
RootCAs: env.ServerCAs(),
|
||||||
|
Certificates: []tls.Certificate{tls.Certificate(*cert)},
|
||||||
|
NextProtos: []string{"http/1.1", "h2"},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
router := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP"))
|
||||||
|
|
||||||
|
idp.url = values.Bind2(idpURL, router.Addr(), func(urlStr string, addr string) string {
|
||||||
u, _ := url.Parse(urlStr)
|
u, _ := url.Parse(urlStr)
|
||||||
host, _, _ := net.SplitHostPort(u.Host)
|
host, _, _ := net.SplitHostPort(u.Host)
|
||||||
|
_, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
panic("bug: " + err.Error())
|
||||||
|
}
|
||||||
return u.ResolveReference(&url.URL{
|
return u.ResolveReference(&url.URL{
|
||||||
Scheme: "http",
|
Host: fmt.Sprintf("%s:%s", host, port),
|
||||||
Host: fmt.Sprintf("%s:%d", host, port),
|
|
||||||
}).String()
|
}).String()
|
||||||
})
|
})
|
||||||
var err error
|
var err error
|
||||||
|
@ -108,7 +149,12 @@ func (idp *IDP) Modify(cfg *config.Config) {
|
||||||
|
|
||||||
var _ testenv.Modifier = (*IDP)(nil)
|
var _ testenv.Modifier = (*IDP)(nil)
|
||||||
|
|
||||||
func NewIDP(users []*User) *IDP {
|
func NewIDP(users []*User, opts ...IDPOption) *IDP {
|
||||||
|
options := IDPOptions{
|
||||||
|
enableTLS: true,
|
||||||
|
}
|
||||||
|
options.apply(opts...)
|
||||||
|
|
||||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -136,6 +182,7 @@ func NewIDP(users []*User) *IDP {
|
||||||
userLookup[user.ID] = user
|
userLookup[user.ID] = user
|
||||||
}
|
}
|
||||||
return &IDP{
|
return &IDP{
|
||||||
|
IDPOptions: options,
|
||||||
publicJWK: publicJWK,
|
publicJWK: publicJWK,
|
||||||
signingKey: signingKey,
|
signingKey: signingKey,
|
||||||
userLookup: userLookup,
|
userLookup: userLookup,
|
||||||
|
|
|
@ -174,16 +174,16 @@ func (rec *OTLPTraceReceiver) FlushResourceSpans() []*tracev1.ResourceSpans {
|
||||||
// GRPCEndpointURL returns a url suitable for use with the environment variable
|
// GRPCEndpointURL returns a url suitable for use with the environment variable
|
||||||
// $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracegrpc.WithEndpointURL].
|
// $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracegrpc.WithEndpointURL].
|
||||||
func (rec *OTLPTraceReceiver) GRPCEndpointURL() values.Value[string] {
|
func (rec *OTLPTraceReceiver) GRPCEndpointURL() values.Value[string] {
|
||||||
return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Port, func(port int) string {
|
return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Addr, func(addr string) string {
|
||||||
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
return fmt.Sprintf("http://%s", addr)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GRPCEndpointURL returns a url suitable for use with the environment variable
|
// GRPCEndpointURL returns a url suitable for use with the environment variable
|
||||||
// $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracehttp.WithEndpointURL].
|
// $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracehttp.WithEndpointURL].
|
||||||
func (rec *OTLPTraceReceiver) HTTPEndpointURL() values.Value[string] {
|
func (rec *OTLPTraceReceiver) HTTPEndpointURL() values.Value[string] {
|
||||||
return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) string {
|
return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) string {
|
||||||
return fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port)
|
return fmt.Sprintf("http://%s/v1/traces", addr)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,9 +200,9 @@ func (rec *OTLPTraceReceiver) NewGRPCClient(opts ...otlptracegrpc.Option) otlptr
|
||||||
|
|
||||||
func (rec *OTLPTraceReceiver) NewHTTPClient(opts ...otlptracehttp.Option) otlptrace.Client {
|
func (rec *OTLPTraceReceiver) NewHTTPClient(opts ...otlptracehttp.Option) otlptrace.Client {
|
||||||
return &deferredClient{
|
return &deferredClient{
|
||||||
client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) otlptrace.Client {
|
client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) otlptrace.Client {
|
||||||
return otlptracehttp.NewClient(append(opts,
|
return otlptracehttp.NewClient(append(opts,
|
||||||
otlptracehttp.WithEndpointURL(fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port)),
|
otlptracehttp.WithEndpointURL(fmt.Sprintf("http://%s/v1/traces", addr)),
|
||||||
otlptracehttp.WithTimeout(1*time.Minute),
|
otlptracehttp.WithTimeout(1*time.Minute),
|
||||||
)...)
|
)...)
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -21,12 +21,12 @@ import (
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
var allServices = []string{
|
var allServices = []string{
|
||||||
|
|
|
@ -4,9 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/connectivity"
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
|
@ -204,7 +204,8 @@ func (f *taskFunc) Run(ctx context.Context) error {
|
||||||
type Upstream interface {
|
type Upstream interface {
|
||||||
Modifier
|
Modifier
|
||||||
Task
|
Task
|
||||||
Port() values.Value[int]
|
|
||||||
|
Addr() values.Value[string]
|
||||||
Route() RouteStub
|
Route() RouteStub
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,10 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
@ -97,8 +97,10 @@ type service struct {
|
||||||
impl any
|
impl any
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *grpcUpstream) Port() values.Value[int] {
|
func (g *grpcUpstream) Addr() values.Value[string] {
|
||||||
return g.serverPort
|
return values.Bind(g.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s:%d", g.Env().Host(), port)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterService implements grpc.ServiceRegistrar.
|
// RegisterService implements grpc.ServiceRegistrar.
|
||||||
|
@ -117,7 +119,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
r.To(values.Bind(g.serverPort, func(port int) string {
|
r.To(values.Bind(g.serverPort, func(port int) string {
|
||||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
return fmt.Sprintf("%s://%s:%d", protocol, g.Env().Host(), port)
|
||||||
}))
|
}))
|
||||||
g.Add(r)
|
g.Add(r)
|
||||||
return r
|
return r
|
||||||
|
@ -125,7 +127,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub {
|
||||||
|
|
||||||
// Start implements testenv.Upstream.
|
// Start implements testenv.Upstream.
|
||||||
func (g *grpcUpstream) Run(ctx context.Context) error {
|
func (g *grpcUpstream) Run(ctx context.Context) error {
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", g.Env().Host()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -187,7 +189,7 @@ func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||||
cc, err := grpc.NewClient(fmt.Sprintf("127.0.0.1:%d", g.Port().Value()),
|
cc, err := grpc.NewClient(g.Addr().Value(),
|
||||||
append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...)
|
append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
package upstreams
|
package upstreams
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -13,25 +11,29 @@ import (
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"os"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/pomerium/pomerium/integration/forms"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/pomerium/pomerium/internal/retry"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||||
|
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/codes"
|
"go.opentelemetry.io/otel/codes"
|
||||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
"google.golang.org/protobuf/proto"
|
)
|
||||||
|
|
||||||
|
type Protocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
DialHTTP1 Protocol = "http/1.1"
|
||||||
|
DialHTTP2 Protocol = "h2"
|
||||||
|
DialHTTP3 Protocol = "h3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestOptions struct {
|
type RequestOptions struct {
|
||||||
|
@ -42,12 +44,18 @@ type RequestOptions struct {
|
||||||
authenticateAs string
|
authenticateAs string
|
||||||
body any
|
body any
|
||||||
clientCerts []tls.Certificate
|
clientCerts []tls.Certificate
|
||||||
client *http.Client
|
clientHook func(*http.Client) *http.Client
|
||||||
|
dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)
|
||||||
|
dialProtocol Protocol
|
||||||
trace *httptrace.ClientTrace
|
trace *httptrace.ClientTrace
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestOption func(*RequestOptions)
|
type RequestOption func(*RequestOptions)
|
||||||
|
|
||||||
|
func (ro RequestOption) Format(fmt.State, rune) {
|
||||||
|
panic("test bug: request option mistakenly passed to assert function")
|
||||||
|
}
|
||||||
|
|
||||||
func (o *RequestOptions) apply(opts ...RequestOption) {
|
func (o *RequestOptions) apply(opts ...RequestOption) {
|
||||||
for _, op := range opts {
|
for _, op := range opts {
|
||||||
op(o)
|
op(o)
|
||||||
|
@ -82,9 +90,38 @@ func AuthenticateAs(email string) RequestOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Client(c *http.Client) RequestOption {
|
// ClientHook allows editing or replacing the http client before it is used.
|
||||||
|
// When any request is about to start, this function will be called with the
|
||||||
|
// client that would be used to make the request. The returned client will
|
||||||
|
// be the actual client used for that request. It can be the same as the input
|
||||||
|
// (with or without modification), or replaced entirely.
|
||||||
|
//
|
||||||
|
// Note: the Transport of the client passed to the hook will always be a
|
||||||
|
// [*Transport]. That transport's underlying transport will always be
|
||||||
|
// a [*otelhttp.Transport].
|
||||||
|
func ClientHook(f func(*http.Client) *http.Client) RequestOption {
|
||||||
return func(o *RequestOptions) {
|
return func(o *RequestOptions) {
|
||||||
o.client = c
|
o.clientHook = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialerHook allows editing or replacing the websocket dialer before it is
|
||||||
|
// used. When a websocket request is about to start (using the DialWS method),
|
||||||
|
// this function will be called with the dialer that would be used, and the
|
||||||
|
// destination URL (including wss:// scheme, and path if one is present). The
|
||||||
|
// returned dialer+URL will be the actual dialer+URL used for that request.
|
||||||
|
//
|
||||||
|
// If ClientHook is also set, both will be called. The dialer passed to this
|
||||||
|
// hook will have its TLSClientConfig and Jar fields set from the client.
|
||||||
|
func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.dialerHook = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialProtocol(protocol Protocol) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.dialProtocol = protocol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,10 +180,12 @@ type HTTPUpstream interface {
|
||||||
testenv.Upstream
|
testenv.Upstream
|
||||||
|
|
||||||
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||||
|
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
||||||
|
|
||||||
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
|
DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpUpstream struct {
|
type httpUpstream struct {
|
||||||
|
@ -194,8 +233,10 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU
|
||||||
}
|
}
|
||||||
|
|
||||||
// Port implements HTTPUpstream.
|
// Port implements HTTPUpstream.
|
||||||
func (h *httpUpstream) Port() values.Value[int] {
|
func (h *httpUpstream) Addr() values.Value[string] {
|
||||||
return h.serverPort
|
return values.Bind(h.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Router implements HTTPUpstream.
|
// Router implements HTTPUpstream.
|
||||||
|
@ -203,12 +244,37 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req
|
||||||
return h.router.HandleFunc(path, f)
|
return h.router.HandleFunc(path, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Router implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
||||||
|
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, span := trace.Continue(r.Context(), "HandleWS")
|
||||||
|
defer span.End()
|
||||||
|
c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil)
|
||||||
|
if err != nil {
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
err = f(c)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Route implements HTTPUpstream.
|
// Route implements HTTPUpstream.
|
||||||
func (h *httpUpstream) Route() testenv.RouteStub {
|
func (h *httpUpstream) Route() testenv.RouteStub {
|
||||||
r := &testenv.PolicyRoute{}
|
r := &testenv.PolicyRoute{}
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
r.To(values.Bind(h.serverPort, func(port int) string {
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
|
||||||
}))
|
}))
|
||||||
h.Add(r)
|
h.Add(r)
|
||||||
return r
|
return r
|
||||||
|
@ -216,15 +282,21 @@ func (h *httpUpstream) Route() testenv.RouteStub {
|
||||||
|
|
||||||
// Run implements HTTPUpstream.
|
// Run implements HTTPUpstream.
|
||||||
func (h *httpUpstream) Run(ctx context.Context) error {
|
func (h *httpUpstream) Run(ctx context.Context) error {
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
var listener net.Listener
|
||||||
if err != nil {
|
if h.tlsConfig != nil {
|
||||||
return err
|
var err error
|
||||||
|
listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||||
var tlsConfig *tls.Config
|
|
||||||
if h.tlsConfig != nil {
|
|
||||||
tlsConfig = h.tlsConfig.Value()
|
|
||||||
}
|
|
||||||
if h.serverTracerProviderOverride != nil {
|
if h.serverTracerProviderOverride != nil {
|
||||||
h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
|
h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
|
||||||
} else {
|
} else {
|
||||||
|
@ -238,8 +310,7 @@ func (h *httpUpstream) Run(ctx context.Context) error {
|
||||||
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))
|
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))
|
||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Handler: h.router,
|
Handler: h.router,
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
// BaseContext: func(net.Listener) context.Context {
|
// BaseContext: func(net.Listener) context.Context {
|
||||||
// return ctx
|
// return ctx
|
||||||
// },
|
// },
|
||||||
|
@ -277,6 +348,53 @@ func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Respo
|
||||||
return h.Do(http.MethodPost, r, opts...)
|
return h.Do(http.MethodPost, r, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Transport struct {
|
||||||
|
*otelhttp.Transport
|
||||||
|
// The underlying http.Transport instance wrapped by the otelhttp.Transport.
|
||||||
|
Base *http.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.RoundTripper = Transport{}
|
||||||
|
|
||||||
|
func (h *httpUpstream) newClient(options *RequestOptions) *http.Client {
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
transport.TLSClientConfig = &tls.Config{
|
||||||
|
RootCAs: h.Env().ServerCAs(),
|
||||||
|
Certificates: options.clientCerts,
|
||||||
|
}
|
||||||
|
transport.DialTLSContext = nil
|
||||||
|
c := http.Client{
|
||||||
|
Transport: &Transport{
|
||||||
|
Transport: otelhttp.NewTransport(transport,
|
||||||
|
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
||||||
|
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
||||||
|
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
Base: transport,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client {
|
||||||
|
span := oteltrace.SpanFromContext(options.requestCtx)
|
||||||
|
var cachedClient any
|
||||||
|
var ok bool
|
||||||
|
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||||
|
span.AddEvent("creating new http client")
|
||||||
|
cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options))
|
||||||
|
} else {
|
||||||
|
span.AddEvent("using cached http client")
|
||||||
|
}
|
||||||
|
client := cachedClient.(*http.Client)
|
||||||
|
if options.clientHook != nil {
|
||||||
|
client = options.clientHook(client)
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
// Do implements HTTPUpstream.
|
// Do implements HTTPUpstream.
|
||||||
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||||
options := RequestOptions{
|
options := RequestOptions{
|
||||||
|
@ -303,141 +421,54 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
||||||
options.requestCtx = ctx
|
options.requestCtx = ctx
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
newClient := func() *http.Client {
|
return doAuthenticatedRequest(options.requestCtx,
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
func(ctx context.Context) (*http.Request, error) {
|
||||||
transport.TLSClientConfig = &tls.Config{
|
return http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||||
RootCAs: h.Env().ServerCAs(),
|
},
|
||||||
Certificates: options.clientCerts,
|
func(context.Context) *http.Client {
|
||||||
}
|
return h.getRouteClient(r, &options)
|
||||||
transport.DialTLSContext = nil
|
},
|
||||||
c := http.Client{
|
&options,
|
||||||
Transport: otelhttp.NewTransport(transport,
|
)
|
||||||
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
|
||||||
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
|
||||||
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
|
||||||
return &c
|
|
||||||
}
|
|
||||||
var client *http.Client
|
|
||||||
if options.client != nil {
|
|
||||||
client = options.client
|
|
||||||
} else {
|
|
||||||
var cachedClient any
|
|
||||||
var ok bool
|
|
||||||
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
|
||||||
span.AddEvent("creating new http client")
|
|
||||||
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
|
|
||||||
} else {
|
|
||||||
span.AddEvent("using cached http client")
|
|
||||||
}
|
|
||||||
client = cachedClient.(*http.Client)
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp *http.Response
|
|
||||||
resendCount := 0
|
|
||||||
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return retry.NewTerminalError(err)
|
|
||||||
}
|
|
||||||
switch body := options.body.(type) {
|
|
||||||
case string:
|
|
||||||
req.Body = io.NopCloser(strings.NewReader(body))
|
|
||||||
case []byte:
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
|
||||||
case io.Reader:
|
|
||||||
req.Body = io.NopCloser(body)
|
|
||||||
case proto.Message:
|
|
||||||
buf, err := proto.Marshal(body)
|
|
||||||
if err != nil {
|
|
||||||
return retry.NewTerminalError(err)
|
|
||||||
}
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
|
||||||
req.Header.Set("Content-Type", "application/octet-stream")
|
|
||||||
default:
|
|
||||||
buf, err := json.Marshal(body)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("unsupported body type: %T", body))
|
|
||||||
}
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
case nil:
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.authenticateAs != "" {
|
|
||||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
|
||||||
} else {
|
|
||||||
resp, err = client.Do(req) //nolint:bodyclose
|
|
||||||
}
|
|
||||||
// retry on connection refused
|
|
||||||
if err != nil {
|
|
||||||
span.RecordError(err)
|
|
||||||
var opErr *net.OpError
|
|
||||||
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
|
||||||
span.AddEvent("Retrying on dial error")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return retry.NewTerminalError(err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode/100 == 5 {
|
|
||||||
resendCount++
|
|
||||||
_, _ = io.ReadAll(resp.Body)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
|
|
||||||
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
|
|
||||||
attribute.String("status", resp.Status),
|
|
||||||
))
|
|
||||||
return errors.New(http.StatusText(resp.StatusCode))
|
|
||||||
}
|
|
||||||
span.SetStatus(codes.Ok, "request completed successfully")
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
retry.WithInitialInterval(1*time.Millisecond),
|
|
||||||
retry.WithMaxInterval(100*time.Millisecond),
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
|
func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error {
|
||||||
span := oteltrace.SpanFromContext(ctx)
|
options := RequestOptions{
|
||||||
var res *http.Response
|
requestCtx: h.Env().Context(),
|
||||||
originalHostname := req.URL.Hostname()
|
}
|
||||||
res, err := client.Do(req)
|
options.apply(opts...)
|
||||||
|
u, err := url.Parse(r.URL().Value())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
span.RecordError(err)
|
return err
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
location := res.Request.URL
|
u.Scheme = "wss"
|
||||||
if location.Hostname() == originalHostname {
|
if options.path != "" || options.query != nil {
|
||||||
// already authenticated
|
u = u.ResolveReference(&url.URL{
|
||||||
span.SetStatus(codes.Ok, "already authenticated")
|
Path: options.path,
|
||||||
return res, nil
|
RawQuery: options.query.Encode(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
fs := forms.Parse(res.Body)
|
ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes(
|
||||||
_, _ = io.ReadAll(res.Body)
|
attribute.String("url", u.String()),
|
||||||
_ = res.Body.Close()
|
))
|
||||||
if len(fs) > 0 {
|
options.requestCtx = ctx
|
||||||
f := fs[0]
|
defer span.End()
|
||||||
f.Inputs["email"] = email
|
|
||||||
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
client := h.getRouteClient(r, &options)
|
||||||
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
|
d := &websocket.Dialer{
|
||||||
formReq, err := f.NewRequestWithContext(ctx, location)
|
HandshakeTimeout: 10 * time.Second,
|
||||||
if err != nil {
|
TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig,
|
||||||
span.RecordError(err)
|
Jar: client.Jar,
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp, err := client.Do(formReq)
|
|
||||||
if err != nil {
|
|
||||||
span.RecordError(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
span.SetStatus(codes.Ok, "form submitted successfully")
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("test bug: expected IDP login form")
|
if options.dialerHook != nil {
|
||||||
|
d, u = options.dialerHook(d, u)
|
||||||
|
}
|
||||||
|
conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
return fmt.Errorf("DialContext: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return f(conn)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ type CommonUpstreamOptions struct {
|
||||||
type CommonUpstreamOption interface {
|
type CommonUpstreamOption interface {
|
||||||
GRPCUpstreamOption
|
GRPCUpstreamOption
|
||||||
HTTPUpstreamOption
|
HTTPUpstreamOption
|
||||||
|
TCPUpstreamOption
|
||||||
}
|
}
|
||||||
|
|
||||||
type commonUpstreamOption func(o *CommonUpstreamOptions)
|
type commonUpstreamOption func(o *CommonUpstreamOptions)
|
||||||
|
@ -25,6 +26,9 @@ func (c commonUpstreamOption) applyGRPC(o *GRPCUpstreamOptions) { c(&o.CommonUps
|
||||||
// applyHTTP implements CommonUpstreamOption.
|
// applyHTTP implements CommonUpstreamOption.
|
||||||
func (c commonUpstreamOption) applyHTTP(o *HTTPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
|
func (c commonUpstreamOption) applyHTTP(o *HTTPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
|
||||||
|
|
||||||
|
// applyTCP implements CommonUpstreamOption.
|
||||||
|
func (c commonUpstreamOption) applyTCP(o *TCPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
|
||||||
|
|
||||||
func WithDisplayName(displayName string) CommonUpstreamOption {
|
func WithDisplayName(displayName string) CommonUpstreamOption {
|
||||||
return commonUpstreamOption(func(o *CommonUpstreamOptions) {
|
return commonUpstreamOption(func(o *CommonUpstreamOptions) {
|
||||||
o.displayName = displayName
|
o.displayName = displayName
|
||||||
|
|
343
internal/testenv/upstreams/tcp.go
Normal file
343
internal/testenv/upstreams/tcp.go
Normal file
|
@ -0,0 +1,343 @@
|
||||||
|
package upstreams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TCPUpstream interface {
|
||||||
|
testenv.Upstream
|
||||||
|
|
||||||
|
Handle(fn func(context.Context, net.Conn) error)
|
||||||
|
|
||||||
|
Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPUpstreamOptions struct {
|
||||||
|
CommonUpstreamOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPUpstreamOption interface {
|
||||||
|
applyTCP(*TCPUpstreamOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpUpstream struct {
|
||||||
|
TCPUpstreamOptions
|
||||||
|
testenv.Aggregate
|
||||||
|
serverPort values.MutableValue[int]
|
||||||
|
serverHandler func(context.Context, net.Conn) error
|
||||||
|
|
||||||
|
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
||||||
|
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
||||||
|
clientTracer values.Value[oteltrace.Tracer]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TCP(opts ...TCPUpstreamOption) TCPUpstream {
|
||||||
|
options := TCPUpstreamOptions{
|
||||||
|
CommonUpstreamOptions: CommonUpstreamOptions{
|
||||||
|
displayName: "TCP Upstream",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, op := range opts {
|
||||||
|
op.applyTCP(&options)
|
||||||
|
}
|
||||||
|
up := &tcpUpstream{
|
||||||
|
TCPUpstreamOptions: options,
|
||||||
|
serverPort: values.Deferred[int](),
|
||||||
|
|
||||||
|
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
||||||
|
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
||||||
|
}
|
||||||
|
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
|
||||||
|
return tp.Tracer(trace.PomeriumCoreTracer)
|
||||||
|
})
|
||||||
|
up.RecordCaller()
|
||||||
|
return up
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial implements TCPUpstream.
|
||||||
|
func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error {
|
||||||
|
options := RequestOptions{
|
||||||
|
requestCtx: t.Env().Context(),
|
||||||
|
dialProtocol: DialHTTP1,
|
||||||
|
}
|
||||||
|
options.apply(opts...)
|
||||||
|
u, err := url.Parse(r.URL().Value())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes(
|
||||||
|
attribute.String("protocol", string(options.dialProtocol)),
|
||||||
|
attribute.String("url", u.String()),
|
||||||
|
))
|
||||||
|
if options.path != "" || options.query != nil {
|
||||||
|
u = u.ResolveReference(&url.URL{
|
||||||
|
Path: options.path,
|
||||||
|
RawQuery: options.query.Encode(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if options.trace != nil {
|
||||||
|
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
||||||
|
}
|
||||||
|
options.requestCtx = ctx
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
var remoteConn *tls.Conn
|
||||||
|
remoteWriter := make(chan *io.PipeWriter, 1)
|
||||||
|
|
||||||
|
connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
|
||||||
|
|
||||||
|
var getClientFn func(context.Context) *http.Client
|
||||||
|
var newRequestFn func(ctx context.Context) (*http.Request, error)
|
||||||
|
switch options.dialProtocol {
|
||||||
|
case DialHTTP1:
|
||||||
|
getClientFn = t.h1Dialer(&options, connectURL, &remoteConn)
|
||||||
|
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
||||||
|
req := (&http.Request{
|
||||||
|
Method: http.MethodConnect,
|
||||||
|
URL: connectURL,
|
||||||
|
Host: u.Host,
|
||||||
|
}).WithContext(ctx)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
case DialHTTP2:
|
||||||
|
getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter)
|
||||||
|
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
||||||
|
req := (&http.Request{
|
||||||
|
Method: http.MethodConnect,
|
||||||
|
URL: connectURL,
|
||||||
|
Host: u.Host,
|
||||||
|
Proto: "HTTP/2",
|
||||||
|
}).WithContext(ctx)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
case DialHTTP3:
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
resp.Body.Close()
|
||||||
|
return errors.New(resp.Status)
|
||||||
|
}
|
||||||
|
if resp.Request.URL.Path == "/oidc/auth" {
|
||||||
|
if options.authenticateAs == "" {
|
||||||
|
return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()")
|
||||||
|
}
|
||||||
|
return errors.New("internal test bug: unexpected IDP redirect")
|
||||||
|
}
|
||||||
|
|
||||||
|
var w io.WriteCloser = remoteConn
|
||||||
|
if options.dialProtocol == DialHTTP2 {
|
||||||
|
w = <-remoteWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := NewRWConn(resp.Body, w)
|
||||||
|
defer conn.Close()
|
||||||
|
return clientHandler(resp.Request.Context(), conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpUpstream) h1Dialer(
|
||||||
|
options *RequestOptions,
|
||||||
|
connectURL *url.URL,
|
||||||
|
remoteConn **tls.Conn,
|
||||||
|
) func(context.Context) *http.Client {
|
||||||
|
jar, _ := cookiejar.New(nil)
|
||||||
|
return func(context.Context) *http.Client {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
RootCAs: t.Env().ServerCAs(),
|
||||||
|
Certificates: options.clientCerts,
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
}
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if *remoteConn != nil {
|
||||||
|
(*remoteConn).Close()
|
||||||
|
*remoteConn = nil
|
||||||
|
}
|
||||||
|
dialer := &tls.Dialer{
|
||||||
|
Config: tlsConfig,
|
||||||
|
}
|
||||||
|
cc, err := dialer.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
||||||
|
}
|
||||||
|
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
||||||
|
if protocol != "http/1.1" {
|
||||||
|
cc.Close()
|
||||||
|
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
*remoteConn = cc.(*tls.Conn)
|
||||||
|
return cc, nil
|
||||||
|
},
|
||||||
|
TLSClientConfig: tlsConfig, // important
|
||||||
|
},
|
||||||
|
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
||||||
|
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
||||||
|
req.Method = http.MethodConnect
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Jar: jar,
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpUpstream) h2Dialer(
|
||||||
|
options *RequestOptions,
|
||||||
|
connectURL *url.URL,
|
||||||
|
remoteConn **tls.Conn,
|
||||||
|
writer chan<- *io.PipeWriter,
|
||||||
|
) func(context.Context) *http.Client {
|
||||||
|
jar, _ := cookiejar.New(nil)
|
||||||
|
return func(context.Context) *http.Client {
|
||||||
|
h1 := &http.Transport{
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
DisableKeepAlives: true,
|
||||||
|
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if *remoteConn != nil {
|
||||||
|
(*remoteConn).Close()
|
||||||
|
*remoteConn = nil
|
||||||
|
}
|
||||||
|
dialer := &tls.Dialer{
|
||||||
|
Config: &tls.Config{
|
||||||
|
RootCAs: t.Env().ServerCAs(),
|
||||||
|
Certificates: options.clientCerts,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cc, err := dialer.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
||||||
|
}
|
||||||
|
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
||||||
|
if protocol != "h2" {
|
||||||
|
cc.Close()
|
||||||
|
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
*remoteConn = cc.(*tls.Conn)
|
||||||
|
|
||||||
|
return cc, nil
|
||||||
|
},
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: t.Env().ServerCAs(),
|
||||||
|
Certificates: options.clientCerts,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := http2.ConfigureTransport(h1); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: h1,
|
||||||
|
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
||||||
|
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
req.Method = http.MethodConnect
|
||||||
|
req.Body = pr
|
||||||
|
req.ContentLength = -1
|
||||||
|
writer <- pw
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Jar: jar,
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle implements TCPUpstream.
|
||||||
|
func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) {
|
||||||
|
t.serverHandler = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port implements TCPUpstream.
|
||||||
|
func (t *tcpUpstream) Addr() values.Value[string] {
|
||||||
|
return values.Bind(t.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s:%d", t.Env().Host(), port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route implements TCPUpstream.
|
||||||
|
func (t *tcpUpstream) Route() testenv.RouteStub {
|
||||||
|
r := &testenv.TCPRoute{}
|
||||||
|
r.To(values.Bind(t.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port)
|
||||||
|
}))
|
||||||
|
t.Add(r)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run implements TCPUpstream.
|
||||||
|
func (t *tcpUpstream) Run(ctx context.Context) error {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
context.AfterFunc(ctx, func() {
|
||||||
|
listener.Close()
|
||||||
|
})
|
||||||
|
t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||||
|
if t.serverTracerProviderOverride != nil {
|
||||||
|
t.serverTracerProvider.Resolve(t.serverTracerProviderOverride)
|
||||||
|
} else {
|
||||||
|
t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName))
|
||||||
|
}
|
||||||
|
if t.clientTracerProviderOverride != nil {
|
||||||
|
t.clientTracerProvider.Resolve(t.clientTracerProviderOverride)
|
||||||
|
} else {
|
||||||
|
t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client"))
|
||||||
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := t.serverHandler(ctx, conn); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic("server handler error: " + err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ testenv.Upstream = (*tcpUpstream)(nil)
|
||||||
|
_ TCPUpstream = (*tcpUpstream)(nil)
|
||||||
|
)
|
195
internal/testenv/upstreams/util.go
Normal file
195
internal/testenv/upstreams/util.go
Normal file
|
@ -0,0 +1,195 @@
|
||||||
|
package upstreams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/integration/forms"
|
||||||
|
"github.com/pomerium/pomerium/internal/retry"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/codes"
|
||||||
|
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||||
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrRetry = errors.New("error")
|
||||||
|
|
||||||
|
func doAuthenticatedRequest(
|
||||||
|
ctx context.Context,
|
||||||
|
newRequest func(context.Context) (*http.Request, error),
|
||||||
|
getClient func(context.Context) *http.Client,
|
||||||
|
options *RequestOptions,
|
||||||
|
) (*http.Response, error) {
|
||||||
|
var resp *http.Response
|
||||||
|
resendCount := 0
|
||||||
|
client := getClient(ctx)
|
||||||
|
|
||||||
|
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
||||||
|
req, err := newRequest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return retry.NewTerminalError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch body := options.body.(type) {
|
||||||
|
case string:
|
||||||
|
req.Body = io.NopCloser(strings.NewReader(body))
|
||||||
|
case []byte:
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
case io.Reader:
|
||||||
|
req.Body = io.NopCloser(body)
|
||||||
|
case proto.Message:
|
||||||
|
buf, err := proto.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return retry.NewTerminalError(err)
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||||
|
req.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
default:
|
||||||
|
buf, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("unsupported body type: %T", body))
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
case nil:
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.headers != nil && req.Header == nil {
|
||||||
|
req.Header = http.Header{}
|
||||||
|
}
|
||||||
|
for k, v := range options.headers {
|
||||||
|
req.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.authenticateAs != "" {
|
||||||
|
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs, true) //nolint:bodyclose
|
||||||
|
} else {
|
||||||
|
resp, err = client.Do(req) //nolint:bodyclose
|
||||||
|
}
|
||||||
|
// retry on connection refused
|
||||||
|
span := oteltrace.SpanFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
span.RecordError(err)
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
||||||
|
span.AddEvent("Retrying on dial error")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return retry.NewTerminalError(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode/100 == 5 {
|
||||||
|
resendCount++
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
|
||||||
|
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
|
||||||
|
attribute.String("status", resp.Status),
|
||||||
|
))
|
||||||
|
return errors.New(http.StatusText(resp.StatusCode))
|
||||||
|
}
|
||||||
|
span.SetStatus(codes.Ok, "request completed successfully")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
retry.WithInitialInterval(1*time.Millisecond),
|
||||||
|
retry.WithMaxInterval(100*time.Millisecond),
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string, checkLocation bool) (*http.Response, error) {
|
||||||
|
span := oteltrace.SpanFromContext(ctx)
|
||||||
|
var res *http.Response
|
||||||
|
originalHostname := req.URL.Hostname()
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
span.RecordError(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
location := res.Request.URL
|
||||||
|
if checkLocation && location.Hostname() == originalHostname {
|
||||||
|
// already authenticated
|
||||||
|
span.SetStatus(codes.Ok, "already authenticated")
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
fs := forms.Parse(res.Body)
|
||||||
|
_, _ = io.ReadAll(res.Body)
|
||||||
|
_ = res.Body.Close()
|
||||||
|
if len(fs) > 0 {
|
||||||
|
f := fs[0]
|
||||||
|
f.Inputs["email"] = email
|
||||||
|
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
||||||
|
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
|
||||||
|
formReq, err := f.NewRequestWithContext(ctx, location)
|
||||||
|
if err != nil {
|
||||||
|
span.RecordError(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(formReq)
|
||||||
|
if err != nil {
|
||||||
|
span.RecordError(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
span.SetStatus(codes.Ok, "form submitted successfully")
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("test bug: expected IDP login form")
|
||||||
|
}
|
||||||
|
|
||||||
|
type rwConn struct {
|
||||||
|
serverReader io.ReadCloser
|
||||||
|
serverWriter io.WriteCloser
|
||||||
|
|
||||||
|
net.Conn
|
||||||
|
remote net.Conn
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRWConn(reader io.ReadCloser, writer io.WriteCloser) net.Conn {
|
||||||
|
rwc := &rwConn{
|
||||||
|
serverReader: reader,
|
||||||
|
serverWriter: writer,
|
||||||
|
wg: &sync.WaitGroup{},
|
||||||
|
}
|
||||||
|
|
||||||
|
rwc.Conn, rwc.remote = net.Pipe()
|
||||||
|
rwc.wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer rwc.wg.Done()
|
||||||
|
_, _ = io.Copy(rwc.remote, rwc.serverReader)
|
||||||
|
rwc.remote.Close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer rwc.wg.Done()
|
||||||
|
_, _ = io.Copy(rwc.serverWriter, rwc.remote)
|
||||||
|
rwc.serverWriter.Close()
|
||||||
|
}()
|
||||||
|
return rwc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rwc *rwConn) Close() error {
|
||||||
|
var err error
|
||||||
|
rwc.closeOnce.Do(func() {
|
||||||
|
readerErr := rwc.serverReader.Close()
|
||||||
|
localErr := rwc.Conn.Close()
|
||||||
|
rwc.wg.Wait()
|
||||||
|
err = errors.Join(localErr, readerErr)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = (*rwConn)(nil)
|
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/minio/minio-go/v7"
|
"github.com/minio/minio-go/v7"
|
||||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/testcontainers/testcontainers-go"
|
"github.com/testcontainers/testcontainers-go"
|
||||||
"github.com/testcontainers/testcontainers-go/wait"
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/testcontainers/testcontainers-go"
|
"github.com/testcontainers/testcontainers-go"
|
||||||
"github.com/testcontainers/testcontainers-go/wait"
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
"unique"
|
"unique"
|
||||||
|
|
||||||
gocmp "github.com/google/go-cmp/cmp"
|
gocmp "github.com/google/go-cmp/cmp"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
|
|
@ -26,11 +26,11 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/events"
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config"
|
derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy"
|
"github.com/pomerium/pomerium/pkg/envoy"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/proxy"
|
"github.com/pomerium/pomerium/proxy"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
|
@ -713,7 +713,7 @@ func TestSharedResourceMonitor(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBootstrapConfig(t *testing.T) {
|
func TestBootstrapConfig(t *testing.T) {
|
||||||
b := envoyconfig.New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil)
|
b := envoyconfig.New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true)
|
||||||
testEnvoyPid := 99
|
testEnvoyPid := 99
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
monitor, err := NewSharedResourceMonitor(context.Background(), config.NewStaticSource(nil), tempDir, WithCgroupDriver(&cgroupV2Driver{
|
monitor, err := NewSharedResourceMonitor(context.Background(), config.NewStaticSource(nil), tempDir, WithCgroupDriver(&cgroupV2Driver{
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Rec
|
||||||
cs := &changeSet{now: timestamppb.Now()}
|
cs := &changeSet{now: timestamppb.Now()}
|
||||||
|
|
||||||
for _, rec := range current.GetRemoved(target).Flatten() {
|
for _, rec := range current.GetRemoved(target).Flatten() {
|
||||||
cs.Remove(rec.GetType(), rec.GetId())
|
cs.Remove(rec)
|
||||||
}
|
}
|
||||||
for _, rec := range current.GetModified(target, cmpFn).Flatten() {
|
for _, rec := range current.GetModified(target, cmpFn).Flatten() {
|
||||||
cs.Upsert(rec)
|
cs.Upsert(rec)
|
||||||
|
@ -33,13 +33,10 @@ type changeSet struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove adds a record to the change set.
|
// Remove adds a record to the change set.
|
||||||
func (cs *changeSet) Remove(typ string, id string) {
|
func (cs *changeSet) Remove(record *Record) {
|
||||||
cs.updates = append(cs.updates, &Record{
|
record = proto.Clone(record).(*Record)
|
||||||
Type: typ,
|
record.DeletedAt = cs.now
|
||||||
Id: id,
|
cs.updates = append(cs.updates, record)
|
||||||
DeletedAt: cs.now,
|
|
||||||
Data: &anypb.Any{TypeUrl: typ},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upsert adds a record to the change set.
|
// Upsert adds a record to the change set.
|
||||||
|
|
52
pkg/grpc/databroker/changeset_test.go
Normal file
52
pkg/grpc/databroker/changeset_test.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
package databroker_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/datasource/pkg/directory"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetChangeset(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rsb1 := databroker.RecordSetBundle{}
|
||||||
|
rsb2 := databroker.RecordSetBundle{}
|
||||||
|
updates := databroker.GetChangeSet(rsb1, rsb2, func(record1, record2 *databroker.Record) bool {
|
||||||
|
return cmp.Equal(record1, record2, protocmp.Transform())
|
||||||
|
})
|
||||||
|
assert.Len(t, updates, 0)
|
||||||
|
|
||||||
|
rsb1 = databroker.RecordSetBundle{}
|
||||||
|
rsb1.Add(&databroker.Record{
|
||||||
|
Type: directory.UserRecordType,
|
||||||
|
Id: "user-1",
|
||||||
|
Data: protoutil.NewAny(mustNewStruct(map[string]any{
|
||||||
|
"email": "user-1@example.com",
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
rsb2 = databroker.RecordSetBundle{}
|
||||||
|
updates = databroker.GetChangeSet(rsb1, rsb2, func(record1, record2 *databroker.Record) bool {
|
||||||
|
return cmp.Equal(record1, record2, protocmp.Transform())
|
||||||
|
})
|
||||||
|
if assert.Len(t, updates, 1) {
|
||||||
|
assert.Equal(t, directory.UserRecordType, updates[0].GetType())
|
||||||
|
assert.Equal(t, "type.googleapis.com/google.protobuf.Struct", updates[0].GetData().GetTypeUrl(),
|
||||||
|
"should preserve data type")
|
||||||
|
assert.NotNil(t, updates[0].GetDeletedAt())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustNewStruct(m map[string]any) *structpb.Struct {
|
||||||
|
s, err := structpb.NewStruct(m)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
|
@ -3,14 +3,12 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
status "google.golang.org/grpc/status"
|
status "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
structpb "google.golang.org/protobuf/types/known/structpb"
|
structpb "google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
@ -53,34 +51,6 @@ func Get(ctx context.Context, client DataBrokerServiceClient, object recordObjec
|
||||||
return res.GetRecord().GetData().UnmarshalTo(object)
|
return res.GetRecord().GetData().UnmarshalTo(object)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetViaJSON gets a record from the databroker, marshals it to JSON, and then unmarshals it to the given type.
|
|
||||||
func GetViaJSON[T any](ctx context.Context, client DataBrokerServiceClient, recordType, recordID string) (*T, error) {
|
|
||||||
res, err := client.Get(ctx, &GetRequest{
|
|
||||||
Type: recordType,
|
|
||||||
Id: recordID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := res.GetRecord().GetData().UnmarshalNew()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bs, err := protojson.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var obj T
|
|
||||||
err = json.Unmarshal(bs, &obj)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &obj, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts a record into the databroker.
|
// Put puts a record into the databroker.
|
||||||
func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) {
|
func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) {
|
||||||
records := make([]*Record, len(objects))
|
records := make([]*Record, len(objects))
|
||||||
|
|
|
@ -16,11 +16,11 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Name identifies the generic OpenID Connect provider.
|
// Name identifies the generic OpenID Connect provider.
|
||||||
|
|
|
@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) {
|
||||||
cache.fastcache.Set(key, item)
|
cache.fastcache.Set(key, item)
|
||||||
cache.mu.Unlock()
|
cache.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GlobalCache is a global cache with a TTL of one minute.
|
||||||
|
var GlobalCache = NewGlobalCache(time.Minute)
|
||||||
|
|
|
@ -48,7 +48,7 @@ func TestQueryTracing(t *testing.T) {
|
||||||
snippets.WaitStartupComplete(env)
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
resp, err := up.Get(route, upstreams.AuthenticateAs("user@example.com"), upstreams.Path("/foo"))
|
resp, err := up.Get(route, upstreams.AuthenticateAs("user@example.com"), upstreams.Path("/foo"))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
io.ReadAll(resp.Body)
|
io.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
||||||
Deterministic: true,
|
Deterministic: true,
|
||||||
}).Marshal(res)
|
}).Marshal(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerRecord uses a querier to get a databroker record.
|
||||||
|
func GetDataBrokerRecord(
|
||||||
|
ctx context.Context,
|
||||||
|
recordType string,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (*databroker.Record, error) {
|
||||||
|
q := GetQuerier(ctx)
|
||||||
|
|
||||||
|
req := &databroker.QueryRequest{
|
||||||
|
Type: recordType,
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
req.SetFilterByIDOrIndex(recordID)
|
||||||
|
|
||||||
|
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||||
|
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||||
|
q.InvalidateCache(ctx, req)
|
||||||
|
} else {
|
||||||
|
return res.GetRecords()[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// retry with an up to date cache
|
||||||
|
res, err = q.Query(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.GetRecords()[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
|
||||||
|
func GetDataBrokerMessage[T any, TMessage interface {
|
||||||
|
*T
|
||||||
|
proto.Message
|
||||||
|
}](
|
||||||
|
ctx context.Context,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (TMessage, error) {
|
||||||
|
var msg T
|
||||||
|
|
||||||
|
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = record.GetData().UnmarshalTo(TMessage(&msg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return TMessage(&msg), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
|
||||||
|
func GetDataBrokerObjectViaJSON[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
recordType string,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (*T, error) {
|
||||||
|
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := record.GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bs, err := protojson.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj T
|
||||||
|
err = json.Unmarshal(bs, &obj)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &obj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
|
||||||
|
func InvalidateCacheForDataBrokerRecords(
|
||||||
|
ctx context.Context,
|
||||||
|
records ...*databroker.Record,
|
||||||
|
) {
|
||||||
|
for _, record := range records {
|
||||||
|
q := &databroker.QueryRequest{
|
||||||
|
Type: record.GetType(),
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
q.SetFilterByIDOrIndex(record.GetId())
|
||||||
|
GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
101
pkg/storage/querier_test.go
Normal file
101
pkg/storage/querier_test.go
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
package storage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/datasource/pkg/directory"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDataBrokerRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
recordVersion, queryVersion uint64
|
||||||
|
underlyingQueryCount, cachedQueryCount int
|
||||||
|
}{
|
||||||
|
{"cached", 1, 1, 1, 2},
|
||||||
|
{"invalidated", 1, 2, 3, 4},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||||
|
|
||||||
|
sq := storage.NewStaticQuerier(s1)
|
||||||
|
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||||
|
qctx := storage.WithQuerier(ctx, cq)
|
||||||
|
|
||||||
|
s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, s)
|
||||||
|
|
||||||
|
s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, s)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDataBrokerMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.GetContext(t, time.Minute)
|
||||||
|
|
||||||
|
s1 := &session.Session{Id: "s1"}
|
||||||
|
sq := storage.NewStaticQuerier(s1)
|
||||||
|
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||||
|
qctx := storage.WithQuerier(ctx, cq)
|
||||||
|
|
||||||
|
s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform()))
|
||||||
|
|
||||||
|
_, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0)
|
||||||
|
assert.ErrorIs(t, err, storage.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDataBrokerObjectViaJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.GetContext(t, time.Minute)
|
||||||
|
|
||||||
|
du1 := &directory.User{
|
||||||
|
ID: "u1",
|
||||||
|
Email: "u1@example.com",
|
||||||
|
DisplayName: "User 1!",
|
||||||
|
}
|
||||||
|
sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1))
|
||||||
|
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||||
|
qctx := storage.WithQuerier(ctx, cq)
|
||||||
|
|
||||||
|
du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record {
|
||||||
|
m := map[string]any{}
|
||||||
|
bs, _ := json.Marshal(directoryUser)
|
||||||
|
_ = json.Unmarshal(bs, &m)
|
||||||
|
s, _ := structpb.NewStruct(m)
|
||||||
|
return storage.NewStaticRecord(directory.UserRecordType, s)
|
||||||
|
}
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,13 +13,13 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||||
"github.com/pomerium/pomerium/internal/testutil/tracetest/mock_otlptrace"
|
"github.com/pomerium/pomerium/internal/testutil/tracetest/mock_otlptrace"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
|
@ -9,8 +9,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/trace/noop"
|
"go.opentelemetry.io/otel/trace/noop"
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
@ -9,28 +9,24 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
||||||
client := p.state.Load().dataBrokerClient
|
|
||||||
|
|
||||||
isImpersonated = false
|
isImpersonated = false
|
||||||
s, err = session.Get(ctx, client, sessionID)
|
s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0)
|
||||||
if s.GetImpersonateSessionId() != "" {
|
if s.GetImpersonateSessionId() != "" {
|
||||||
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
|
s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0)
|
||||||
isImpersonated = true
|
isImpersonated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, isImpersonated, err
|
return s, isImpersonated, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||||
client := p.state.Load().dataBrokerClient
|
return storage.GetDataBrokerMessage[user.User](ctx, userID, 0)
|
||||||
return user.Get(ctx, client, userID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||||
|
@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
||||||
client := p.state.Load().dataBrokerClient
|
record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0)
|
||||||
|
data.IsEnterprise = record != nil
|
||||||
res, _ := client.Get(ctx, &databroker.GetRequest{
|
|
||||||
Type: "type.googleapis.com/pomerium.config.Config",
|
|
||||||
Id: "dashboard-settings",
|
|
||||||
})
|
|
||||||
data.IsEnterprise = res.GetRecord() != nil
|
|
||||||
if !data.IsEnterprise {
|
if !data.IsEnterprise {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
|
data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0)
|
||||||
if data.DirectoryUser != nil {
|
if data.DirectoryUser != nil {
|
||||||
for _, groupID := range data.DirectoryUser.GroupIDs {
|
for _, groupID := range data.DirectoryUser.GroupIDs {
|
||||||
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
|
directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0)
|
||||||
if directoryGroup != nil {
|
if directoryGroup != nil {
|
||||||
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_getUserInfoData(t *testing.T) {
|
func Test_getUserInfoData(t *testing.T) {
|
||||||
|
@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
proxy, err := New(ctx, &config.Config{Options: opts})
|
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proxy.state.Load().dataBrokerClient = client
|
proxy.state.Load().dataBrokerClient = client
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewQuerier(client))
|
||||||
|
|
||||||
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
||||||
makeRecord(&session.Session{
|
makeRecord(&session.Session{
|
||||||
|
@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
"group_ids": []any{"G1", "G2", "G3"},
|
"group_ids": []any{"G1", "G2", "G3"},
|
||||||
})))
|
})))
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil)
|
||||||
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
||||||
ID: "S1",
|
ID: "S1",
|
||||||
}))
|
}))
|
||||||
|
@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
assert.Equal(t, "S1", data.Session.Id)
|
assert.Equal(t, "S1", data.Session.Id)
|
||||||
assert.Equal(t, "U1", data.User.Id)
|
assert.Equal(t, "U1", data.User.Id)
|
||||||
assert.True(t, data.IsEnterprise)
|
assert.True(t, data.IsEnterprise)
|
||||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
if assert.NotNil(t, data.DirectoryUser) {
|
||||||
|
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// registerDashboardHandlers returns the proxy service's ServeMux
|
// registerDashboardHandlers returns the proxy service's ServeMux
|
||||||
|
|
|
@ -19,8 +19,9 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/proxy/portal"
|
"github.com/pomerium/pomerium/proxy/portal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
||||||
r.StrictSlash(true)
|
r.StrictSlash(true)
|
||||||
// dashboard handlers are registered to all routes
|
// dashboard handlers are registered to all routes
|
||||||
r = p.registerDashboardHandlers(r, opts)
|
r = p.registerDashboardHandlers(r, opts)
|
||||||
|
// attach the querier to the context
|
||||||
|
r.Use(p.querierMiddleware)
|
||||||
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
|
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
|
||||||
|
|
||||||
p.currentRouter.Store(r)
|
p.currentRouter.Store(r)
|
||||||
|
@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
||||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
p.currentRouter.Load().ServeHTTP(w, r)
|
p.currentRouter.Load().ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) querierMiddleware(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier(
|
||||||
|
storage.NewQuerier(p.state.Load().dataBrokerClient),
|
||||||
|
storage.GlobalCache,
|
||||||
|
))
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
@ -88,19 +89,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
|
||||||
|
|
||||||
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||||
Type: recordType,
|
|
||||||
Id: recordID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return res.GetRecord(), nil
|
|
||||||
},
|
},
|
||||||
func(ctx context.Context, records []*databroker.Record) error {
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
Records: records,
|
Records: records,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue