mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-02 19:04:14 +02:00
IdP-token-based session creation makes requests to the authenticate service to verify tokens. We have a singleflight group to avoid having duplicate requests in flight, but it looks like this is not working as intended. Move the IncomingIDPTokenSessionCreator initialization into the main authorize state object, and out of the request path. Add an integration test to assert that making a large number of requests with the same IdP token will result in only one token verification request to the authenticate service.
165 lines
3.7 KiB
Go
165 lines
3.7 KiB
Go
package authorize_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/interop"
|
|
"google.golang.org/grpc/interop/grpc_testing"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/jwtutil"
|
|
"github.com/pomerium/pomerium/internal/testenv"
|
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
|
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
|
)
|
|
|
|
func TestIDPTokenRequests(t *testing.T) {
|
|
const maxRequests = 1000
|
|
|
|
var verifyRequestCount atomic.Uint32
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/.pomerium/verify-access-token", func(w http.ResponseWriter, _ *http.Request) {
|
|
verifyRequestCount.Add(1)
|
|
json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{
|
|
Valid: true,
|
|
Claims: jwtutil.Claims{"sub": "test-user"},
|
|
})
|
|
})
|
|
authSrv := httptest.NewTLSServer(mux)
|
|
t.Cleanup(authSrv.Close)
|
|
|
|
env := testenv.New(t)
|
|
|
|
env.Add(testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) {
|
|
fmt := config.BearerTokenFormatIDPAccessToken
|
|
cfg.Options.BearerTokenFormat = &fmt
|
|
cfg.Options.AuthenticateURLString = authSrv.URL
|
|
}))
|
|
|
|
up := upstreams.GRPC(insecure.NewCredentials())
|
|
srv := interop.NewTestServer()
|
|
grpc_testing.RegisterTestServiceServer(up, srv)
|
|
|
|
h2c := up.Route().
|
|
From(env.SubdomainURL("grpc-h2c")).
|
|
Policy(func(p *config.Policy) {
|
|
var ppl config.PPLPolicy
|
|
err := ppl.UnmarshalJSON([]byte(`{
|
|
"allow": {
|
|
"and": [{
|
|
"user": {"is": "test-user"}
|
|
}]
|
|
}
|
|
}`))
|
|
require.NoError(t, err)
|
|
p.Policy = &ppl
|
|
})
|
|
|
|
done := make(chan struct{})
|
|
|
|
grpcTestRunner := func(ctx context.Context, client grpc_testing.TestServiceClient) error {
|
|
ctx, span := env.Tracer().Start(ctx, "grpcTestRunner")
|
|
defer span.End()
|
|
|
|
call, err := client.FullDuplexCall(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("call: %w", err)
|
|
}
|
|
err = call.Send(&grpc_testing.StreamingOutputCallRequest{
|
|
ResponseParameters: []*grpc_testing.ResponseParameters{
|
|
{
|
|
Size: 17,
|
|
},
|
|
},
|
|
ResponseStatus: &grpc_testing.EchoStatus{
|
|
Message: "hello",
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("send: %w", err)
|
|
}
|
|
|
|
resp, err := call.Recv()
|
|
if err != nil {
|
|
return fmt.Errorf("recv: %w", err)
|
|
}
|
|
if n := len(resp.Payload.Body); n != 17 {
|
|
return fmt.Errorf("got %d bytes, want 17", n)
|
|
}
|
|
|
|
go func() {
|
|
<-done
|
|
call.CloseSend()
|
|
call.Recv()
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
env.AddUpstream(up)
|
|
env.Start()
|
|
snippets.WaitStartupComplete(env)
|
|
|
|
cc := up.Dial(h2c, grpc.WithPerRPCCredentials(grpcBearerToken{"test-access-token"}))
|
|
client := grpc_testing.NewTestServiceClient(cc)
|
|
|
|
ctx, cancel := context.WithCancel(env.Context())
|
|
t.Cleanup(cancel)
|
|
ch := make(chan error)
|
|
for i := range maxRequests {
|
|
go func() {
|
|
if err := grpcTestRunner(ctx, client); err != nil {
|
|
ch <- fmt.Errorf("#%d: got error %w", i, err)
|
|
return
|
|
}
|
|
|
|
ch <- nil
|
|
<-ctx.Done()
|
|
}()
|
|
}
|
|
|
|
var failed int
|
|
for range maxRequests {
|
|
select {
|
|
case err := <-ch:
|
|
if !assert.NoError(t, err) {
|
|
failed++
|
|
}
|
|
case <-ctx.Done():
|
|
t.Fatal("timeout")
|
|
}
|
|
}
|
|
|
|
assert.Equal(t, verifyRequestCount.Load(), uint32(1))
|
|
|
|
close(done)
|
|
|
|
if failed > 0 {
|
|
t.Logf("\n\n\n *** %d / %d REQUESTS FAILED ***", failed, maxRequests)
|
|
}
|
|
}
|
|
|
|
type grpcBearerToken struct {
|
|
token string
|
|
}
|
|
|
|
func (t grpcBearerToken) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) {
|
|
return map[string]string{
|
|
"authorization": "Bearer " + t.token,
|
|
}, nil
|
|
}
|
|
|
|
func (t grpcBearerToken) RequireTransportSecurity() bool {
|
|
return false
|
|
}
|