diff --git a/authenticate/state.go b/authenticate/state.go index 990eee294..74d3ec199 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -116,6 +116,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err state.jwk.Keys = append(state.jwk.Keys, *jwk) } + sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ Addr: cfg.Options.DataBrokerURL, OverrideCertificateName: cfg.Options.OverrideCertificateName, @@ -125,6 +127,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, WithInsecure: cfg.Options.GRPCInsecure, ServiceName: cfg.Options.Services, + SignedJWTKey: sharedKey, }) if err != nil { return nil, err diff --git a/authorize/state.go b/authorize/state.go index 9e3e38590..a765dffee 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -1,6 +1,7 @@ package authorize import ( + "encoding/base64" "fmt" "sync/atomic" @@ -41,6 +42,8 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a return nil, err } + sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ Addr: cfg.Options.DataBrokerURL, OverrideCertificateName: cfg.Options.OverrideCertificateName, @@ -50,6 +53,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, WithInsecure: cfg.Options.GRPCInsecure, ServiceName: cfg.Options.Services, + SignedJWTKey: sharedKey, }) if err != nil { return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err) diff --git a/cache/cache.go b/cache/cache.go index aa23f31b4..c914ba6bd 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -5,6 +5,7 @@ package cache import ( "context" + "encoding/base64" "fmt" "net" "sync" @@ -21,6 +22,7 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpcutil" ) // Cache represents the cache service. The cache service is a simple interface @@ -46,12 +48,22 @@ func New(cfg *config.Config) (*Cache, error) { return nil, err } + sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + // No metrics handler because we have one in the control plane. Add one // if we no longer register with that grpc Server - localGRPCServer := grpc.NewServer() + localGRPCServer := grpc.NewServer( + grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)), + grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)), + ) clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services) - clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure()) + clientDialOptions := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithChainUnaryInterceptor(clientStatsHandler.UnaryInterceptor, grpcutil.WithUnarySignedJWT(sharedKey)), + grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(sharedKey)), + grpc.WithStatsHandler(clientStatsHandler.Handler), + } localGRPCConnection, err := grpc.DialContext( context.Background(), diff --git a/cmd/pomerium-cli/cli.go b/cmd/pomerium-cli/cli.go index 238509c39..21e2d6a6b 100644 --- a/cmd/pomerium-cli/cli.go +++ b/cmd/pomerium-cli/cli.go @@ -3,6 +3,7 @@ package main import ( "bufio" "context" + "encoding/base64" "errors" "fmt" "net/url" @@ -91,23 +92,6 @@ var serviceAccountCmd = &cobra.Command{ l := zerolog.Nop() log.SetLogger(&l) - dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL) - if err != nil { - return fmt.Errorf("invalid databroker url: %w", err) - } - - cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ - Addr: dataBrokerURL, - OverrideCertificateName: serviceAccountOptions.overrideCertificateName, - CA: serviceAccountOptions.ca, - CAFile: serviceAccountOptions.caFile, - WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"), - }) - if err != nil { - return fmt.Errorf("error creating databroker connection: %w", err) - } - defer cc.Close() - // hydrate our session serviceAccountOptions.serviceAccount.Audience = jwt.Audience(serviceAccountOptions.aud) serviceAccountOptions.serviceAccount.Groups = []string(serviceAccountOptions.groups) @@ -144,6 +128,26 @@ var serviceAccountCmd = &cobra.Command{ return errors.New("iss is required") } + dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL) + if err != nil { + return fmt.Errorf("invalid databroker url: %w", err) + } + + rawSharedKey, _ := base64.StdEncoding.DecodeString(sharedKey) + + cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ + Addr: dataBrokerURL, + OverrideCertificateName: serviceAccountOptions.overrideCertificateName, + CA: serviceAccountOptions.ca, + CAFile: serviceAccountOptions.caFile, + WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"), + SignedJWTKey: rawSharedKey, + }) + if err != nil { + return fmt.Errorf("error creating databroker connection: %w", err) + } + defer cc.Close() + sa := &user.ServiceAccount{ Id: uuid.New().String(), UserId: serviceAccountOptions.serviceAccount.User, diff --git a/integration/control_plane_test.go b/integration/control_plane_test.go index 3b2b4e927..c482b09f1 100644 --- a/integration/control_plane_test.go +++ b/integration/control_plane_test.go @@ -20,7 +20,7 @@ func TestDashboard(t *testing.T) { t.Run("admin impersonate", func(t *testing.T) { client := testcluster.NewHTTPClient() - res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"), + _, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"), flows.WithEmail("bob@dogs.test"), flows.WithGroups("user")) if !assert.NoError(t, err) { return @@ -31,7 +31,7 @@ func TestDashboard(t *testing.T) { t.Fatal(err) } - res, err = client.Do(req) + res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index dacc92563..0a232fcb0 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -2,6 +2,7 @@ package databroker import ( "context" + "encoding/base64" "errors" "sync" "time" @@ -138,6 +139,7 @@ func (src *ConfigSource) rebuild(firstTime bool) { } func (src *ConfigSource) runUpdater(cfg *config.Config) { + sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) connectionOptions := &grpc.Options{ Addr: cfg.Options.DataBrokerURL, OverrideCertificateName: cfg.Options.OverrideCertificateName, @@ -147,6 +149,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) { ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, WithInsecure: cfg.Options.GRPCInsecure, ServiceName: cfg.Options.Services, + SignedJWTKey: sharedKey, } h, err := hashstructure.Hash(connectionOptions, nil) if err != nil { diff --git a/internal/telemetry/grpc.go b/internal/telemetry/grpc.go index b4b4258f2..339a68d90 100644 --- a/internal/telemetry/grpc.go +++ b/internal/telemetry/grpc.go @@ -55,13 +55,3 @@ func NewGRPCClientStatsHandler(service string) *GRPCClientStatsHandler { UnaryInterceptor: metrics.GRPCClientInterceptor(ServiceName(service)), } } - -// DialOptions returns telemetry related DialOptions appended to an optional existing list -// of DialOptions -func (h *GRPCClientStatsHandler) DialOptions(o ...grpc.DialOption) []grpc.DialOption { - o = append(o, - grpc.WithUnaryInterceptor(h.UnaryInterceptor), - grpc.WithStatsHandler(h.Handler), - ) - return o -} diff --git a/internal/telemetry/grpc_test.go b/internal/telemetry/grpc_test.go index 54bdc8021..d2e208f2b 100644 --- a/internal/telemetry/grpc_test.go +++ b/internal/telemetry/grpc_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" "go.opencensus.io/plugin/ocgrpc" - "google.golang.org/grpc" grpcstats "google.golang.org/grpc/stats" ) @@ -36,27 +35,3 @@ func Test_GRPCServerStatsHandler(t *testing.T) { assert.Equal(t, ctx.Value(mockCtxTag("added")), "true") assert.Equal(t, ctx.Value(mockCtxTag("original")), "true") } - -type mockDialOption struct { - name string - grpc.EmptyDialOption -} - -func Test_NewGRPCClientStatsHandler(t *testing.T) { - t.Parallel() - - h := NewGRPCClientStatsHandler("test") - - origOpts := []grpc.DialOption{ - mockDialOption{name: "one"}, - mockDialOption{name: "two"}, - } - - newOpts := h.DialOptions(origOpts...) - - for i := range origOpts { - assert.Contains(t, newOpts, origOpts[i]) - } - - assert.Greater(t, len(newOpts), len(origOpts)) -} diff --git a/internal/telemetry/requestid/grpc_client.go b/internal/telemetry/requestid/grpc_client.go index 9d506dd40..b773f1a06 100644 --- a/internal/telemetry/requestid/grpc_client.go +++ b/internal/telemetry/requestid/grpc_client.go @@ -14,7 +14,7 @@ func StreamClientInterceptor() grpc.StreamClientInterceptor { desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption, ) (grpc.ClientStream, error) { - toMetadata(ctx) + ctx = toMetadata(ctx) return streamer(ctx, desc, cc, method, opts...) } } @@ -26,21 +26,15 @@ func UnaryClientInterceptor() grpc.UnaryClientInterceptor { method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, ) error { - toMetadata(ctx) + ctx = toMetadata(ctx) return invoker(ctx, method, req, reply, cc, opts...) } } -func toMetadata(ctx context.Context) { +func toMetadata(ctx context.Context) context.Context { requestID := FromContext(ctx) if requestID == "" { requestID = New() } - - md, ok := metadata.FromOutgoingContext(ctx) - if !ok { - return - } - - md.Set(headerName, requestID) + return metadata.AppendToOutgoingContext(ctx, headerName, requestID) } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 37074cc9e..63d678666 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -22,6 +22,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry/requestid" + "github.com/pomerium/pomerium/pkg/grpcutil" ) const ( @@ -52,6 +53,9 @@ type Options struct { // ServiceName specifies the service name for telemetry exposition ServiceName string + + // SignedJWTKey is the JWT key to use for signing a JWT attached to metadata. + SignedJWTKey []byte } // NewGRPCClientConn returns a new gRPC pomerium service client connection. @@ -70,17 +74,27 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) { } } - dialOptions := []grpc.DialOption{ - grpc.WithChainUnaryInterceptor( - requestid.UnaryClientInterceptor(), - grpcTimeoutInterceptor(opts.RequestTimeout), - ), - grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()), - grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...), + clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName) + + unaryClientInterceptors := []grpc.UnaryClientInterceptor{ + requestid.UnaryClientInterceptor(), + grpcTimeoutInterceptor(opts.RequestTimeout), + clientStatsHandler.UnaryInterceptor, + } + streamClientInterceptors := []grpc.StreamClientInterceptor{ + requestid.StreamClientInterceptor(), + } + if opts.SignedJWTKey != nil { + unaryClientInterceptors = append(unaryClientInterceptors, grpcutil.WithUnarySignedJWT(opts.SignedJWTKey)) + streamClientInterceptors = append(streamClientInterceptors, grpcutil.WithStreamSignedJWT(opts.SignedJWTKey)) } - clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName) - dialOptions = clientStatsHandler.DialOptions(dialOptions...) + dialOptions := []grpc.DialOption{ + grpc.WithChainUnaryInterceptor(unaryClientInterceptors...), + grpc.WithChainStreamInterceptor(streamClientInterceptors...), + grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...), + grpc.WithStatsHandler(clientStatsHandler.Handler), + } if opts.WithInsecure { log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure") @@ -129,10 +143,8 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) { dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig()) connAddr = fmt.Sprintf("dns:///%s", connAddr) } - return grpc.Dial( - connAddr, - dialOptions..., - ) + + return grpc.Dial(connAddr, dialOptions...) } // grpcTimeoutInterceptor enforces per-RPC request timeouts diff --git a/pkg/grpcutil/options.go b/pkg/grpcutil/options.go new file mode 100644 index 000000000..16f307159 --- /dev/null +++ b/pkg/grpcutil/options.go @@ -0,0 +1,112 @@ +package grpcutil + +import ( + "context" + "encoding/base64" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests. +func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor { + return func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, streamer grpc.Streamer, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + ctx, err := withSignedJWT(ctx, key) + if err != nil { + return nil, err + } + + return streamer(ctx, desc, cc, method, opts...) + } +} + +// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests. +func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx, err := withSignedJWT(ctx, key) + if err != nil { + return err + } + + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) { + if len(key) > 0 { + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, + (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + return ctx, err + } + + rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }).CompactSerialize() + if err != nil { + return ctx, err + } + + ctx = WithOutgoingJWT(ctx, rawjwt) + } + return ctx, nil +} + +// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key. +func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor { + keyBS, _ := base64.StdEncoding.DecodeString(key) + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if err := requireSignedJWT(ctx, keyBS); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key. +func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor { + keyBS, _ := base64.StdEncoding.DecodeString(key) + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := requireSignedJWT(ss.Context(), keyBS); err != nil { + return err + } + return handler(srv, ss) + } +} + +func requireSignedJWT(ctx context.Context, key []byte) error { + if len(key) > 0 { + rawjwt, ok := JWTFromGRPCRequest(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "unauthenticated") + } + + tok, err := jwt.ParseSigned(rawjwt) + if err != nil { + return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err) + } + + var claims struct { + Expiry *jwt.NumericDate `json:"exp,omitempty"` + } + err = tok.Claims(key, &claims) + if err != nil { + return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err) + } + + if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) { + return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err) + } + } + return nil +}