diff --git a/proxy/data.go b/proxy/data.go index 8abfc7205..3a63c3d3d 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -34,12 +34,22 @@ func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) } func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { - options := p.currentOptions.Load() + cfg := p.currentConfig.Load() state := p.state.Load() data := handlers.UserInfoData{ CSRFToken: csrf.Token(r), - BrandingOptions: options.BrandingOptions, + BrandingOptions: cfg.Options.BrandingOptions, + } + + if s, err := state.incomingIDPTokenSessionCreator.CreateSession(r.Context(), cfg, nil, r); err == nil { + data.Session = s + data.IsImpersonated = false + + data.User, err = p.getUser(r.Context(), data.Session.GetUserId()) + if err != nil { + data.User = &user.User{Id: data.Session.GetUserId()} + } } ss, err := p.state.Load().sessionStore.LoadSessionState(r) @@ -85,7 +95,7 @@ func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.U } func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) { - options := p.currentOptions.Load() + options := p.currentConfig.Load().Options state := p.state.Load() ss, err := p.state.Load().sessionStore.LoadSessionState(r) diff --git a/proxy/data_test.go b/proxy/data_test.go index b15a02539..8f9042c12 100644 --- a/proxy/data_test.go +++ b/proxy/data_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -17,6 +17,7 @@ import ( "github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/testutil" configpb "github.com/pomerium/pomerium/pkg/grpc/config" @@ -32,42 +33,64 @@ func Test_getUserInfoData(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() - cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { - databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, trace.NewNoopTracerProvider())) + t.Run("incoming idp token", func(t *testing.T) { + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider())) + }) + t.Cleanup(func() { cc.Close() }) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + + opts := testOptions(t) + proxy, err := New(ctx, &config.Config{Options: opts}) + require.NoError(t, err) + proxy.state.Load().dataBrokerClient = client + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN") + data := proxy.getUserInfoData(r) + assert.NotNil(t, data.Session) + assert.NotNil(t, data.User) }) - t.Cleanup(func() { cc.Close() }) - client := databrokerpb.NewDataBrokerServiceClient(cc) + t.Run("session", func(t *testing.T) { + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider())) + }) + t.Cleanup(func() { cc.Close() }) - opts := testOptions(t) - proxy, err := New(ctx, &config.Config{Options: opts}) - require.NoError(t, err) - proxy.state.Load().dataBrokerClient = client + client := databrokerpb.NewDataBrokerServiceClient(cc) - require.NoError(t, databrokerpb.PutMulti(ctx, client, - makeRecord(&session.Session{ - Id: "S1", - UserId: "U1", - }), - makeRecord(&user.User{ - Id: "U1", - }), - makeRecord(&configpb.Config{ - Name: "dashboard-settings", - }), - makeStructRecord(directory.UserRecordType, "U1", map[string]any{ - "group_ids": []any{"G1", "G2", "G3"}, - }))) + opts := testOptions(t) + proxy, err := New(ctx, &config.Config{Options: opts}) + require.NoError(t, err) + proxy.state.Load().dataBrokerClient = client - r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) - r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ - ID: "S1", - })) - data := proxy.getUserInfoData(r) - assert.Equal(t, "S1", data.Session.Id) - assert.Equal(t, "U1", data.User.Id) - assert.True(t, data.IsEnterprise) - assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) + require.NoError(t, databrokerpb.PutMulti(ctx, client, + makeRecord(&session.Session{ + Id: "S1", + UserId: "U1", + }), + makeRecord(&user.User{ + Id: "U1", + }), + makeRecord(&configpb.Config{ + Name: "dashboard-settings", + }), + makeStructRecord(directory.UserRecordType, "U1", map[string]any{ + "group_ids": []any{"G1", "G2", "G3"}, + }))) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ + ID: "S1", + })) + data := proxy.getUserInfoData(r) + assert.Equal(t, "S1", data.Session.Id) + assert.Equal(t, "U1", data.User.Id) + assert.True(t, data.IsEnterprise) + assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) + }) } func makeRecord(object interface { diff --git a/proxy/handlers.go b/proxy/handlers.go index 3b0b19ec4..dda079a88 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -78,7 +78,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error { state := p.state.Load() var redirectURL *url.URL - signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL() + signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL() if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } @@ -126,7 +126,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error { // using the authenticate service. func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error { state := p.state.Load() - options := p.currentOptions.Load() + options := p.currentConfig.Load().Options redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { diff --git a/proxy/handlers_portal.go b/proxy/handlers_portal.go index 66b016577..34523f542 100644 --- a/proxy/handlers_portal.go +++ b/proxy/handlers_portal.go @@ -40,7 +40,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error { } func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route { - options := p.currentOptions.Load() + options := p.currentConfig.Load().Options pu := p.getPortalUser(u) var routes []*config.Policy for route := range options.GetAllPolicies() { diff --git a/proxy/proxy.go b/proxy/proxy.go index 07050271f..3926d3589 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -57,7 +57,7 @@ func ValidateOptions(o *config.Options) error { // Proxy stores all the information associated with proxying a request. type Proxy struct { state *atomicutil.Value[*proxyState] - currentOptions *atomicutil.Value[*config.Options] + currentConfig *atomicutil.Value[*config.Config] currentRouter *atomicutil.Value[*mux.Router] webauthn *webauthn.Handler tracerProvider oteltrace.TracerProvider @@ -76,7 +76,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { p := &Proxy{ tracerProvider: tracerProvider, state: atomicutil.NewValue(state), - currentOptions: config.NewAtomicOptions(), + currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}), currentRouter: atomicutil.NewValue(httputil.NewRouter()), logoProvider: portal.NewLogoProvider(), } @@ -84,7 +84,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { p.webauthn = webauthn.New(p.getWebauthnState) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { - return int64(p.currentOptions.Load().NumPolicies()) + return int64(p.currentConfig.Load().Options.NumPolicies()) }) return p, nil @@ -101,7 +101,7 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) { return } - p.currentOptions.Store(cfg.Options) + p.currentConfig.Store(cfg) if err := p.setHandlers(ctx, cfg.Options); err != nil { log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index da49e0161..260dd2795 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "encoding/json" "net/http" "net/http/httptest" "net/url" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/jwtutil" hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers" ) @@ -32,7 +34,16 @@ func testOptions(t *testing.T) *config.Options { hpkePrivateKey, err := opts.GetHPKEPrivateKey() require.NoError(t, err) - authnSrv := httptest.NewServer(hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())) + authnSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case hpke_handlers.HPKEPublicKeyPath: + hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()) + case "/.pomerium/verify-access-token": + json.NewEncoder(w).Encode(map[string]any{"valid": true, "claims": jwtutil.Claims{}}) + default: + http.NotFound(w, r) + } + })) t.Cleanup(authnSrv.Close) opts.AuthenticateURLString = authnSrv.URL diff --git a/proxy/state.go b/proxy/state.go index c07e1c51b..dab2563d4 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -5,13 +5,14 @@ import ( "net/http" "net/url" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + oteltrace "go.opentelemetry.io/otel/trace" + googlegrpc "google.golang.org/grpc" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - oteltrace "go.opentelemetry.io/otel/trace" - googlegrpc "google.golang.org/grpc" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) @@ -32,6 +33,7 @@ type proxyState struct { dataBrokerClient databroker.DataBrokerServiceClient programmaticRedirectDomainWhitelist []string authenticateFlow authenticateFlow + incomingIDPTokenSessionCreator config.IncomingIDPTokenSessionCreator } func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) { @@ -83,5 +85,24 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace return nil, err } + state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator( + func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) { + res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{ + Type: recordType, + Id: recordID, + }) + if err != nil { + return nil, err + } + return res.GetRecord(), nil + }, + func(ctx context.Context, records []*databroker.Record) error { + _, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{ + Records: records, + }) + return err + }, + ) + return state, nil }