mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 13:07:13 +02:00
fix databroker requiring signed jwt (#1538)
* add test, explicitly call RequireSignedJWT instead of using interceptor to handle combined gRPC server * register handler, handle config changes * fix nil error in tests * unexport constructor
This commit is contained in:
parent
a375f707f8
commit
1763f02620
5 changed files with 184 additions and 22 deletions
10
cache/cache.go
vendored
10
cache/cache.go
vendored
|
@ -28,7 +28,7 @@ import (
|
||||||
// Cache represents the cache service. The cache service is a simple interface
|
// Cache represents the cache service. The cache service is a simple interface
|
||||||
// for storing keyed blobs (bytes) of unstructured data.
|
// for storing keyed blobs (bytes) of unstructured data.
|
||||||
type Cache struct {
|
type Cache struct {
|
||||||
dataBrokerServer *DataBrokerServer
|
dataBrokerServer *dataBrokerServer
|
||||||
manager *manager.Manager
|
manager *manager.Manager
|
||||||
|
|
||||||
localListener net.Listener
|
localListener net.Listener
|
||||||
|
@ -52,10 +52,7 @@ func New(cfg *config.Config) (*Cache, error) {
|
||||||
|
|
||||||
// No metrics handler because we have one in the control plane. Add one
|
// No metrics handler because we have one in the control plane. Add one
|
||||||
// if we no longer register with that grpc Server
|
// if we no longer register with that grpc Server
|
||||||
localGRPCServer := grpc.NewServer(
|
localGRPCServer := grpc.NewServer()
|
||||||
grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)),
|
|
||||||
grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)),
|
|
||||||
)
|
|
||||||
|
|
||||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
||||||
clientDialOptions := []grpc.DialOption{
|
clientDialOptions := []grpc.DialOption{
|
||||||
|
@ -74,7 +71,7 @@ func New(cfg *config.Config) (*Cache, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg)
|
dataBrokerServer := newDataBrokerServer(cfg)
|
||||||
|
|
||||||
c := &Cache{
|
c := &Cache{
|
||||||
dataBrokerServer: dataBrokerServer,
|
dataBrokerServer: dataBrokerServer,
|
||||||
|
@ -84,6 +81,7 @@ func New(cfg *config.Config) (*Cache, error) {
|
||||||
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
|
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
|
||||||
dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
|
dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
|
||||||
}
|
}
|
||||||
|
c.Register(c.localGRPCServer)
|
||||||
|
|
||||||
err = c.update(cfg)
|
err = c.update(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
95
cache/databroker.go
vendored
95
cache/databroker.go
vendored
|
@ -1,32 +1,39 @@
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/ptypes/empty"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/databroker"
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A DataBrokerServer implements the data broker service interface.
|
// A dataBrokerServer implements the data broker service interface.
|
||||||
type DataBrokerServer struct {
|
type dataBrokerServer struct {
|
||||||
*databroker.Server
|
server *databroker.Server
|
||||||
|
sharedKey atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDataBrokerServer creates a new databroker service server.
|
// newDataBrokerServer creates a new databroker service server.
|
||||||
func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer {
|
func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
|
||||||
srv := &DataBrokerServer{}
|
srv := &dataBrokerServer{}
|
||||||
srv.Server = databroker.New(srv.getOptions(cfg)...)
|
srv.server = databroker.New(srv.getOptions(cfg)...)
|
||||||
databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv)
|
srv.setKey(cfg)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
|
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
|
||||||
func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) {
|
func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) {
|
||||||
srv.UpdateConfig(srv.getOptions(cfg)...)
|
srv.server.UpdateConfig(srv.getOptions(cfg)...)
|
||||||
|
srv.setKey(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
|
func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
|
||||||
return []databroker.ServerOption{
|
return []databroker.ServerOption{
|
||||||
databroker.WithSharedKey(cfg.Options.SharedKey),
|
databroker.WithSharedKey(cfg.Options.SharedKey),
|
||||||
databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
|
databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
|
||||||
|
@ -36,3 +43,67 @@ func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerO
|
||||||
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
|
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) setKey(cfg *config.Config) {
|
||||||
|
bs, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||||
|
if bs == nil {
|
||||||
|
bs = make([]byte, 0)
|
||||||
|
}
|
||||||
|
srv.sharedKey.Store(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Delete(ctx context.Context, req *databrokerpb.DeleteRequest) (*empty.Empty, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.Delete(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.Get(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) GetAll(ctx context.Context, req *databrokerpb.GetAllRequest) (*databrokerpb.GetAllResponse, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.GetAll(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.Query(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Set(ctx context.Context, req *databrokerpb.SetRequest) (*databrokerpb.SetResponse, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.Set(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error {
|
||||||
|
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.server.Sync(req, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) GetTypes(ctx context.Context, req *empty.Empty) (*databrokerpb.GetTypesResponse, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.GetTypes(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) SyncTypes(req *empty.Empty, stream databrokerpb.DataBrokerService_SyncTypesServer) error {
|
||||||
|
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.server.SyncTypes(req, stream)
|
||||||
|
}
|
||||||
|
|
3
cache/databroker_test.go
vendored
3
cache/databroker_test.go
vendored
|
@ -27,7 +27,8 @@ func init() {
|
||||||
lis = bufconn.Listen(bufSize)
|
lis = bufconn.Listen(bufSize)
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
internalSrv := internal_databroker.New()
|
internalSrv := internal_databroker.New()
|
||||||
srv := &DataBrokerServer{Server: internalSrv}
|
srv := &dataBrokerServer{server: internalSrv}
|
||||||
|
srv.sharedKey.Store([]byte{})
|
||||||
databroker.RegisterDataBrokerServiceServer(s, srv)
|
databroker.RegisterDataBrokerServiceServer(s, srv)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
@ -66,7 +66,7 @@ func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
|
||||||
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
||||||
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
||||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||||
if err := requireSignedJWT(ctx, keyBS); err != nil {
|
if err := RequireSignedJWT(ctx, keyBS); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return handler(ctx, req)
|
return handler(ctx, req)
|
||||||
|
@ -77,14 +77,15 @@ func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
||||||
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
|
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
|
||||||
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
||||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
if err := requireSignedJWT(ss.Context(), keyBS); err != nil {
|
if err := RequireSignedJWT(ss.Context(), keyBS); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return handler(srv, ss)
|
return handler(srv, ss)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireSignedJWT(ctx context.Context, key []byte) error {
|
// RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key.
|
||||||
|
func RequireSignedJWT(ctx context.Context, key []byte) error {
|
||||||
if len(key) > 0 {
|
if len(key) > 0 {
|
||||||
rawjwt, ok := JWTFromGRPCRequest(ctx)
|
rawjwt, ok := JWTFromGRPCRequest(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
91
pkg/grpcutil/options_test.go
Normal file
91
pkg/grpcutil/options_test.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package grpcutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/reflection"
|
||||||
|
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignedJWT(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer li.Close()
|
||||||
|
|
||||||
|
key := cryptutil.NewKey()
|
||||||
|
srv := grpc.NewServer(
|
||||||
|
grpc.StreamInterceptor(StreamRequireSignedJWT(base64.StdEncoding.EncodeToString(key))),
|
||||||
|
grpc.UnaryInterceptor(UnaryRequireSignedJWT(base64.StdEncoding.EncodeToString(key))),
|
||||||
|
)
|
||||||
|
go srv.Serve(li)
|
||||||
|
|
||||||
|
reflection.Register(srv)
|
||||||
|
|
||||||
|
t.Run("unauthenticated", func(t *testing.T) {
|
||||||
|
cc, err := grpc.Dial(li.Addr().String(),
|
||||||
|
grpc.WithInsecure())
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cc.Close()
|
||||||
|
|
||||||
|
client := grpc_reflection_v1alpha.NewServerReflectionClient(cc)
|
||||||
|
stream, err := client.ServerReflectionInfo(ctx, grpc.WaitForReady(true))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = stream.Send(&grpc_reflection_v1alpha.ServerReflectionRequest{
|
||||||
|
Host: "",
|
||||||
|
MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_ListServices{},
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stream.Recv()
|
||||||
|
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||||
|
})
|
||||||
|
t.Run("authenticated", func(t *testing.T) {
|
||||||
|
cc, err := grpc.Dial(li.Addr().String(),
|
||||||
|
grpc.WithUnaryInterceptor(WithUnarySignedJWT(key)),
|
||||||
|
grpc.WithStreamInterceptor(WithStreamSignedJWT(key)),
|
||||||
|
grpc.WithInsecure())
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cc.Close()
|
||||||
|
|
||||||
|
client := grpc_reflection_v1alpha.NewServerReflectionClient(cc)
|
||||||
|
stream, err := client.ServerReflectionInfo(ctx, grpc.WaitForReady(true))
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = stream.Send(&grpc_reflection_v1alpha.ServerReflectionRequest{
|
||||||
|
Host: "",
|
||||||
|
MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_ListServices{},
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stream.Recv()
|
||||||
|
assert.Equal(t, codes.OK, status.Code(err))
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue