mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
support loading idp token sessions in the proxy service
This commit is contained in:
parent
4b95eda51e
commit
d7c2927cfa
7 changed files with 111 additions and 46 deletions
|
@ -34,12 +34,22 @@ func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error)
|
|||
}
|
||||
|
||||
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||
options := p.currentOptions.Load()
|
||||
cfg := p.currentConfig.Load()
|
||||
state := p.state.Load()
|
||||
|
||||
data := handlers.UserInfoData{
|
||||
CSRFToken: csrf.Token(r),
|
||||
BrandingOptions: options.BrandingOptions,
|
||||
BrandingOptions: cfg.Options.BrandingOptions,
|
||||
}
|
||||
|
||||
if s, err := state.incomingIDPTokenSessionCreator.CreateSession(r.Context(), cfg, nil, r); err == nil {
|
||||
data.Session = s
|
||||
data.IsImpersonated = false
|
||||
|
||||
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())
|
||||
if err != nil {
|
||||
data.User = &user.User{Id: data.Session.GetUserId()}
|
||||
}
|
||||
}
|
||||
|
||||
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
|
||||
|
@ -85,7 +95,7 @@ func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.U
|
|||
}
|
||||
|
||||
func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) {
|
||||
options := p.currentOptions.Load()
|
||||
options := p.currentConfig.Load().Options
|
||||
state := p.state.Load()
|
||||
|
||||
ss, err := p.state.Load().sessionStore.LoadSessionState(r)
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
|
@ -32,8 +33,29 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
t.Run("incoming idp token", func(t *testing.T) {
|
||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, trace.NewNoopTracerProvider()))
|
||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||
})
|
||||
t.Cleanup(func() { cc.Close() })
|
||||
|
||||
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||
|
||||
opts := testOptions(t)
|
||||
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||
require.NoError(t, err)
|
||||
proxy.state.Load().dataBrokerClient = client
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
||||
r.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
|
||||
data := proxy.getUserInfoData(r)
|
||||
assert.NotNil(t, data.Session)
|
||||
assert.NotNil(t, data.User)
|
||||
})
|
||||
|
||||
t.Run("session", func(t *testing.T) {
|
||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||
})
|
||||
t.Cleanup(func() { cc.Close() })
|
||||
|
||||
|
@ -68,6 +90,7 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
assert.Equal(t, "U1", data.User.Id)
|
||||
assert.True(t, data.IsEnterprise)
|
||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func makeRecord(object interface {
|
||||
|
|
|
@ -78,7 +78,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
state := p.state.Load()
|
||||
|
||||
var redirectURL *url.URL
|
||||
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL()
|
||||
signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
@ -126,7 +126,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
|||
// using the authenticate service.
|
||||
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
|
||||
state := p.state.Load()
|
||||
options := p.currentOptions.Load()
|
||||
options := p.currentConfig.Load().Options
|
||||
|
||||
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||
if err != nil {
|
||||
|
|
|
@ -40,7 +40,7 @@ func (p *Proxy) routesPortalJSON(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
|
||||
func (p *Proxy) getPortalRoutes(ctx context.Context, u handlers.UserInfoData) []portal.Route {
|
||||
options := p.currentOptions.Load()
|
||||
options := p.currentConfig.Load().Options
|
||||
pu := p.getPortalUser(u)
|
||||
var routes []*config.Policy
|
||||
for route := range options.GetAllPolicies() {
|
||||
|
|
|
@ -57,7 +57,7 @@ func ValidateOptions(o *config.Options) error {
|
|||
// Proxy stores all the information associated with proxying a request.
|
||||
type Proxy struct {
|
||||
state *atomicutil.Value[*proxyState]
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
currentRouter *atomicutil.Value[*mux.Router]
|
||||
webauthn *webauthn.Handler
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
|
@ -76,7 +76,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
|
|||
p := &Proxy{
|
||||
tracerProvider: tracerProvider,
|
||||
state: atomicutil.NewValue(state),
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}),
|
||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||
logoProvider: portal.NewLogoProvider(),
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
|
|||
p.webauthn = webauthn.New(p.getWebauthnState)
|
||||
|
||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||
return int64(p.currentOptions.Load().NumPolicies())
|
||||
return int64(p.currentConfig.Load().Options.NumPolicies())
|
||||
})
|
||||
|
||||
return p, nil
|
||||
|
@ -101,7 +101,7 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
return
|
||||
}
|
||||
|
||||
p.currentOptions.Store(cfg.Options)
|
||||
p.currentConfig.Store(cfg)
|
||||
if err := p.setHandlers(ctx, cfg.Options); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
@ -11,6 +12,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers"
|
||||
)
|
||||
|
||||
|
@ -32,7 +34,16 @@ func testOptions(t *testing.T) *config.Options {
|
|||
hpkePrivateKey, err := opts.GetHPKEPrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
authnSrv := httptest.NewServer(hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
|
||||
authnSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case hpke_handlers.HPKEPublicKeyPath:
|
||||
hpke_handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())
|
||||
case "/.pomerium/verify-access-token":
|
||||
json.NewEncoder(w).Encode(map[string]any{"valid": true, "claims": jwtutil.Claims{}})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(authnSrv.Close)
|
||||
opts.AuthenticateURLString = authnSrv.URL
|
||||
|
||||
|
|
|
@ -5,13 +5,14 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
googlegrpc "google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
googlegrpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
@ -32,6 +33,7 @@ type proxyState struct {
|
|||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
programmaticRedirectDomainWhitelist []string
|
||||
authenticateFlow authenticateFlow
|
||||
incomingIDPTokenSessionCreator config.IncomingIDPTokenSessionCreator
|
||||
}
|
||||
|
||||
func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.TracerProvider, cfg *config.Config) (*proxyState, error) {
|
||||
|
@ -83,5 +85,24 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
|
|||
return nil, err
|
||||
}
|
||||
|
||||
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: recordType,
|
||||
Id: recordID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.GetRecord(), nil
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue