fix databroker requiring signed jwt (#1538)

* add test, explicitly call RequireSignedJWT instead of using interceptor to handle combined gRPC server

* register handler, handle config changes

* fix nil error in tests

* unexport constructor
This commit is contained in:
Caleb Doxsey 2020-10-20 10:29:22 -06:00 committed by GitHub
parent a375f707f8
commit 1763f02620
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 22 deletions

10
cache/cache.go vendored
View file

@ -28,7 +28,7 @@ import (
// Cache represents the cache service. The cache service is a simple interface // Cache represents the cache service. The cache service is a simple interface
// for storing keyed blobs (bytes) of unstructured data. // for storing keyed blobs (bytes) of unstructured data.
type Cache struct { type Cache struct {
dataBrokerServer *DataBrokerServer dataBrokerServer *dataBrokerServer
manager *manager.Manager manager *manager.Manager
localListener net.Listener localListener net.Listener
@ -52,10 +52,7 @@ func New(cfg *config.Config) (*Cache, error) {
// No metrics handler because we have one in the control plane. Add one // No metrics handler because we have one in the control plane. Add one
// if we no longer register with that grpc Server // if we no longer register with that grpc Server
localGRPCServer := grpc.NewServer( localGRPCServer := grpc.NewServer()
grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)),
grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)),
)
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services) clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
clientDialOptions := []grpc.DialOption{ clientDialOptions := []grpc.DialOption{
@ -74,7 +71,7 @@ func New(cfg *config.Config) (*Cache, error) {
return nil, err return nil, err
} }
dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg) dataBrokerServer := newDataBrokerServer(cfg)
c := &Cache{ c := &Cache{
dataBrokerServer: dataBrokerServer, dataBrokerServer: dataBrokerServer,
@ -84,6 +81,7 @@ func New(cfg *config.Config) (*Cache, error) {
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(), deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
dataBrokerStorageType: cfg.Options.DataBrokerStorageType, dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
} }
c.Register(c.localGRPCServer)
err = c.update(cfg) err = c.update(cfg)
if err != nil { if err != nil {

95
cache/databroker.go vendored
View file

@ -1,32 +1,39 @@
package cache package cache
import ( import (
"google.golang.org/grpc" "context"
"encoding/base64"
"sync/atomic"
"github.com/golang/protobuf/ptypes/empty"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/internal/databroker"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpcutil"
) )
// A DataBrokerServer implements the data broker service interface. // A dataBrokerServer implements the data broker service interface.
type DataBrokerServer struct { type dataBrokerServer struct {
*databroker.Server server *databroker.Server
sharedKey atomic.Value
} }
// NewDataBrokerServer creates a new databroker service server. // newDataBrokerServer creates a new databroker service server.
func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer { func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
srv := &DataBrokerServer{} srv := &dataBrokerServer{}
srv.Server = databroker.New(srv.getOptions(cfg)...) srv.server = databroker.New(srv.getOptions(cfg)...)
databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv) srv.setKey(cfg)
return srv return srv
} }
// OnConfigChange updates the underlying databroker server whenever configuration is changed. // OnConfigChange updates the underlying databroker server whenever configuration is changed.
func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) { func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) {
srv.UpdateConfig(srv.getOptions(cfg)...) srv.server.UpdateConfig(srv.getOptions(cfg)...)
srv.setKey(cfg)
} }
func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
return []databroker.ServerOption{ return []databroker.ServerOption{
databroker.WithSharedKey(cfg.Options.SharedKey), databroker.WithSharedKey(cfg.Options.SharedKey),
databroker.WithStorageType(cfg.Options.DataBrokerStorageType), databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
@ -36,3 +43,67 @@ func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerO
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify), databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
} }
} }
func (srv *dataBrokerServer) setKey(cfg *config.Config) {
bs, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
if bs == nil {
bs = make([]byte, 0)
}
srv.sharedKey.Store(bs)
}
func (srv *dataBrokerServer) Delete(ctx context.Context, req *databrokerpb.DeleteRequest) (*empty.Empty, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Delete(ctx, req)
}
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Get(ctx, req)
}
func (srv *dataBrokerServer) GetAll(ctx context.Context, req *databrokerpb.GetAllRequest) (*databrokerpb.GetAllResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.GetAll(ctx, req)
}
func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Query(ctx, req)
}
func (srv *dataBrokerServer) Set(ctx context.Context, req *databrokerpb.SetRequest) (*databrokerpb.SetResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Set(ctx, req)
}
func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
return err
}
return srv.server.Sync(req, stream)
}
func (srv *dataBrokerServer) GetTypes(ctx context.Context, req *empty.Empty) (*databrokerpb.GetTypesResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.GetTypes(ctx, req)
}
func (srv *dataBrokerServer) SyncTypes(req *empty.Empty, stream databrokerpb.DataBrokerService_SyncTypesServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
return err
}
return srv.server.SyncTypes(req, stream)
}

View file

@ -27,7 +27,8 @@ func init() {
lis = bufconn.Listen(bufSize) lis = bufconn.Listen(bufSize)
s := grpc.NewServer() s := grpc.NewServer()
internalSrv := internal_databroker.New() internalSrv := internal_databroker.New()
srv := &DataBrokerServer{Server: internalSrv} srv := &dataBrokerServer{server: internalSrv}
srv.sharedKey.Store([]byte{})
databroker.RegisterDataBrokerServiceServer(s, srv) databroker.RegisterDataBrokerServiceServer(s, srv)
go func() { go func() {

View file

@ -66,7 +66,7 @@ func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor { func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
keyBS, _ := base64.StdEncoding.DecodeString(key) keyBS, _ := base64.StdEncoding.DecodeString(key)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if err := requireSignedJWT(ctx, keyBS); err != nil { if err := RequireSignedJWT(ctx, keyBS); err != nil {
return nil, err return nil, err
} }
return handler(ctx, req) return handler(ctx, req)
@ -77,14 +77,15 @@ func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor { func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
keyBS, _ := base64.StdEncoding.DecodeString(key) keyBS, _ := base64.StdEncoding.DecodeString(key)
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := requireSignedJWT(ss.Context(), keyBS); err != nil { if err := RequireSignedJWT(ss.Context(), keyBS); err != nil {
return err return err
} }
return handler(srv, ss) return handler(srv, ss)
} }
} }
func requireSignedJWT(ctx context.Context, key []byte) error { // RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key.
func RequireSignedJWT(ctx context.Context, key []byte) error {
if len(key) > 0 { if len(key) > 0 {
rawjwt, ok := JWTFromGRPCRequest(ctx) rawjwt, ok := JWTFromGRPCRequest(ctx)
if !ok { if !ok {

View file

@ -0,0 +1,91 @@
package grpcutil
import (
"context"
"encoding/base64"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestSignedJWT(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
li, err := net.Listen("tcp4", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer li.Close()
key := cryptutil.NewKey()
srv := grpc.NewServer(
grpc.StreamInterceptor(StreamRequireSignedJWT(base64.StdEncoding.EncodeToString(key))),
grpc.UnaryInterceptor(UnaryRequireSignedJWT(base64.StdEncoding.EncodeToString(key))),
)
go srv.Serve(li)
reflection.Register(srv)
t.Run("unauthenticated", func(t *testing.T) {
cc, err := grpc.Dial(li.Addr().String(),
grpc.WithInsecure())
if !assert.NoError(t, err) {
return
}
defer cc.Close()
client := grpc_reflection_v1alpha.NewServerReflectionClient(cc)
stream, err := client.ServerReflectionInfo(ctx, grpc.WaitForReady(true))
if !assert.NoError(t, err) {
return
}
err = stream.Send(&grpc_reflection_v1alpha.ServerReflectionRequest{
Host: "",
MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_ListServices{},
})
if !assert.NoError(t, err) {
return
}
_, err = stream.Recv()
assert.Equal(t, codes.Unauthenticated, status.Code(err))
})
t.Run("authenticated", func(t *testing.T) {
cc, err := grpc.Dial(li.Addr().String(),
grpc.WithUnaryInterceptor(WithUnarySignedJWT(key)),
grpc.WithStreamInterceptor(WithStreamSignedJWT(key)),
grpc.WithInsecure())
if !assert.NoError(t, err) {
return
}
defer cc.Close()
client := grpc_reflection_v1alpha.NewServerReflectionClient(cc)
stream, err := client.ServerReflectionInfo(ctx, grpc.WaitForReady(true))
if !assert.NoError(t, err) {
return
}
err = stream.Send(&grpc_reflection_v1alpha.ServerReflectionRequest{
Host: "",
MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_ListServices{},
})
if !assert.NoError(t, err) {
return
}
_, err = stream.Recv()
assert.Equal(t, codes.OK, status.Code(err))
})
}