support loading idp token sessions in the proxy service

This commit is contained in:
Caleb Doxsey 2025-02-19 15:39:48 -07:00
parent 4b95eda51e
commit d7c2927cfa
7 changed files with 111 additions and 46 deletions

View file

@ -34,12 +34,22 @@ func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error)
} }
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
options := p.currentOptions.Load() cfg := p.currentConfig.Load()
state := p.state.Load() state := p.state.Load()
data := handlers.UserInfoData{ data := handlers.UserInfoData{
CSRFToken: csrf.Token(r), CSRFToken: csrf.Token(r),
BrandingOptions: options.BrandingOptions, BrandingOptions: cfg.Options.BrandingOptions,
}
if s, err := state.incomingIDPTokenSessionCreator.CreateSession(r.Context(), cfg, nil, r); err == nil {
data.Session = s
data.IsImpersonated = false
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())
if err != nil {
data.User = &user.User{Id: data.Session.GetUserId()}
}
} }
ss, err := p.state.Load().sessionStore.LoadSessionState(r) ss, err := p.state.Load().sessionStore.LoadSessionState(r)
@ -85,7 +95,7 @@ func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.U
} }
func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) { func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) {
options := p.currentOptions.Load() options := p.currentConfig.Load().Options
state := p.state.Load() state := p.state.Load()
ss, err := p.state.Load().sessionStore.LoadSessionState(r) ss, err := p.state.Load().sessionStore.LoadSessionState(r)

View file

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
@ -17,6 +17,7 @@ import (
"github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config" configpb "github.com/pomerium/pomerium/pkg/grpc/config"
@ -32,42 +33,64 @@ func Test_getUserInfoData(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout() defer clearTimeout()
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { t.Run("incoming idp token", func(t *testing.T) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, trace.NewNoopTracerProvider())) cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
opts := testOptions(t)
proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
r.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
data := proxy.getUserInfoData(r)
assert.NotNil(t, data.Session)
assert.NotNil(t, data.User)
}) })
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc) t.Run("session", func(t *testing.T) {
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
opts := testOptions(t) client := databrokerpb.NewDataBrokerServiceClient(cc)
proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client
require.NoError(t, databrokerpb.PutMulti(ctx, client, opts := testOptions(t)
makeRecord(&session.Session{ proxy, err := New(ctx, &config.Config{Options: opts})
Id: "S1", require.NoError(t, err)
UserId: "U1", proxy.state.Load().dataBrokerClient = client
}),
makeRecord(&user.User{
Id: "U1",
}),
makeRecord(&configpb.Config{
Name: "dashboard-settings",
}),
makeStructRecord(directory.UserRecordType, "U1", map[string]any{
"group_ids": []any{"G1", "G2", "G3"},
})))
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) require.NoError(t, databrokerpb.PutMulti(ctx, client,
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ makeRecord(&session.Session{
ID: "S1", Id: "S1",
})) UserId: "U1",
data := proxy.getUserInfoData(r) }),
assert.Equal(t, "S1", data.Session.Id) makeRecord(&user.User{
assert.Equal(t, "U1", data.User.Id) Id: "U1",
assert.True(t, data.IsEnterprise) }),
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) makeRecord(&configpb.Config{
Name: "dashboard-settings",
}),
makeStructRecord(directory.UserRecordType, "U1", map[string]any{
"group_ids": []any{"G1", "G2", "G3"},
})))
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
ID: "S1",
}))
data := proxy.getUserInfoData(r)
assert.Equal(t, "S1", data.Session.Id)
assert.Equal(t, "U1", data.User.Id)
assert.True(t, data.IsEnterprise)
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
})
} }
func makeRecord(object interface { func makeRecord(object interface {

View file

@ -78,7 +78,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load() state := p.state.Load()
var redirectURL *url.URL var redirectURL *url.URL
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL() signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, err) return httputil.NewError(http.StatusInternalServerError, err)
} }
@ -126,7 +126,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
// using the authenticate service. // using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error { func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load() state := p.state.Load()
options := p.currentOptions.Load() options := p.currentConfig.Load().Options
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {

View file

@ -40,7 +40,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
} }
func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route { func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
options := p.currentOptions.Load() options := p.currentConfig.Load().Options
pu := p.getPortalUser(u) pu := p.getPortalUser(u)
var routes []*config.Policy var routes []*config.Policy
for route := range options.GetAllPolicies() { for route := range options.GetAllPolicies() {

View file

@ -57,7 +57,7 @@ func ValidateOptions(o *config.Options) error {
// Proxy stores all the information associated with proxying a request. // Proxy stores all the information associated with proxying a request.
type Proxy struct { type Proxy struct {
state *atomicutil.Value[*proxyState] state *atomicutil.Value[*proxyState]
currentOptions *atomicutil.Value[*config.Options] currentConfig *atomicutil.Value[*config.Config]
currentRouter *atomicutil.Value[*mux.Router] currentRouter *atomicutil.Value[*mux.Router]
webauthn *webauthn.Handler webauthn *webauthn.Handler
tracerProvider oteltrace.TracerProvider tracerProvider oteltrace.TracerProvider
@ -76,7 +76,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
p := &Proxy{ p := &Proxy{
tracerProvider: tracerProvider, tracerProvider: tracerProvider,
state: atomicutil.NewValue(state), state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}),
currentRouter: atomicutil.NewValue(httputil.NewRouter()), currentRouter: atomicutil.NewValue(httputil.NewRouter()),
logoProvider: portal.NewLogoProvider(), logoProvider: portal.NewLogoProvider(),
} }
@ -84,7 +84,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
p.webauthn = webauthn.New(p.getWebauthnState) p.webauthn = webauthn.New(p.getWebauthnState)
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
return int64(p.currentOptions.Load().NumPolicies()) return int64(p.currentConfig.Load().Options.NumPolicies())
}) })
return p, nil return p, nil
@ -101,7 +101,7 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
return return
} }
p.currentOptions.Store(cfg.Options) p.currentConfig.Store(cfg)
if err := p.setHandlers(ctx, cfg.Options); err != nil { if err := p.setHandlers(ctx, cfg.Options); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
} }

