pomerium/pkg/storage/storagetest/storagetest.go
Kenneth Jenkins a771b82a72
storage/inmemory: fix Patch() error handling (#4838)
The Patch() method was intended to skip any records that do not
currently exist. However, currently inmemory.Backend.Patch() will return
ErrNotFound if the last record in the records slice is not found (it
will ignore any other previous records that are not found).

Update the error handling logic here to be consistent with the postgres
backend, and add a unit test to exercise this case.
2023-12-07 12:21:10 -08:00

188 lines
6 KiB
Go

// Package storagetest contains test cases for use in verifying the behavior of
// a storage.Backend implementation.
package storagetest
import (
"context"
"fmt"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/storage"
)
// TestBackendPatch verifies the behavior of the backend Patch() method.
func TestBackendPatch(t *testing.T, ctx context.Context, backend storage.Backend) { //nolint:revive
mkRecord := func(s *session.Session) *databroker.Record {
a, _ := anypb.New(s)
return &databroker.Record{
Type: a.TypeUrl,
Id: s.Id,
Data: a,
}
}
t.Run("not found", func(t *testing.T) {
mask, err := fieldmaskpb.New(&session.Session{}, "oauth_token")
require.NoError(t, err)
s := &session.Session{Id: "session-id-that-does-not-exist"}
_, updated, err := backend.Patch(ctx, []*databroker.Record{mkRecord(s)}, mask)
require.NoError(t, err)
assert.Empty(t, updated)
})
t.Run("basic", func(t *testing.T) {
// Populate an initial set of session records.
s1 := &session.Session{
Id: "session-1",
IdToken: &session.IDToken{Issuer: "issuer-1"},
OauthToken: &session.OAuthToken{AccessToken: "access-token-1"},
}
s2 := &session.Session{
Id: "session-2",
IdToken: &session.IDToken{Issuer: "issuer-2"},
OauthToken: &session.OAuthToken{AccessToken: "access-token-2"},
}
s3 := &session.Session{
Id: "session-3",
IdToken: &session.IDToken{Issuer: "issuer-3"},
OauthToken: &session.OAuthToken{AccessToken: "access-token-3"},
}
initial := []*databroker.Record{mkRecord(s1), mkRecord(s2), mkRecord(s3)}
_, err := backend.Put(ctx, initial)
require.NoError(t, err)
// Now patch just the oauth_token field.
u1 := &session.Session{
Id: "session-1",
OauthToken: &session.OAuthToken{AccessToken: "access-token-1-new"},
}
u2 := &session.Session{
Id: "session-4-does-not-exist",
OauthToken: &session.OAuthToken{AccessToken: "access-token-4-new"},
}
u3 := &session.Session{
Id: "session-3",
OauthToken: &session.OAuthToken{AccessToken: "access-token-3-new"},
}
mask, _ := fieldmaskpb.New(&session.Session{}, "oauth_token")
_, updated, err := backend.Patch(
ctx, []*databroker.Record{mkRecord(u1), mkRecord(u2), mkRecord(u3)}, mask)
require.NoError(t, err)
// The OAuthToken message should be updated but the IDToken message should
// be unchanged, as it was not included in the field mask. The results
// should indicate that only two records were updated (one did not exist).
assert.Equal(t, 2, len(updated))
assert.Greater(t, updated[0].Version, initial[0].Version)
assert.Greater(t, updated[1].Version, initial[2].Version)
testutil.AssertProtoJSONEqual(t, `{
"@type": "type.googleapis.com/session.Session",
"id": "session-1",
"idToken": {
"issuer": "issuer-1"
},
"oauthToken": {
"accessToken": "access-token-1-new"
}
}`, updated[0].Data)
testutil.AssertProtoJSONEqual(t, `{
"@type": "type.googleapis.com/session.Session",
"id": "session-3",
"idToken": {
"issuer": "issuer-3"
},
"oauthToken": {
"accessToken": "access-token-3-new"
}
}`, updated[1].Data)
// Verify that the updates will indeed be seen by a subsequent Get().
// Note: first truncate the modified_at timestamps to 1 µs precision, as
// that is the maximum precision supported by Postgres.
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
truncateTimestamps(updated[0].ModifiedAt, r1.ModifiedAt)
testutil.AssertProtoEqual(t, updated[0], r1)
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
truncateTimestamps(updated[1].ModifiedAt, r3.ModifiedAt)
testutil.AssertProtoEqual(t, updated[1], r3)
})
t.Run("concurrent", func(t *testing.T) {
if n := runtime.GOMAXPROCS(0); n < 2 {
t.Skipf("skipping concurrent test (GOMAXPROCS = %d)", n)
}
rs1 := make([]*databroker.Record, 1)
rs2 := make([]*databroker.Record, 1)
s1 := session.Session{Id: "concurrent", OauthToken: &session.OAuthToken{}}
s2 := session.Session{Id: "concurrent", OauthToken: &session.OAuthToken{}}
// Store an initial version of a session record.
rs1[0] = mkRecord(&s1)
_, err := backend.Put(ctx, rs1)
require.NoError(t, err)
fmAccessToken, err := fieldmaskpb.New(&session.Session{}, "oauth_token.access_token")
require.NoError(t, err)
fmRefreshToken, err := fieldmaskpb.New(&session.Session{}, "oauth_token.refresh_token")
require.NoError(t, err)
var wg sync.WaitGroup
// Repeatedly make Patch calls to update the session from two separate
// goroutines (one updating just the access token, the other updating
// just the refresh token.) Verify that no updates are lost.
for i := 0; i < 100; i++ {
access := fmt.Sprintf("access-%d", i)
s1.OauthToken.AccessToken = access
rs1[0] = mkRecord(&s1)
refresh := fmt.Sprintf("refresh-%d", i)
s2.OauthToken.RefreshToken = refresh
rs2[0] = mkRecord(&s2)
wg.Add(2)
go func() {
_, _, _ = backend.Patch(ctx, rs1, fmAccessToken)
wg.Done()
}()
go func() {
_, _, _ = backend.Patch(ctx, rs2, fmRefreshToken)
wg.Done()
}()
wg.Wait()
r, err := backend.Get(ctx, "type.googleapis.com/session.Session", "concurrent")
require.NoError(t, err)
data, err := r.Data.UnmarshalNew()
require.NoError(t, err)
s := data.(*session.Session)
require.Equal(t, access, s.OauthToken.AccessToken)
require.Equal(t, refresh, s.OauthToken.RefreshToken)
}
})
}
// truncateTimestamps truncates Timestamp messages to 1 µs precision.
func truncateTimestamps(ts ...*timestamppb.Timestamp) {
for _, t := range ts {
t.Nanos = (t.Nanos / 1000) * 1000
}
}