mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
authorize: move IdP token session creator initialization (#5616)d
IdP-token-based session creation makes requests to the authenticate service to verify tokens. We have a singleflight group to avoid having duplicate requests in flight, but it looks like this is not working as intended. Move the IncomingIDPTokenSessionCreator initialization into the main authorize state object, and out of the request path. Add an integration test to assert that making a large number of requests with the same IdP token will result in only one token verification request to the authenticate service.
This commit is contained in:
parent
f9fd52067e
commit
e4dc218b81
3 changed files with 183 additions and 16 deletions
165
authorize/authorize_int_test.go
Normal file
165
authorize/authorize_int_test.go
Normal file
|
@ -0,0 +1,165 @@
|
|||
package authorize_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/interop"
|
||||
"google.golang.org/grpc/interop/grpc_testing"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||
)
|
||||
|
||||
func TestIDPTokenRequests(t *testing.T) {
|
||||
const maxRequests = 1000
|
||||
|
||||
var verifyRequestCount atomic.Uint32
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.pomerium/verify-access-token", func(w http.ResponseWriter, _ *http.Request) {
|
||||
verifyRequestCount.Add(1)
|
||||
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
|
||||
Valid: true,
|
||||
Claims: jwtutil.Claims{"sub": "test-user"},
|
||||
})
|
||||
})
|
||||
authSrv := httptest.NewTLSServer(mux)
|
||||
t.Cleanup(authSrv.Close)
|
||||
|
||||
env := testenv.New(t)
|
||||
|
||||
env.Add(testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) {
|
||||
fmt := config.BearerTokenFormatIDPAccessToken
|
||||
cfg.Options.BearerTokenFormat = &fmt
|
||||
cfg.Options.AuthenticateURLString = authSrv.URL
|
||||
}))
|
||||
|
||||
up := upstreams.GRPC(insecure.NewCredentials())
|
||||
srv := interop.NewTestServer()
|
||||
grpc_testing.RegisterTestServiceServer(up, srv)
|
||||
|
||||
h2c := up.Route().
|
||||
From(env.SubdomainURL("grpc-h2c")).
|
||||
Policy(func(p *config.Policy) {
|
||||
var ppl config.PPLPolicy
|
||||
err := ppl.UnmarshalJSON([]byte(`{
|
||||
"allow": {
|
||||
"and": [{
|
||||
"user": {"is": "test-user"}
|
||||
}]
|
||||
}
|
||||
}`))
|
||||
require.NoError(t, err)
|
||||
p.Policy = &ppl
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
grpcTestRunner := func(ctx context.Context, client grpc_testing.TestServiceClient) error {
|
||||
ctx, span := env.Tracer().Start(ctx, "grpcTestRunner")
|
||||
defer span.End()
|
||||
|
||||
call, err := client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("call: %w", err)
|
||||
}
|
||||
err = call.Send(&grpc_testing.StreamingOutputCallRequest{
|
||||
ResponseParameters: []*grpc_testing.ResponseParameters{
|
||||
{
|
||||
Size: 17,
|
||||
},
|
||||
},
|
||||
ResponseStatus: &grpc_testing.EchoStatus{
|
||||
Message: "hello",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("send: %w", err)
|
||||
}
|
||||
|
||||
resp, err := call.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("recv: %w", err)
|
||||
}
|
||||
if n := len(resp.Payload.Body); n != 17 {
|
||||
return fmt.Errorf("got %d bytes, want 17", n)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-done
|
||||
call.CloseSend()
|
||||
call.Recv()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
cc := up.Dial(h2c, grpc.WithPerRPCCredentials(grpcBearerToken{"test-access-token"}))
|
||||
client := grpc_testing.NewTestServiceClient(cc)
|
||||
|
||||
ctx, cancel := context.WithCancel(env.Context())
|
||||
t.Cleanup(cancel)
|
||||
ch := make(chan error)
|
||||
for i := range maxRequests {
|
||||
go func() {
|
||||
if err := grpcTestRunner(ctx, client); err != nil {
|
||||
ch <- fmt.Errorf("#%d: got error %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
ch <- nil
|
||||
<-ctx.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
var failed int
|
||||
for range maxRequests {
|
||||
select {
|
||||
case err := <-ch:
|
||||
if !assert.NoError(t, err) {
|
||||
failed++
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, verifyRequestCount.Load(), uint32(1))
|
||||
|
||||
close(done)
|
||||
|
||||
if failed > 0 {
|
||||
t.Logf("\n\n\n *** %d / %d REQUESTS FAILED ***", failed, maxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
type grpcBearerToken struct {
|
||||
token string
|
||||
}
|
||||
|
||||
func (t grpcBearerToken) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
"authorization": "Bearer " + t.token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t grpcBearerToken) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
|
@ -20,7 +20,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
|
@ -135,21 +134,8 @@ func (a *Authorize) maybeGetSessionFromRequest(
|
|||
}
|
||||
|
||||
// attempt to create a session from an incoming idp token
|
||||
return config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
res, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
|
||||
return nil
|
||||
},
|
||||
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
|
||||
return a.state.Load().idpTokenSessionCreator.
|
||||
CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
|
||||
}
|
||||
|
||||
func (a *Authorize) getMCPSession(
|
||||
|
|
|
@ -35,6 +35,7 @@ type authorizeState struct {
|
|||
dataBrokerClientConnection *googlegrpc.ClientConn
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
sessionStore *config.SessionStore
|
||||
idpTokenSessionCreator config.IncomingIDPTokenSessionCreator
|
||||
authenticateFlow authenticateFlow
|
||||
syncQueriers map[string]storage.Querier
|
||||
mcp *mcp.Handler
|
||||
|
@ -100,6 +101,21 @@ func newAuthorizeStateFromConfig(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: invalid session store: %w", err)
|
||||
}
|
||||
state.idpTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
res, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if cfg.Options.UseStatelessAuthenticateFlow() {
|
||||
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, tracerProvider, cfg, nil, nil, nil, nil)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue