pomerium/pkg/storage/inmemory/backend_test.go
Caleb Doxsey 2dc778035d
databroker: add support for field masks on Put (#3210)
* databroker: add support for field masks on Put

* return errors

* clean up go.mod
2022-03-29 16:36:40 -06:00

292 lines
7.4 KiB
Go

package inmemory
import (
"context"
"fmt"
"testing"
"time"
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestBackend(t *testing.T) {
ctx := context.Background()
backend := New()
defer func() { _ = backend.Close() }()
t.Run("get missing record", func(t *testing.T) {
record, err := backend.Get(ctx, "TYPE", "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
sv, err := backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "abcd",
Data: data,
}, nil)
assert.NoError(t, err)
assert.Equal(t, backend.serverVersion, sv)
record, err := backend.Get(ctx, "TYPE", "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "TYPE", record.Type)
assert.Equal(t, uint64(1), record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
sv, err := backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: "abcd",
DeletedAt: timestamppb.Now(),
}, nil)
assert.NoError(t, err)
assert.Equal(t, backend.serverVersion, sv)
record, err := backend.Get(ctx, "TYPE", "abcd")
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("get all records", func(t *testing.T) {
for i := 0; i < 1000; i++ {
sv, err := backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}, nil)
assert.NoError(t, err)
assert.Equal(t, backend.serverVersion, sv)
}
records, versions, err := backend.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, 1000)
assert.Equal(t, uint64(1002), versions.LatestRecordVersion)
})
}
func TestExpiry(t *testing.T) {
ctx := context.Background()
backend := New(WithExpiry(0))
defer func() { _ = backend.Close() }()
for i := 0; i < 1000; i++ {
sv, err := backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}, nil)
assert.NoError(t, err)
assert.Equal(t, backend.serverVersion, sv)
}
stream, err := backend.Sync(ctx, backend.serverVersion, 0)
require.NoError(t, err)
var records []*databroker.Record
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 1000)
backend.removeChangesBefore(time.Now().Add(time.Second))
stream, err = backend.Sync(ctx, backend.serverVersion, 0)
require.NoError(t, err)
records = nil
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 0)
}
func TestConcurrency(t *testing.T) {
ctx := context.Background()
backend := New()
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 1000; i++ {
_, _, _ = backend.GetAll(ctx)
}
return nil
})
eg.Go(func() error {
for i := 0; i < 1000; i++ {
_, _ = backend.Put(ctx, &databroker.Record{
Id: fmt.Sprint(i),
}, nil)
}
return nil
})
assert.NoError(t, eg.Wait())
}
func TestStream(t *testing.T) {
ctx := context.Background()
backend := New()
defer func() { _ = backend.Close() }()
stream, err := backend.Sync(ctx, backend.serverVersion, 0)
require.NoError(t, err)
defer func() { _ = stream.Close() }()
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 10000; i++ {
assert.True(t, stream.Next(true))
assert.Nil(t, stream.Err())
assert.Equal(t, "TYPE", stream.Record().GetType())
assert.Equal(t, fmt.Sprint(i), stream.Record().GetId())
assert.Equal(t, uint64(i+1), stream.Record().GetVersion())
}
return nil
})
eg.Go(func() error {
for i := 0; i < 10000; i++ {
_, err := backend.Put(ctx, &databroker.Record{
Type: "TYPE",
Id: fmt.Sprint(i),
}, nil)
assert.NoError(t, err)
}
return nil
})
require.NoError(t, eg.Wait())
}
func TestStreamClose(t *testing.T) {
ctx := context.Background()
t.Run("by backend", func(t *testing.T) {
backend := New()
stream, err := backend.Sync(ctx, backend.serverVersion, 0)
require.NoError(t, err)
require.NoError(t, backend.Close())
assert.False(t, stream.Next(true))
assert.Equal(t, storage.ErrStreamClosed, stream.Err())
})
t.Run("by stream", func(t *testing.T) {
backend := New()
stream, err := backend.Sync(ctx, backend.serverVersion, 0)
require.NoError(t, err)
require.NoError(t, stream.Close())
assert.False(t, stream.Next(true))
assert.Equal(t, storage.ErrStreamClosed, stream.Err())
})
t.Run("by context", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
backend := New()
stream, err := backend.Sync(ctx, backend.serverVersion, 0)
require.NoError(t, err)
cancel()
assert.False(t, stream.Next(true))
assert.Equal(t, context.Canceled, stream.Err())
})
}
func TestCapacity(t *testing.T) {
ctx := context.Background()
backend := New()
defer func() { _ = backend.Close() }()
err := backend.SetOptions(ctx, "EXAMPLE", &databroker.Options{
Capacity: proto.Uint64(3),
})
require.NoError(t, err)
for i := 0; i < 10; i++ {
_, err = backend.Put(ctx, &databroker.Record{
Type: "EXAMPLE",
Id: fmt.Sprint(i),
}, nil)
require.NoError(t, err)
}
records, _, err := backend.GetAll(ctx)
require.NoError(t, err)
assert.Len(t, records, 3)
var ids []string
for _, r := range records {
ids = append(ids, r.GetId())
}
assert.Equal(t, []string{"7", "8", "9"}, ids, "should contain recent records")
}
func TestLease(t *testing.T) {
ctx := context.Background()
backend := New()
{
ok, err := backend.Lease(ctx, "test", "a", time.Second*30)
require.NoError(t, err)
assert.True(t, ok, "expected a to acquire the lease")
}
{
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
require.NoError(t, err)
assert.False(t, ok, "expected b to fail to acquire the lease")
}
{
ok, err := backend.Lease(ctx, "test", "a", 0)
require.NoError(t, err)
assert.False(t, ok, "expected a to clear the lease")
}
{
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
require.NoError(t, err)
assert.True(t, ok, "expected b to to acquire the lease")
}
}
func TestFieldMask(t *testing.T) {
ctx := context.Background()
backend := New()
_, _ = backend.Put(ctx, &databroker.Record{
Type: "example",
Id: "example",
Data: protoutil.NewAny(&envoy_type_v3.SemanticVersion{
MajorNumber: 1,
MinorNumber: 1,
Patch: 1,
}),
}, nil)
_, _ = backend.Put(ctx, &databroker.Record{
Type: "example",
Id: "example",
Data: protoutil.NewAny(&envoy_type_v3.SemanticVersion{
MajorNumber: 2,
MinorNumber: 2,
Patch: 2,
}),
}, &fieldmaskpb.FieldMask{
Paths: []string{"major_number", "patch"},
})
record, _ := backend.Get(ctx, "example", "example")
record.ModifiedAt = nil
testutil.AssertProtoJSONEqual(t, `{
"data": {
"@type": "type.googleapis.com/envoy.type.v3.SemanticVersion",
"majorNumber": 2,
"minorNumber": 1,
"patch": 2
},
"id": "example",
"type": "example",
"version": "2"
}`, record)
}