mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
* authorize: add databroker server and record version to result, force sync via polling * wrap inmem store to take read lock when grabbing databroker versions * address code review comments * reset max to 0
116 lines
3.6 KiB
Go
116 lines
3.6 KiB
Go
package authorize
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
)
|
|
|
|
func TestAuthorize_waitForRecordSync(t *testing.T) {
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
o := &config.Options{
|
|
AuthenticateURLString: "https://authN.example.com",
|
|
DataBrokerURLString: "https://databroker.example.com",
|
|
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
|
|
Policies: testPolicies(t),
|
|
}
|
|
t.Run("skip if exists", func(t *testing.T) {
|
|
a, err := New(&config.Config{Options: o})
|
|
require.NoError(t, err)
|
|
|
|
a.store.UpdateRecord(0, newRecord(&session.Session{
|
|
Id: "SESSION_ID",
|
|
}))
|
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
panic("should never be called")
|
|
},
|
|
}
|
|
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
})
|
|
t.Run("skip if not found", func(t *testing.T) {
|
|
a, err := New(&config.Config{Options: o})
|
|
require.NoError(t, err)
|
|
|
|
callCount := 0
|
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
callCount++
|
|
return nil, status.Error(codes.NotFound, "not found")
|
|
},
|
|
}
|
|
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
assert.Equal(t, 1, callCount, "should be called once")
|
|
})
|
|
t.Run("poll", func(t *testing.T) {
|
|
a, err := New(&config.Config{Options: o})
|
|
require.NoError(t, err)
|
|
|
|
callCount := 0
|
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
callCount++
|
|
switch callCount {
|
|
case 1:
|
|
s := &session.Session{Id: "SESSION_ID"}
|
|
a.store.UpdateRecord(0, newRecord(s))
|
|
return &databroker.GetResponse{Record: newRecord(s)}, nil
|
|
default:
|
|
panic("should never be called")
|
|
}
|
|
},
|
|
}
|
|
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
})
|
|
t.Run("timeout", func(t *testing.T) {
|
|
a, err := New(&config.Config{Options: o})
|
|
require.NoError(t, err)
|
|
|
|
tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100)
|
|
defer clearTimeout()
|
|
|
|
callCount := 0
|
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
callCount++
|
|
s := &session.Session{Id: "SESSION_ID"}
|
|
return &databroker.GetResponse{Record: newRecord(s)}, nil
|
|
},
|
|
}
|
|
a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
|
assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism
|
|
})
|
|
}
|
|
|
|
type storableMessage interface {
|
|
proto.Message
|
|
GetId() string
|
|
}
|
|
|
|
func newRecord(msg storableMessage) *databroker.Record {
|
|
any, err := anypb.New(msg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return &databroker.Record{
|
|
Version: 1,
|
|
Type: any.GetTypeUrl(),
|
|
Id: msg.GetId(),
|
|
Data: any,
|
|
}
|
|
}
|