grpc: remove ptypes references (#3078)

This commit is contained in:
Caleb Doxsey 2022-02-24 08:37:59 -07:00 committed by GitHub
parent 35f697e491
commit 1342523cda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 83 additions and 177 deletions

View file

@ -17,7 +17,6 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/authenticate/handlers"
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
"github.com/pomerium/pomerium/internal/httputil"
@ -30,6 +29,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
@ -626,7 +626,7 @@ func (a *Authenticate) saveSessionToDataBroker(
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
_, err = user.Put(ctx, state.dataBrokerClient, managerUser.User)
_, err = databroker.Put(ctx, state.dataBrokerClient, managerUser.User)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
}
@ -703,13 +703,11 @@ func (a *Authenticate) getCurrentSession(ctx context.Context) (s *session.Sessio
func (a *Authenticate) getUser(ctx context.Context, userID string) (*user.User, error) {
client := a.state.Load().dataBrokerClient
return user.Get(ctx, client, userID)
}
func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*directory.User, error) {
client := a.state.Load().dataBrokerClient
return directory.GetUser(ctx, client, userID)
}

View file

@ -321,7 +321,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
// save the user
u.DeviceCredentialIds = append(u.DeviceCredentialIds, deviceCredential.GetId())
_, err = user.Put(ctx, state.Client, u)
_, err = databroker.Put(ctx, state.Client, u)
if err != nil {
return err
}
@ -370,7 +370,7 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state
// remove the credential from the user
u.DeviceCredentialIds = removeString(u.DeviceCredentialIds, deviceCredentialID)
_, err = user.Put(ctx, state.Client, u)
_, err = databroker.Put(ctx, state.Client, u)
if err != nil {
return err
}

View file

@ -14,7 +14,6 @@ import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
@ -23,6 +22,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
"github.com/pomerium/pomerium/config"
@ -158,20 +158,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
encryptedEncoder: tt.encoder,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
})
if err != nil {
return nil, err
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
},
Record: databroker.NewRecord(&session.Session{
Id: "SESSION_ID",
}),
}, nil
},
},
@ -322,20 +312,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
sharedEncoder: mock.Encoder{},
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
})
if err != nil {
return nil, err
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
},
Record: databroker.NewRecord(&session.Session{
Id: "SESSION_ID",
}),
}, nil
},
put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
@ -583,20 +563,10 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
})
if err != nil {
return nil, err
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Id: "SESSION_ID",
Data: data,
},
Record: databroker.NewRecord(&session.Session{
Id: "SESSION_ID",
}),
}, nil
},
},
@ -671,7 +641,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
t.Parallel()
now := time.Now()
pbNow, _ := ptypes.TimestampProto(now)
tests := []struct {
name string
url *url.URL
@ -727,22 +696,12 @@ func TestAuthenticate_userInfo(t *testing.T) {
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, err := ptypes.MarshalAny(&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
IdToken: &session.IDToken{IssuedAt: pbNow},
})
if err != nil {
return nil, err
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: 1,
Type: data.GetTypeUrl(),
Record: databroker.NewRecord(&session.Session{
Id: "SESSION_ID",
Data: data,
},
UserId: "USER_ID",
IdToken: &session.IDToken{IssuedAt: timestamppb.New(now)},
}),
}, nil
},
},

View file

@ -18,7 +18,6 @@ import (
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/any"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/scylladb/go-set"
@ -432,7 +431,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
var maxStreamDuration *durationpb.Duration
if options.WriteTimeout > 0 {
maxStreamDuration = ptypes.DurationProto(options.WriteTimeout)
maxStreamDuration = durationpb.New(options.WriteTimeout)
}
rc, err := b.buildRouteConfiguration("main", virtualHosts)
@ -452,10 +451,10 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
HttpFilters: filters,
AccessLog: buildAccessLogs(options),
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
IdleTimeout: ptypes.DurationProto(options.IdleTimeout),
IdleTimeout: durationpb.New(options.IdleTimeout),
MaxStreamDuration: maxStreamDuration,
},
RequestTimeout: ptypes.DurationProto(options.ReadTimeout),
RequestTimeout: durationpb.New(options.ReadTimeout),
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100},
Provider: tracingProvider,

View file