View file

@ -2,6 +2,7 @@ package proxy
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -11,6 +12,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/jwtutil"
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers" hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
) )
@ -32,7 +34,16 @@ func testOptions(t *testing.T) *config.Options {
hpkePrivateKey, err := opts.GetHPKEPrivateKey() hpkePrivateKey, err := opts.GetHPKEPrivateKey()
require.NoError(t, err) require.NoError(t, err)
authnSrv := httptest.NewServer(hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())) authnSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case hpke_handlers.HPKEPublicKeyPath:
hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())
case "/.pomerium/verify-access-token":
json.NewEncoder(w).Encode(map[string]any{"valid": true, "claims": jwtutil.Claims{}})
default:
http.NotFound(w, r)
}
}))
t.Cleanup(authnSrv.Close) t.Cleanup(authnSrv.Close)
opts.AuthenticateURLString = authnSrv.URL opts.AuthenticateURLString = authnSrv.URL

View file

@ -5,13 +5,14 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
googlegrpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
googlegrpc "google.golang.org/grpc"
) )
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -32,6 +33,7 @@ type proxyState struct {
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
programmaticRedirectDomainWhitelist []string programmaticRedirectDomainWhitelist []string
authenticateFlow authenticateFlow authenticateFlow authenticateFlow
incomingIDPTokenSessionCreator config.IncomingIDPTokenSessionCreator
} }
func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) { func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) {
@ -83,5 +85,24 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
return nil, err return nil, err
} }
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: recordType,
Id: recordID,
})
if err != nil {
return nil, err
}
return res.GetRecord(), nil
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
return err
},
)
return state, nil return state, nil
} }