mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 10:56:28 +02:00
support loading idp token sessions in the proxy service
This commit is contained in:
parent
4b95eda51e
commit
d7c2927cfa
7 changed files with 111 additions and 46 deletions
|
@ -34,12 +34,22 @@ func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||||
options := p.currentOptions.Load()
|
cfg := p.currentConfig.Load()
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
|
|
||||||
data := handlers.UserInfoData{
|
data := handlers.UserInfoData{
|
||||||
CSRFToken: csrf.Token(r),
|
CSRFToken: csrf.Token(r),
|
||||||
BrandingOptions: options.BrandingOptions,
|
BrandingOptions: cfg.Options.BrandingOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, err := state.incomingIDPTokenSessionCreator.CreateSession(r.Context(), cfg, nil, r); err == nil {
|
||||||
|
data.Session = s
|
||||||
|
data.IsImpersonated = false
|
||||||
|
|
||||||
|
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())
|
||||||
|
if err != nil {
|
||||||
|
data.User = &user.User{Id: data.Session.GetUserId()}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
|
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
|
||||||
|
@ -85,7 +95,7 @@ func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.U
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) {
|
func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) {
|
||||||
options := p.currentOptions.Load()
|
options := p.currentConfig.Load().Options
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
|
|
||||||
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
|
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace/noop"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/pomerium/datasource/pkg/directory"
|
"github.com/pomerium/datasource/pkg/directory"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/databroker"
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||||
|
@ -32,42 +33,64 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
t.Run("incoming idp token", func(t *testing.T) {
|
||||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, trace.NewNoopTracerProvider()))
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
|
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
|
||||||
|
opts := testOptions(t)
|
||||||
|
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||||
|
require.NoError(t, err)
|
||||||
|
proxy.state.Load().dataBrokerClient = client
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
||||||
|
r.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
|
||||||
|
data := proxy.getUserInfoData(r)
|
||||||
|
assert.NotNil(t, data.Session)
|
||||||
|
assert.NotNil(t, data.User)
|
||||||
})
|
})
|
||||||
t.Cleanup(func() { cc.Close() })
|
|
||||||
|
|
||||||
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
t.Run("session", func(t *testing.T) {
|
||||||
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
opts := testOptions(t)
|
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
proxy, err := New(ctx, &config.Config{Options: opts})
|
|
||||||
require.NoError(t, err)
|
|
||||||
proxy.state.Load().dataBrokerClient = client
|
|
||||||
|
|
||||||
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
opts := testOptions(t)
|
||||||
makeRecord(&session.Session{
|
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||||
Id: "S1",
|
require.NoError(t, err)
|
||||||
UserId: "U1",
|
proxy.state.Load().dataBrokerClient = client
|
||||||
}),
|
|
||||||
makeRecord(&user.User{
|
|
||||||
Id: "U1",
|
|
||||||
}),
|
|
||||||
makeRecord(&configpb.Config{
|
|
||||||
Name: "dashboard-settings",
|
|
||||||
}),
|
|
||||||
makeStructRecord(directory.UserRecordType, "U1", map[string]any{
|
|
||||||
"group_ids": []any{"G1", "G2", "G3"},
|
|
||||||
})))
|
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
||||||
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
makeRecord(&session.Session{
|
||||||
ID: "S1",
|
Id: "S1",
|
||||||
}))
|
UserId: "U1",
|
||||||
data := proxy.getUserInfoData(r)
|
}),
|
||||||
assert.Equal(t, "S1", data.Session.Id)
|
makeRecord(&user.User{
|
||||||
assert.Equal(t, "U1", data.User.Id)
|
Id: "U1",
|
||||||
assert.True(t, data.IsEnterprise)
|
}),
|
||||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
makeRecord(&configpb.Config{
|
||||||
|
Name: "dashboard-settings",
|
||||||
|
}),
|
||||||
|
makeStructRecord(directory.UserRecordType, "U1", map[string]any{
|
||||||
|
"group_ids": []any{"G1", "G2", "G3"},
|
||||||
|
})))
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
||||||
|
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
||||||
|
ID: "S1",
|
||||||
|
}))
|
||||||
|
data := proxy.getUserInfoData(r)
|
||||||
|
assert.Equal(t, "S1", data.Session.Id)
|
||||||
|
assert.Equal(t, "U1", data.User.Id)
|
||||||
|
assert.True(t, data.IsEnterprise)
|
||||||
|
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRecord(object interface {
|
func makeRecord(object interface {
|
||||||
|
|
|
@ -78,7 +78,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
|
|
||||||
var redirectURL *url.URL
|
var redirectURL *url.URL
|
||||||
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL()
|
signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
return httputil.NewError(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||||
// using the authenticate service.
|
// using the authenticate service.
|
||||||
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
|
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
options := p.currentOptions.Load()
|
options := p.currentConfig.Load().Options
|
||||||
|
|
||||||
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -40,7 +40,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
|
func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
|
||||||
options := p.currentOptions.Load()
|
options := p.currentConfig.Load().Options
|
||||||
pu := p.getPortalUser(u)
|
pu := p.getPortalUser(u)
|
||||||
var routes []*config.Policy
|
var routes []*config.Policy
|
||||||
for route := range options.GetAllPolicies() {
|
for route := range options.GetAllPolicies() {
|
||||||
|
|
|
@ -57,7 +57,7 @@ func ValidateOptions(o *config.Options) error {
|
||||||
// Proxy stores all the information associated with proxying a request.
|
// Proxy stores all the information associated with proxying a request.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
state *atomicutil.Value[*proxyState]
|
state *atomicutil.Value[*proxyState]
|
||||||
currentOptions *atomicutil.Value[*config.Options]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
currentRouter *atomicutil.Value[*mux.Router]
|
currentRouter *atomicutil.Value[*mux.Router]
|
||||||
webauthn *webauthn.Handler
|
webauthn *webauthn.Handler
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
|
@ -76,7 +76,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
|
||||||
p := &Proxy{
|
p := &Proxy{
|
||||||
tracerProvider: tracerProvider,
|
tracerProvider: tracerProvider,
|
||||||
state: atomicutil.NewValue(state),
|
state: atomicutil.NewValue(state),
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}),
|
||||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||||
logoProvider: portal.NewLogoProvider(),
|
logoProvider: portal.NewLogoProvider(),
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
|
||||||
p.webauthn = webauthn.New(p.getWebauthnState)
|
p.webauthn = webauthn.New(p.getWebauthnState)
|
||||||
|
|
||||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||||
return int64(p.currentOptions.Load().NumPolicies())
|
return int64(p.currentConfig.Load().Options.NumPolicies())
|
||||||
})
|
})
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
|
@ -101,7 +101,7 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.currentOptions.Store(cfg.Options)
|
p.currentConfig.Store(cfg)
|
||||||
if err := p.setHandlers(ctx, cfg.Options); err != nil {
|
if err := p.setHandlers(ctx, cfg.Options); err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -11,6 +12,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||||
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
|
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,7 +34,16 @@ func testOptions(t *testing.T) *config.Options {
|
||||||
hpkePrivateKey, err := opts.GetHPKEPrivateKey()
|
hpkePrivateKey, err := opts.GetHPKEPrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
authnSrv := httptest.NewServer(hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
|
authnSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case hpke_handlers.HPKEPublicKeyPath:
|
||||||
|
hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())
|
||||||
|
case "/.pomerium/verify-access-token":
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{"valid": true, "claims": jwtutil.Claims{}})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
t.Cleanup(authnSrv.Close)
|
t.Cleanup(authnSrv.Close)
|
||||||
opts.AuthenticateURLString = authnSrv.URL
|
opts.AuthenticateURLString = authnSrv.URL
|
||||||
|
|
||||||
|
|
|
@ -5,13 +5,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
googlegrpc "google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
|
||||||
googlegrpc "google.golang.org/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
@ -32,6 +33,7 @@ type proxyState struct {
|
||||||
dataBrokerClient databroker.DataBrokerServiceClient
|
dataBrokerClient databroker.DataBrokerServiceClient
|
||||||
programmaticRedirectDomainWhitelist []string
|
programmaticRedirectDomainWhitelist []string
|
||||||
authenticateFlow authenticateFlow
|
authenticateFlow authenticateFlow
|
||||||
|
incomingIDPTokenSessionCreator config.IncomingIDPTokenSessionCreator
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) {
|
func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) {
|
||||||
|
@ -83,5 +85,24 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||||
|
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||||
|
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||||
|
Type: recordType,
|
||||||
|
Id: recordID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.GetRecord(), nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
|
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
|
Records: records,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue