support loading idp token sessions in the proxy service

This commit is contained in:
Caleb Doxsey 2025-02-19 15:39:48 -07:00
parent 4b95eda51e
commit d7c2927cfa
7 changed files with 111 additions and 46 deletions

View file

@ -34,12 +34,22 @@ func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error)
}
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
options := p.currentOptions.Load()
cfg := p.currentConfig.Load()
state := p.state.Load()
data := handlers.UserInfoData{
CSRFToken: csrf.Token(r),
BrandingOptions: options.BrandingOptions,
BrandingOptions: cfg.Options.BrandingOptions,
}
if s, err := state.incomingIDPTokenSessionCreator.CreateSession(r.Context(), cfg, nil, r); err == nil {
data.Session = s
data.IsImpersonated = false
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())
if err != nil {
data.User = &user.User{Id: data.Session.GetUserId()}
}
}
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
@ -85,7 +95,7 @@ func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.U
}
func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) {
options := p.currentOptions.Load()
options := p.currentConfig.Load().Options
state := p.state.Load()
ss, err := p.state.Load().sessionStore.LoadSessionState(r)

View file

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@ -17,6 +17,7 @@ import (
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
@ -32,8 +33,29 @@ func Test_getUserInfoData(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
t.Run("incoming idp token", func(t *testing.T) {
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, trace.NewNoopTracerProvider()))
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
opts := testOptions(t)
proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
r.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
data := proxy.getUserInfoData(r)
assert.NotNil(t, data.Session)
assert.NotNil(t, data.User)
})
t.Run("session", func(t *testing.T) {
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
@ -68,6 +90,7 @@ func Test_getUserInfoData(t *testing.T) {
assert.Equal(t, "U1", data.User.Id)
assert.True(t, data.IsEnterprise)
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
})
}
func makeRecord(object interface {

View file

@ -78,7 +78,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load()
var redirectURL *url.URL
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL()
signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
@ -126,7 +126,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
// using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load()
options := p.currentOptions.Load()
options := p.currentConfig.Load().Options
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {

View file

@ -40,7 +40,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
}
func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
options := p.currentOptions.Load()
options := p.currentConfig.Load().Options
pu := p.getPortalUser(u)
var routes []*config.Policy
for route := range options.GetAllPolicies() {

View file

@ -57,7 +57,7 @@ func ValidateOptions(o *config.Options) error {
// Proxy stores all the information associated with proxying a request.
type Proxy struct {
state *atomicutil.Value[*proxyState]
currentOptions *atomicutil.Value[*config.Options]
currentConfig *atomicutil.Value[*config.Config]
currentRouter *atomicutil.Value[*mux.Router]
webauthn *webauthn.Handler
tracerProvider oteltrace.TracerProvider
@ -76,7 +76,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
p := &Proxy{
tracerProvider: tracerProvider,
state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(),
currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}),
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
logoProvider: portal.NewLogoProvider(),
}
@ -84,7 +84,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
p.webauthn = webauthn.New(p.getWebauthnState)
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
return int64(p.currentOptions.Load().NumPolicies())
return int64(p.currentConfig.Load().Options.NumPolicies())
})
return p, nil
@ -101,7 +101,7 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
return
}
p.currentOptions.Store(cfg.Options)
p.currentConfig.Store(cfg)
if err := p.setHandlers(ctx, cfg.Options); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
}

View file

@ -2,6 +2,7 @@ package proxy
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
@ -11,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/jwtutil"
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
)
@ -32,7 +34,16 @@ func testOptions(t *testing.T) *config.Options {
hpkePrivateKey, err := opts.GetHPKEPrivateKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
authnSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case hpke_handlers.HPKEPublicKeyPath:
hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())
case "/.pomerium/verify-access-token":
json.NewEncoder(w).Encode(map[string]any{"valid": true, "claims": jwtutil.Claims{}})
default:
http.NotFound(w, r)
}
}))
t.Cleanup(authnSrv.Close)
opts.AuthenticateURLString = authnSrv.URL

View file

@ -5,13 +5,14 @@ import (
"net/http"
"net/url"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
googlegrpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
googlegrpc "google.golang.org/grpc"
)
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -32,6 +33,7 @@ type proxyState struct {
dataBrokerClient databroker.DataBrokerServiceClient
programmaticRedirectDomainWhitelist []string
authenticateFlow authenticateFlow
incomingIDPTokenSessionCreator config.IncomingIDPTokenSessionCreator
}
func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) {
@ -83,5 +85,24 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
return nil, err
}
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: recordType,
Id: recordID,
})
if err != nil {
return nil, err
}
return res.GetRecord(), nil
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
return err
},
)
return state, nil
}