authorize: move IdP token session creator initialization (#5616)d

IdP-token-based session creation makes requests to the authenticate 
service to verify tokens. We have a singleflight group to avoid having 
duplicate requests in flight, but it looks like this is not working as 
intended. Move the IncomingIDPTokenSessionCreator initialization into
the main authorize state object, and out of the request path.

Add an integration test to assert that making a large number of requests 
with the same IdP token will result in only one token verification
request to the authenticate service.
This commit is contained in:
Kenneth Jenkins 2025-05-14 13:54:39 -07:00 committed by GitHub
parent f9fd52067e
commit e4dc218b81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 183 additions and 16 deletions

View file

@ -0,0 +1,165 @@
package authorize_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/interop"
"google.golang.org/grpc/interop/grpc_testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/pkg/authenticateapi"
)
func TestIDPTokenRequests(t *testing.T) {
const maxRequests = 1000
var verifyRequestCount atomic.Uint32
mux := http.NewServeMux()
mux.HandleFunc("/.pomerium/verify-access-token", func(w http.ResponseWriter, _ *http.Request) {
verifyRequestCount.Add(1)
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
Valid: true,
Claims: jwtutil.Claims{"sub": "test-user"},
})
})
authSrv := httptest.NewTLSServer(mux)
t.Cleanup(authSrv.Close)
env := testenv.New(t)
env.Add(testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) {
fmt := config.BearerTokenFormatIDPAccessToken
cfg.Options.BearerTokenFormat = &fmt
cfg.Options.AuthenticateURLString = authSrv.URL
}))
up := upstreams.GRPC(insecure.NewCredentials())
srv := interop.NewTestServer()
grpc_testing.RegisterTestServiceServer(up, srv)
h2c := up.Route().
From(env.SubdomainURL("grpc-h2c")).
Policy(func(p *config.Policy) {
var ppl config.PPLPolicy
err := ppl.UnmarshalJSON([]byte(`{
"allow": {
"and": [{
"user": {"is": "test-user"}
}]
}
}`))
require.NoError(t, err)
p.Policy = &ppl
})
done := make(chan struct{})
grpcTestRunner := func(ctx context.Context, client grpc_testing.TestServiceClient) error {
ctx, span := env.Tracer().Start(ctx, "grpcTestRunner")
defer span.End()
call, err := client.FullDuplexCall(ctx)
if err != nil {
return fmt.Errorf("call: %w", err)
}
err = call.Send(&grpc_testing.StreamingOutputCallRequest{
ResponseParameters: []*grpc_testing.ResponseParameters{
{
Size: 17,
},
},
ResponseStatus: &grpc_testing.EchoStatus{
Message: "hello",
},
})
if err != nil {
return fmt.Errorf("send: %w", err)
}
resp, err := call.Recv()
if err != nil {
return fmt.Errorf("recv: %w", err)
}
if n := len(resp.Payload.Body); n != 17 {
return fmt.Errorf("got %d bytes, want 17", n)
}
go func() {
<-done
call.CloseSend()
call.Recv()
}()
return nil
}
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
cc := up.Dial(h2c, grpc.WithPerRPCCredentials(grpcBearerToken{"test-access-token"}))
client := grpc_testing.NewTestServiceClient(cc)
ctx, cancel := context.WithCancel(env.Context())
t.Cleanup(cancel)
ch := make(chan error)
for i := range maxRequests {
go func() {
if err := grpcTestRunner(ctx, client); err != nil {
ch <- fmt.Errorf("#%d: got error %w", i, err)
return
}
ch <- nil
<-ctx.Done()
}()
}
var failed int
for range maxRequests {
select {
case err := <-ch:
if !assert.NoError(t, err) {
failed++
}
case <-ctx.Done():
t.Fatal("timeout")
}
}
assert.Equal(t, verifyRequestCount.Load(), uint32(1))
close(done)
if failed > 0 {
t.Logf("\n\n\n *** %d / %d REQUESTS FAILED ***", failed, maxRequests)
}
}
type grpcBearerToken struct {
token string
}
func (t grpcBearerToken) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
func (t grpcBearerToken) RequireTransportSecurity() bool {
return false
}

View file

@ -20,7 +20,6 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
@ -135,21 +134,8 @@ func (a *Authorize) maybeGetSessionFromRequest(
}
// attempt to create a session from an incoming idp token
return config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
res, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
return nil
},
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
return a.state.Load().idpTokenSessionCreator.
CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
}
func (a *Authorize) getMCPSession(

View file

@ -35,6 +35,7 @@ type authorizeState struct {
dataBrokerClientConnection *googlegrpc.ClientConn
dataBrokerClient databroker.DataBrokerServiceClient
sessionStore *config.SessionStore
idpTokenSessionCreator config.IncomingIDPTokenSessionCreator
authenticateFlow authenticateFlow
syncQueriers map[string]storage.Querier
mcp *mcp.Handler
@ -100,6 +101,21 @@ func newAuthorizeStateFromConfig(
if err != nil {
return nil, fmt.Errorf("authorize: invalid session store: %w", err)
}
state.idpTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
res, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
return nil
},
)
if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, tracerProvider, cfg, nil, nil, nil, nil)