@ -6,7 +6,6 @@ import (
"strconv"
"testing"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -18,6 +17,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
)
const bufSize = 1024 * 1024
@ -49,7 +49,7 @@ func TestServerSync(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
c := databroker.NewDataBrokerServiceClient(conn)
any, _ := ptypes.MarshalAny(new(user.User))
any := protoutil.NewAny(new(user.User))
numRecords := 200
var serverVersion uint64
@ -101,7 +101,7 @@ func BenchmarkSync(b *testing.B) {
}
defer conn.Close()
c := databroker.NewDataBrokerServiceClient(conn)
any, _ := ptypes.MarshalAny(new(session.Session))
any := protoutil.NewAny(new(session.Session))
numRecords := 10000
for i := 0; i < numRecords; i++ {

View file

@ -4,7 +4,6 @@ import (
"strings"
envoy_service_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/service/accesslog/v3"
"github.com/golang/protobuf/ptypes"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/log"
@ -43,7 +42,7 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
evt = evt.Str("forwarded-for", entry.GetRequest().GetForwardedFor())
evt = evt.Str("request-id", entry.GetRequest().GetRequestId())
// response properties
dur, _ := ptypes.Duration(entry.CommonProperties.TimeToLastDownstreamTxByte)
dur := entry.CommonProperties.TimeToLastDownstreamTxByte.AsDuration()
evt = evt.Dur("duration", dur)
evt = evt.Uint64("size", entry.Response.ResponseBodyBytes)
evt = evt.Uint32("response-code", entry.GetResponse().GetResponseCode().GetValue())

View file

@ -4,8 +4,8 @@ import (
"encoding/json"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/google/btree"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -65,8 +65,8 @@ func (s Session) NextRefresh() time.Time {
var tm time.Time
if s.GetOauthToken().GetExpiresAt() != nil {
expiry, err := ptypes.Timestamp(s.GetOauthToken().GetExpiresAt())
if err == nil && !expiry.IsZero() {
expiry := s.GetOauthToken().GetExpiresAt().AsTime()
if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() {
expiry = expiry.Add(-s.gracePeriod)
if tm.IsZero() || expiry.Before(tm) {
tm = expiry
@ -75,8 +75,8 @@ func (s Session) NextRefresh() time.Time {
}
if s.GetExpiresAt() != nil {
expiry, err := ptypes.Timestamp(s.GetExpiresAt())
if err == nil && !expiry.IsZero() {
expiry := s.GetExpiresAt().AsTime()
if s.GetExpiresAt().IsValid() && !expiry.IsZero() {
if tm.IsZero() || expiry.Before(tm) {
tm = expiry
}
@ -119,14 +119,14 @@ func (s *Session) UnmarshalJSON(data []byte) error {
if exp, ok := raw["exp"]; ok {
var secs int64
if err := json.Unmarshal(exp, &secs); err == nil {
s.Session.IdToken.ExpiresAt, _ = ptypes.TimestampProto(time.Unix(secs, 0))
s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
}
delete(raw, "exp")
}
if iat, ok := raw["iat"]; ok {
var secs int64
if err := json.Unmarshal(iat, &secs); err == nil {
s.Session.IdToken.IssuedAt, _ = ptypes.TimestampProto(time.Unix(secs, 0))
s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
}
delete(raw, "iat")
}

View file

@ -6,9 +6,9 @@ import (
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
@ -41,21 +41,18 @@ func TestSession_NextRefresh(t *testing.T) {
assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh())
tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC)
pbtm2, _ := ptypes.TimestampProto(tm2)
s.OauthToken = &session.OAuthToken{
ExpiresAt: pbtm2,
ExpiresAt: timestamppb.New(tm2),
}
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh())
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
pbtm3, _ := ptypes.TimestampProto(tm3)
s.ExpiresAt = pbtm3
s.ExpiresAt = timestamppb.New(tm3)
assert.Equal(t, tm3, s.NextRefresh())
}
func TestSession_UnmarshalJSON(t *testing.T) {
tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
pbtm, _ := ptypes.TimestampProto(tm)
var s Session
err := json.Unmarshal([]byte(`{
"iss": "https://some.issuer.com",
@ -69,8 +66,8 @@ func TestSession_UnmarshalJSON(t *testing.T) {
assert.NotNil(t, s.Session.IdToken)
assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer)
assert.Equal(t, "subject", s.Session.IdToken.Subject)
assert.Equal(t, pbtm, s.Session.IdToken.ExpiresAt)
assert.Equal(t, pbtm, s.Session.IdToken.IssuedAt)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt)
assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, s.Claims)

View file

@ -490,7 +490,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
continue
}
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
@ -499,7 +499,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
continue
}
mgr.onUpdateUser(ctx, record, u.User)
mgr.onUpdateUser(ctx, res.GetRecord(), u.User)
}
}

View file

@ -10,7 +10,6 @@ import (
"github.com/pomerium/pomerium/internal/signal"
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/golang/protobuf/ptypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
@ -116,10 +115,7 @@ func (s *inMemoryServer) lockAndReport(services []*pb.Service) (bool, error) {
// reportLocked updates registration and also returns an indication whether service list was updated
func (s *inMemoryServer) reportLocked(services []*pb.Service) (bool, error) {
expires, err := ptypes.TimestampProto(time.Now().Add(s.ttl))
if err != nil {
return false, err
}
expires := timestamppb.New(time.Now().Add(s.ttl))
inserted := false
for _, svc := range services {

View file

@ -5,11 +5,48 @@ import (
"context"
"fmt"
"io"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/protoutil"
)
//go:generate go run github.com/golang/mock/mockgen -source=databroker.pb.go -destination ./mock_databroker/databroker.pb.go DataBrokerServiceClient
//go:generate go run github.com/golang/mock/mockgen -source=leaser.go -destination ./mock_databroker/leaser.go LeaserHandler
type recordObject interface {
proto.Message
GetId() string
}
// NewRecord creates a new Record.
func NewRecord(object recordObject) *Record {
return &Record{
Type: grpcutil.GetTypeURL(object),
Id: object.GetId(),
Data: protoutil.NewAny(object),
}
}
// Get gets a record from the databroker and unmarshals it into the object.
func Get(ctx context.Context, client DataBrokerServiceClient, object recordObject) error {
res, err := client.Get(ctx, &GetRequest{
Type: grpcutil.GetTypeURL(object),
Id: object.GetId(),
})
if err != nil {
return err
}
return res.GetRecord().GetData().UnmarshalTo(object)
}
// Put puts a record into the databroker.
func Put(ctx context.Context, client DataBrokerServiceClient, object recordObject) (*PutResponse, error) {
return client.Put(ctx, &PutRequest{Record: NewRecord(object)})
}
// ApplyOffsetAndLimit applies the offset and limit to the list of records.
func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) {
records = all

View file

@ -4,49 +4,19 @@ package directory
import (
context "context"
"github.com/golang/protobuf/ptypes"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// GetGroup gets a directory group from the databroker.
func GetGroup(ctx context.Context, client databroker.DataBrokerServiceClient, groupID string) (*Group, error) {
any, _ := ptypes.MarshalAny(new(Group))
res, err := client.Get(ctx, &databroker.GetRequest{
Type: any.GetTypeUrl(),
Id: groupID,
})
if err != nil {
return nil, err
}
var g Group
err = ptypes.UnmarshalAny(res.GetRecord().GetData(), &g)
if err != nil {
return nil, err
}
return &g, nil
g := Group{Id: groupID}
return &g, databroker.Get(ctx, client, &g)
}
// GetUser gets a directory user from the databroker.
func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
any, _ := ptypes.MarshalAny(new(User))
res, err := client.Get(ctx, &databroker.GetRequest{
Type: any.GetTypeUrl(),
Id: userID,
})
if err != nil {
return nil, err
}
var u User
err = ptypes.UnmarshalAny(res.GetRecord().GetData(), &u)
if err != nil {
return nil, err
}
return &u, nil
u := User{Id: userID}
return &u, databroker.Get(ctx, client, &u)
}
// Options are directory provider options.

View file

@ -3,66 +3,17 @@ package user
import (
context "context"
"fmt"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// Get gets a user from the databroker.
func Get(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
any := protoutil.NewAny(new(User))
res, err := client.Get(ctx, &databroker.GetRequest{
Type: any.GetTypeUrl(),
Id: userID,
})
if err != nil {
return nil, err
}
var u User
err = res.GetRecord().GetData().UnmarshalTo(&u)
if err != nil {
return nil, fmt.Errorf("error unmarshaling user from databroker: %w", err)
}
return &u, nil
}
// Put sets a user in the databroker.
func Put(ctx context.Context, client databroker.DataBrokerServiceClient, u *User) (*databroker.Record, error) {
any := protoutil.NewAny(u)
res, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: u.Id,
Data: any,
},
})
if err != nil {
return nil, err
}
return res.GetRecord(), nil
}
// PutServiceAccount sets a service account in the databroker.
func PutServiceAccount(ctx context.Context, client databroker.DataBrokerServiceClient, sa *ServiceAccount) (*databroker.Record, error) {
any := protoutil.NewAny(sa)
res, err := client.Put(ctx, &databroker.PutRequest{
Record: &databroker.Record{
Type: any.GetTypeUrl(),
Id: sa.GetId(),
Data: any,
},
})
if err != nil {
return nil, err
}
return res.GetRecord(), nil
u := &User{Id: userID}
return u, databroker.Get(ctx, client, u)
}
// AddClaims adds the flattened claims to the user.