From 1763f026204d623c2c38dffce1948f5bba0114df Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 20 Oct 2020 10:29:22 -0600 Subject: [PATCH] fix databroker requiring signed jwt (#1538) * add test, explicitly call RequireSignedJWT instead of using interceptor to handle combined gRPC server * register handler, handle config changes * fix nil error in tests * unexport constructor --- cache/cache.go | 10 ++-- cache/databroker.go | 95 +++++++++++++++++++++++++++++++----- cache/databroker_test.go | 3 +- pkg/grpcutil/options.go | 7 +-- pkg/grpcutil/options_test.go | 91 ++++++++++++++++++++++++++++++++++ 5 files changed, 184 insertions(+), 22 deletions(-) create mode 100644 pkg/grpcutil/options_test.go diff --git a/cache/cache.go b/cache/cache.go index c914ba6bd..d68251180 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -28,7 +28,7 @@ import ( // Cache represents the cache service. The cache service is a simple interface // for storing keyed blobs (bytes) of unstructured data. type Cache struct { - dataBrokerServer *DataBrokerServer + dataBrokerServer *dataBrokerServer manager *manager.Manager localListener net.Listener @@ -52,10 +52,7 @@ func New(cfg *config.Config) (*Cache, error) { // No metrics handler because we have one in the control plane. Add one // if we no longer register with that grpc Server - localGRPCServer := grpc.NewServer( - grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)), - grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)), - ) + localGRPCServer := grpc.NewServer() clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services) clientDialOptions := []grpc.DialOption{ @@ -74,7 +71,7 @@ func New(cfg *config.Config) (*Cache, error) { return nil, err } - dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg) + dataBrokerServer := newDataBrokerServer(cfg) c := &Cache{ dataBrokerServer: dataBrokerServer, @@ -84,6 +81,7 @@ func New(cfg *config.Config) (*Cache, error) { deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(), dataBrokerStorageType: cfg.Options.DataBrokerStorageType, } + c.Register(c.localGRPCServer) err = c.update(cfg) if err != nil { diff --git a/cache/databroker.go b/cache/databroker.go index 0e5397e8d..6f4ae20bd 100644 --- a/cache/databroker.go +++ b/cache/databroker.go @@ -1,32 +1,39 @@ package cache import ( - "google.golang.org/grpc" + "context" + "encoding/base64" + "sync/atomic" + + "github.com/golang/protobuf/ptypes/empty" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpcutil" ) -// A DataBrokerServer implements the data broker service interface. -type DataBrokerServer struct { - *databroker.Server +// A dataBrokerServer implements the data broker service interface. +type dataBrokerServer struct { + server *databroker.Server + sharedKey atomic.Value } -// NewDataBrokerServer creates a new databroker service server. -func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer { - srv := &DataBrokerServer{} - srv.Server = databroker.New(srv.getOptions(cfg)...) - databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv) +// newDataBrokerServer creates a new databroker service server. +func newDataBrokerServer(cfg *config.Config) *dataBrokerServer { + srv := &dataBrokerServer{} + srv.server = databroker.New(srv.getOptions(cfg)...) + srv.setKey(cfg) return srv } // OnConfigChange updates the underlying databroker server whenever configuration is changed. -func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) { - srv.UpdateConfig(srv.getOptions(cfg)...) +func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) { + srv.server.UpdateConfig(srv.getOptions(cfg)...) + srv.setKey(cfg) } -func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { +func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { return []databroker.ServerOption{ databroker.WithSharedKey(cfg.Options.SharedKey), databroker.WithStorageType(cfg.Options.DataBrokerStorageType), @@ -36,3 +43,67 @@ func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerO databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify), } } + +func (srv *dataBrokerServer) setKey(cfg *config.Config) { + bs, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + if bs == nil { + bs = make([]byte, 0) + } + srv.sharedKey.Store(bs) +} + +func (srv *dataBrokerServer) Delete(ctx context.Context, req *databrokerpb.DeleteRequest) (*empty.Empty, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.Delete(ctx, req) +} + +func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.Get(ctx, req) +} + +func (srv *dataBrokerServer) GetAll(ctx context.Context, req *databrokerpb.GetAllRequest) (*databrokerpb.GetAllResponse, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.GetAll(ctx, req) +} + +func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.Query(ctx, req) +} + +func (srv *dataBrokerServer) Set(ctx context.Context, req *databrokerpb.SetRequest) (*databrokerpb.SetResponse, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.Set(ctx, req) +} + +func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error { + if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { + return err + } + return srv.server.Sync(req, stream) +} + +func (srv *dataBrokerServer) GetTypes(ctx context.Context, req *empty.Empty) (*databrokerpb.GetTypesResponse, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.GetTypes(ctx, req) +} + +func (srv *dataBrokerServer) SyncTypes(req *empty.Empty, stream databrokerpb.DataBrokerService_SyncTypesServer) error { + if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { + return err + } + return srv.server.SyncTypes(req, stream) +} diff --git a/cache/databroker_test.go b/cache/databroker_test.go index d7ad260d1..cfe22dbde 100644 --- a/cache/databroker_test.go +++ b/cache/databroker_test.go @@ -27,7 +27,8 @@ func init() { lis = bufconn.Listen(bufSize) s := grpc.NewServer() internalSrv := internal_databroker.New() - srv := &DataBrokerServer{Server: internalSrv} + srv := &dataBrokerServer{server: internalSrv} + srv.sharedKey.Store([]byte{}) databroker.RegisterDataBrokerServiceServer(s, srv) go func() { diff --git a/pkg/grpcutil/options.go b/pkg/grpcutil/options.go index 16f307159..eafdf5a4b 100644 --- a/pkg/grpcutil/options.go +++ b/pkg/grpcutil/options.go @@ -66,7 +66,7 @@ func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) { func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor { keyBS, _ := base64.StdEncoding.DecodeString(key) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - if err := requireSignedJWT(ctx, keyBS); err != nil { + if err := RequireSignedJWT(ctx, keyBS); err != nil { return nil, err } return handler(ctx, req) @@ -77,14 +77,15 @@ func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor { func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor { keyBS, _ := base64.StdEncoding.DecodeString(key) return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if err := requireSignedJWT(ss.Context(), keyBS); err != nil { + if err := RequireSignedJWT(ss.Context(), keyBS); err != nil { return err } return handler(srv, ss) } } -func requireSignedJWT(ctx context.Context, key []byte) error { +// RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key. +func RequireSignedJWT(ctx context.Context, key []byte) error { if len(key) > 0 { rawjwt, ok := JWTFromGRPCRequest(ctx) if !ok { diff --git a/pkg/grpcutil/options_test.go b/pkg/grpcutil/options_test.go new file mode 100644 index 000000000..ede615cf5 --- /dev/null +++ b/pkg/grpcutil/options_test.go @@ -0,0 +1,91 @@ +package grpcutil + +import ( + "context" + "encoding/base64" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" + "google.golang.org/grpc/status" + + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +func TestSignedJWT(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + li, err := net.Listen("tcp4", "127.0.0.1:0") + if !assert.NoError(t, err) { + return + } + defer li.Close() + + key := cryptutil.NewKey() + srv := grpc.NewServer( + grpc.StreamInterceptor(StreamRequireSignedJWT(base64.StdEncoding.EncodeToString(key))), + grpc.UnaryInterceptor(UnaryRequireSignedJWT(base64.StdEncoding.EncodeToString(key))), + ) + go srv.Serve(li) + + reflection.Register(srv) + + t.Run("unauthenticated", func(t *testing.T) { + cc, err := grpc.Dial(li.Addr().String(), + grpc.WithInsecure()) + if !assert.NoError(t, err) { + return + } + defer cc.Close() + + client := grpc_reflection_v1alpha.NewServerReflectionClient(cc) + stream, err := client.ServerReflectionInfo(ctx, grpc.WaitForReady(true)) + if !assert.NoError(t, err) { + return + } + + err = stream.Send(&grpc_reflection_v1alpha.ServerReflectionRequest{ + Host: "", + MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_ListServices{}, + }) + if !assert.NoError(t, err) { + return + } + + _, err = stream.Recv() + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + }) + t.Run("authenticated", func(t *testing.T) { + cc, err := grpc.Dial(li.Addr().String(), + grpc.WithUnaryInterceptor(WithUnarySignedJWT(key)), + grpc.WithStreamInterceptor(WithStreamSignedJWT(key)), + grpc.WithInsecure()) + if !assert.NoError(t, err) { + return + } + defer cc.Close() + + client := grpc_reflection_v1alpha.NewServerReflectionClient(cc) + stream, err := client.ServerReflectionInfo(ctx, grpc.WaitForReady(true)) + if !assert.NoError(t, err) { + return + } + + err = stream.Send(&grpc_reflection_v1alpha.ServerReflectionRequest{ + Host: "", + MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_ListServices{}, + }) + if !assert.NoError(t, err) { + return + } + + _, err = stream.Recv() + assert.Equal(t, codes.OK, status.Code(err)) + }) +}