mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 00:10:45 +02:00
rework session updates to use new patch method
Update the AccessTracker, WebAuthn handlers, and identity manager refresh loop to perform their session record updates using the databroker Patch() method. This should prevent any of these updates from conflicting.
This commit is contained in:
parent
d5da872157
commit
cef0c89070
6 changed files with 88 additions and 15 deletions
|
@ -2,11 +2,13 @@ package authorize
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -158,13 +160,12 @@ func (tracker *AccessTracker) updateSession(
|
|||
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
|
||||
defer clearTimeout()
|
||||
|
||||
s, err := session.Get(ctx, client, sessionID)
|
||||
if status.Code(err) == codes.NotFound {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
s := &session.Session{Id: sessionID, AccessedAt: timestamppb.Now()}
|
||||
m, err := fieldmaskpb.New(s, "accessed_at")
|
||||
if err != nil {
|
||||
return fmt.Errorf("internal error: %w", err)
|
||||
}
|
||||
s.AccessedAt = timestamppb.Now()
|
||||
_, err = session.Put(ctx, client, s)
|
||||
|
||||
_, err = session.Patch(ctx, client, s, m)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -13,6 +13,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
|
@ -21,6 +23,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
const certPEM = `
|
||||
|
@ -289,6 +292,37 @@ func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.Put
|
|||
return m.put(ctx, in, opts...)
|
||||
}
|
||||
|
||||
// Patch emulates the patch operation using Get and Put. (This is not atomic.)
|
||||
func (m mockDataBrokerServiceClient) Patch(ctx context.Context, in *databroker.PatchRequest, opts ...grpc.CallOption) (*databroker.PatchResponse, error) {
|
||||
var records []*databroker.Record
|
||||
for _, record := range in.GetRecords() {
|
||||
getResponse, err := m.Get(ctx, &databroker.GetRequest{
|
||||
Type: record.GetType(),
|
||||
Id: record.GetId(),
|
||||
})
|
||||
if storage.IsNotFound(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing := getResponse.GetRecord()
|
||||
if err := storage.PatchRecord(existing, record, in.GetFieldMask()); err != nil {
|
||||
return nil, status.Errorf(codes.Unknown, err.Error())
|
||||
}
|
||||
|
||||
records = append(records, record)
|
||||
}
|
||||
putResponse, err := m.Put(ctx, &databroker.PutRequest{Records: records})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &databroker.PatchResponse{
|
||||
ServerVersion: putResponse.GetServerVersion(),
|
||||
Records: putResponse.GetRecords(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mustParseURL(rawURL string) url.URL {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -405,8 +406,13 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
|
|||
}
|
||||
|
||||
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
|
||||
fm, err := fieldmaskpb.New(state.Session, "device_credentials")
|
||||
if err != nil {
|
||||
return fmt.Errorf("internal error: %w", err)
|
||||
}
|
||||
|
||||
// save the session to the databroker
|
||||
res, err := session.Put(r.Context(), state.Client, state.Session)
|
||||
res, err := session.Patch(r.Context(), state.Client, state.Session, fm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
|
@ -288,7 +289,13 @@ func (mgr *Manager) refreshSessionInternal(
|
|||
return false
|
||||
}
|
||||
|
||||
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
|
||||
fm, err := fieldmaskpb.New(s.Session, "oauth_token", "claims")
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("internal error")
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
|
@ -241,12 +242,17 @@ func TestManager_refreshSession(t *testing.T) {
|
|||
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
|
||||
RefreshToken: "new-refresh-token",
|
||||
}
|
||||
client.EXPECT().Put(gomock.Any(),
|
||||
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: "session-id",
|
||||
Data: protoutil.NewAny(expectedSession),
|
||||
}}}}).
|
||||
client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{
|
||||
&databroker.PatchRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: "session-id",
|
||||
Data: protoutil.NewAny(expectedSession),
|
||||
}},
|
||||
FieldMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"oauth_token", "claims"},
|
||||
},
|
||||
}}).
|
||||
Return(nil /* this result is currently unused */, nil)
|
||||
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
|
@ -63,6 +64,24 @@ func Put(ctx context.Context, client databroker.DataBrokerServiceClient, s *Sess
|
|||
return res, err
|
||||
}
|
||||
|
||||
// Patch updates specific fields of an existing session in the databroker.
|
||||
func Patch(
|
||||
ctx context.Context, client databroker.DataBrokerServiceClient,
|
||||
s *Session, fields *fieldmaskpb.FieldMask,
|
||||
) (*databroker.PatchResponse, error) {
|
||||
s = proto.Clone(s).(*Session)
|
||||
data := protoutil.NewAny(s)
|
||||
res, err := client.Patch(ctx, &databroker.PatchRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: s.Id,
|
||||
Data: data,
|
||||
}},
|
||||
FieldMask: fields,
|
||||
})
|
||||
return res, err
|
||||
}
|
||||
|
||||
// AddClaims adds the flattened claims to the session.
|
||||
func (x *Session) AddClaims(claims identity.FlattenedClaims) {
|
||||
if x.Claims == nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue