mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
grpc: remove ptypes references (#3078)
This commit is contained in:
parent
35f697e491
commit
1342523cda
13 changed files with 83 additions and 177 deletions
|
@ -17,7 +17,6 @@ import (
|
|||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers"
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -30,6 +29,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
|
@ -626,7 +626,7 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
|
||||
}
|
||||
_, err = user.Put(ctx, state.dataBrokerClient, managerUser.User)
|
||||
_, err = databroker.Put(ctx, state.dataBrokerClient, managerUser.User)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving user: %w", err)
|
||||
}
|
||||
|
@ -703,13 +703,11 @@ func (a *Authenticate) getCurrentSession(ctx context.Context) (s *session.Sessio
|
|||
|
||||
func (a *Authenticate) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||
client := a.state.Load().dataBrokerClient
|
||||
|
||||
return user.Get(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*directory.User, error) {
|
||||
client := a.state.Load().dataBrokerClient
|
||||
|
||||
return directory.GetUser(ctx, client, userID)
|
||||
}
|
||||
|
||||
|
|
|
@ -321,7 +321,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
|
|||
|
||||
// save the user
|
||||
u.DeviceCredentialIds = append(u.DeviceCredentialIds, deviceCredential.GetId())
|
||||
_, err = user.Put(ctx, state.Client, u)
|
||||
_, err = databroker.Put(ctx, state.Client, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -370,7 +370,7 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
|
|||
|
||||
// remove the credential from the user
|
||||
u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID)
|
||||
_, err = user.Put(ctx, state.Client, u)
|
||||
_, err = databroker.Put(ctx, state.Client, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -23,6 +22,7 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
|
@ -158,20 +158,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
encryptedEncoder: tt.encoder,
|
||||
dataBrokerClient: mockDataBrokerServiceClient{
|
||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
data, err := ptypes.MarshalAny(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
},
|
||||
Record: databroker.NewRecord(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
|
@ -322,20 +312,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
sharedEncoder: mock.Encoder{},
|
||||
dataBrokerClient: mockDataBrokerServiceClient{
|
||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
data, err := ptypes.MarshalAny(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
},
|
||||
Record: databroker.NewRecord(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
|
@ -583,20 +563,10 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
sharedEncoder: signer,
|
||||
dataBrokerClient: mockDataBrokerServiceClient{
|
||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
data, err := ptypes.MarshalAny(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
},
|
||||
Record: databroker.NewRecord(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
|
@ -671,7 +641,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
pbNow, _ := ptypes.TimestampProto(now)
|
||||
tests := []struct {
|
||||
name string
|
||||
url *url.URL
|
||||
|
@ -727,22 +696,12 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
sharedEncoder: signer,
|
||||
dataBrokerClient: mockDataBrokerServiceClient{
|
||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
data, err := ptypes.MarshalAny(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
UserId: "USER_ID",
|
||||
IdToken: &session.IDToken{IssuedAt: pbNow},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &databroker.GetResponse{
|
||||
Record: &databroker.Record{
|
||||
Version: 1,
|
||||
Type: data.GetTypeUrl(),
|
||||
Record: databroker.NewRecord(&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
Data: data,
|
||||
},
|
||||
UserId: "USER_ID",
|
||||
IdToken: &session.IDToken{IssuedAt: timestamppb.New(now)},
|
||||
}),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
|
||||
envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
|
||||
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/any"
|
||||
"github.com/golang/protobuf/ptypes/wrappers"
|
||||
"github.com/scylladb/go-set"
|
||||
|
@ -432,7 +431,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
|
||||
var maxStreamDuration *durationpb.Duration
|
||||
if options.WriteTimeout > 0 {
|
||||
maxStreamDuration = ptypes.DurationProto(options.WriteTimeout)
|
||||
maxStreamDuration = durationpb.New(options.WriteTimeout)
|
||||
}
|
||||
|
||||
rc, err := b.buildRouteConfiguration("main", virtualHosts)
|
||||
|
@ -452,10 +451,10 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
HttpFilters: filters,
|
||||
AccessLog: buildAccessLogs(options),
|
||||
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
|
||||
IdleTimeout: ptypes.DurationProto(options.IdleTimeout),
|
||||
IdleTimeout: durationpb.New(options.IdleTimeout),
|
||||
MaxStreamDuration: maxStreamDuration,
|
||||
},
|
||||
RequestTimeout: ptypes.DurationProto(options.ReadTimeout),
|
||||
RequestTimeout: durationpb.New(options.ReadTimeout),
|
||||
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
|
||||
RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100},
|
||||
Provider: tracingProvider,
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -18,6 +17,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
const bufSize = 1024 * 1024
|
||||
|
@ -49,7 +49,7 @@ func TestServerSync(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
c := databroker.NewDataBrokerServiceClient(conn)
|
||||
any, _ := ptypes.MarshalAny(new(user.User))
|
||||
any := protoutil.NewAny(new(user.User))
|
||||
numRecords := 200
|
||||
|
||||
var serverVersion uint64
|
||||
|
@ -101,7 +101,7 @@ func BenchmarkSync(b *testing.B) {
|
|||
}
|
||||
defer conn.Close()
|
||||
c := databroker.NewDataBrokerServiceClient(conn)
|
||||
any, _ := ptypes.MarshalAny(new(session.Session))
|
||||
any := protoutil.NewAny(new(session.Session))
|
||||
numRecords := 10000
|
||||
|
||||
for i := 0; i < numRecords; i++ {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"strings"
|
||||
|
||||
envoy_service_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/service/accesslog/v3"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -43,7 +42,7 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
|
|||
evt = evt.Str("forwarded-for", entry.GetRequest().GetForwardedFor())
|
||||
evt = evt.Str("request-id", entry.GetRequest().GetRequestId())
|
||||
// response properties
|
||||
dur, _ := ptypes.Duration(entry.CommonProperties.TimeToLastDownstreamTxByte)
|
||||
dur := entry.CommonProperties.TimeToLastDownstreamTxByte.AsDuration()
|
||||
evt = evt.Dur("duration", dur)
|
||||
evt = evt.Uint64("size", entry.Response.ResponseBodyBytes)
|
||||
evt = evt.Uint32("response-code", entry.GetResponse().GetResponseCode().GetValue())
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/google/btree"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
@ -65,8 +65,8 @@ func (s Session) NextRefresh() time.Time {
|
|||
var tm time.Time
|
||||
|
||||
if s.GetOauthToken().GetExpiresAt() != nil {
|
||||
expiry, err := ptypes.Timestamp(s.GetOauthToken().GetExpiresAt())
|
||||
if err == nil && !expiry.IsZero() {
|
||||
expiry := s.GetOauthToken().GetExpiresAt().AsTime()
|
||||
if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() {
|
||||
expiry = expiry.Add(-s.gracePeriod)
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
|
@ -75,8 +75,8 @@ func (s Session) NextRefresh() time.Time {
|
|||
}
|
||||
|
||||
if s.GetExpiresAt() != nil {
|
||||
expiry, err := ptypes.Timestamp(s.GetExpiresAt())
|
||||
if err == nil && !expiry.IsZero() {
|
||||
expiry := s.GetExpiresAt().AsTime()
|
||||
if s.GetExpiresAt().IsValid() && !expiry.IsZero() {
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
}
|
||||
|
@ -119,14 +119,14 @@ func (s *Session) UnmarshalJSON(data []byte) error {
|
|||
if exp, ok := raw["exp"]; ok {
|
||||
var secs int64
|
||||
if err := json.Unmarshal(exp, &secs); err == nil {
|
||||
s.Session.IdToken.ExpiresAt, _ = ptypes.TimestampProto(time.Unix(secs, 0))
|
||||
s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
|
||||
}
|
||||
delete(raw, "exp")
|
||||
}
|
||||
if iat, ok := raw["iat"]; ok {
|
||||
var secs int64
|
||||
if err := json.Unmarshal(iat, &secs); err == nil {
|
||||
s.Session.IdToken.IssuedAt, _ = ptypes.TimestampProto(time.Unix(secs, 0))
|
||||
s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
|
||||
}
|
||||
delete(raw, "iat")
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
|
@ -41,21 +41,18 @@ func TestSession_NextRefresh(t *testing.T) {
|
|||
assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh())
|
||||
|
||||
tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC)
|
||||
pbtm2, _ := ptypes.TimestampProto(tm2)
|
||||
s.OauthToken = &session.OAuthToken{
|
||||
ExpiresAt: pbtm2,
|
||||
ExpiresAt: timestamppb.New(tm2),
|
||||
}
|
||||
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh())
|
||||
|
||||
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
|
||||
pbtm3, _ := ptypes.TimestampProto(tm3)
|
||||
s.ExpiresAt = pbtm3
|
||||
s.ExpiresAt = timestamppb.New(tm3)
|
||||
assert.Equal(t, tm3, s.NextRefresh())
|
||||
}
|
||||
|
||||
func TestSession_UnmarshalJSON(t *testing.T) {
|
||||
tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
|
||||
pbtm, _ := ptypes.TimestampProto(tm)
|
||||
var s Session
|
||||
err := json.Unmarshal([]byte(`{
|
||||
"iss": "https://some.issuer.com",
|
||||
|
@ -69,8 +66,8 @@ func TestSession_UnmarshalJSON(t *testing.T) {
|
|||
assert.NotNil(t, s.Session.IdToken)
|
||||
assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer)
|
||||
assert.Equal(t, "subject", s.Session.IdToken.Subject)
|
||||
assert.Equal(t, pbtm, s.Session.IdToken.ExpiresAt)
|
||||
assert.Equal(t, pbtm, s.Session.IdToken.IssuedAt)
|
||||
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt)
|
||||
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt)
|
||||
assert.Equal(t, map[string]*structpb.ListValue{
|
||||
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
|
||||
}, s.Claims)
|
||||
|
|
|
@ -490,7 +490,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -499,7 +499,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
mgr.onUpdateUser(ctx, record, u.User)
|
||||
mgr.onUpdateUser(ctx, res.GetRecord(), u.User)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/signal"
|
||||
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
@ -116,10 +115,7 @@ func (s *inMemoryServer) lockAndReport(services []*pb.Service) (bool, error) {
|
|||
|
||||
// reportLocked updates registration and also returns an indication whether service list was updated
|
||||
func (s *inMemoryServer) reportLocked(services []*pb.Service) (bool, error) {
|
||||
expires, err := ptypes.TimestampProto(time.Now().Add(s.ttl))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
expires := timestamppb.New(time.Now().Add(s.ttl))
|
||||
|
||||
inserted := false
|
||||
for _, svc := range services {
|
||||
|
|
|
@ -5,11 +5,48 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -source=databroker.pb.go -destination ./mock_databroker/databroker.pb.go DataBrokerServiceClient
|
||||
//go:generate go run github.com/golang/mock/mockgen -source=leaser.go -destination ./mock_databroker/leaser.go LeaserHandler
|
||||
|
||||
type recordObject interface {
|
||||
proto.Message
|
||||
GetId() string
|
||||
}
|
||||
|
||||
// NewRecord creates a new Record.
|
||||
func NewRecord(object recordObject) *Record {
|
||||
return &Record{
|
||||
Type: grpcutil.GetTypeURL(object),
|
||||
Id: object.GetId(),
|
||||
Data: protoutil.NewAny(object),
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets a record from the databroker and unmarshals it into the object.
|
||||
func Get(ctx context.Context, client DataBrokerServiceClient, object recordObject) error {
|
||||
res, err := client.Get(ctx, &GetRequest{
|
||||
Type: grpcutil.GetTypeURL(object),
|
||||
Id: object.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return res.GetRecord().GetData().UnmarshalTo(object)
|
||||
}
|
||||
|
||||
// Put puts a record into the databroker.
|
||||
func Put(ctx context.Context, client DataBrokerServiceClient, object recordObject) (*PutResponse, error) {
|
||||
return client.Put(ctx, &PutRequest{Record: NewRecord(object)})
|
||||
}
|
||||
|
||||
// ApplyOffsetAndLimit applies the offset and limit to the list of records.
|
||||
func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) {
|
||||
records = all
|
||||
|
|
|
@ -4,49 +4,19 @@ package directory
|
|||
import (
|
||||
context "context"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// GetGroup gets a directory group from the databroker.
|
||||
func GetGroup(ctx context.Context, client databroker.DataBrokerServiceClient, groupID string) (*Group, error) {
|
||||
any, _ := ptypes.MarshalAny(new(Group))
|
||||
|
||||
res, err := client.Get(ctx, &databroker.GetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: groupID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g Group
|
||||
err = ptypes.UnmarshalAny(res.GetRecord().GetData(), &g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &g, nil
|
||||
g := Group{Id: groupID}
|
||||
return &g, databroker.Get(ctx, client, &g)
|
||||
}
|
||||
|
||||
// GetUser gets a directory user from the databroker.
|
||||
func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
|
||||
any, _ := ptypes.MarshalAny(new(User))
|
||||
|
||||
res, err := client.Get(ctx, &databroker.GetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var u User
|
||||
err = ptypes.UnmarshalAny(res.GetRecord().GetData(), &u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
u := User{Id: userID}
|
||||
return &u, databroker.Get(ctx, client, &u)
|
||||
}
|
||||
|
||||
// Options are directory provider options.
|
||||
|
|
|
@ -3,66 +3,17 @@ package user
|
|||
|
||||
import (
|
||||
context "context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
// Get gets a user from the databroker.
|
||||
func Get(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
|
||||
any := protoutil.NewAny(new(User))
|
||||
|
||||
res, err := client.Get(ctx, &databroker.GetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var u User
|
||||
err = res.GetRecord().GetData().UnmarshalTo(&u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling user from databroker: %w", err)
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Put sets a user in the databroker.
|
||||
func Put(ctx context.Context, client databroker.DataBrokerServiceClient, u *User) (*databroker.Record, error) {
|
||||
any := protoutil.NewAny(u)
|
||||
res, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.Id,
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.GetRecord(), nil
|
||||
}
|
||||
|
||||
// PutServiceAccount sets a service account in the databroker.
|
||||
func PutServiceAccount(ctx context.Context, client databroker.DataBrokerServiceClient, sa *ServiceAccount) (*databroker.Record, error) {
|
||||
any := protoutil.NewAny(sa)
|
||||
res, err := client.Put(ctx, &databroker.PutRequest{
|
||||
Record: &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: sa.GetId(),
|
||||
Data: any,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.GetRecord(), nil
|
||||
u := &User{Id: userID}
|
||||
return u, databroker.Get(ctx, client, u)
|
||||
}
|
||||
|
||||
// AddClaims adds the flattened claims to the user.
|
||||
|
|
Loading…
Add table
Reference in a new issue