databroker: require JWT for access (#1503)

This commit is contained in:
Caleb Doxsey 2020-10-09 11:08:40 -06:00 committed by GitHub
parent 27d0cf180a
commit eb79cc0957
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 188 additions and 79 deletions

View file

@ -116,6 +116,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
state.jwk.Keys = append(state.jwk.Keys, *jwk) state.jwk.Keys = append(state.jwk.Keys, *jwk)
} }
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
Addr: cfg.Options.DataBrokerURL, Addr: cfg.Options.DataBrokerURL,
OverrideCertificateName: cfg.Options.OverrideCertificateName, OverrideCertificateName: cfg.Options.OverrideCertificateName,
@ -125,6 +127,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GRPCInsecure,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey,
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -1,6 +1,7 @@
package authorize package authorize
import ( import (
"encoding/base64"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
@ -41,6 +42,8 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
return nil, err return nil, err
} }
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{ cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
Addr: cfg.Options.DataBrokerURL, Addr: cfg.Options.DataBrokerURL,
OverrideCertificateName: cfg.Options.OverrideCertificateName, OverrideCertificateName: cfg.Options.OverrideCertificateName,
@ -50,6 +53,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GRPCInsecure,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err) return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err)

16
cache/cache.go vendored
View file

@ -5,6 +5,7 @@ package cache
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@ -21,6 +22,7 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpcutil"
) )
// Cache represents the cache service. The cache service is a simple interface // Cache represents the cache service. The cache service is a simple interface
@ -46,12 +48,22 @@ func New(cfg *config.Config) (*Cache, error) {
return nil, err return nil, err
} }
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
// No metrics handler because we have one in the control plane. Add one // No metrics handler because we have one in the control plane. Add one
// if we no longer register with that grpc Server // if we no longer register with that grpc Server
localGRPCServer := grpc.NewServer() localGRPCServer := grpc.NewServer(
grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)),
grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)),
)
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services) clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure()) clientDialOptions := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithChainUnaryInterceptor(clientStatsHandler.UnaryInterceptor, grpcutil.WithUnarySignedJWT(sharedKey)),
grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(sharedKey)),
grpc.WithStatsHandler(clientStatsHandler.Handler),
}
localGRPCConnection, err := grpc.DialContext( localGRPCConnection, err := grpc.DialContext(
context.Background(), context.Background(),

View file

@ -3,6 +3,7 @@ package main
import ( import (
"bufio" "bufio"
"context" "context"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
@ -91,23 +92,6 @@ var serviceAccountCmd = &cobra.Command{
l := zerolog.Nop() l := zerolog.Nop()
log.SetLogger(&l) log.SetLogger(&l)
dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL)
if err != nil {
return fmt.Errorf("invalid databroker url: %w", err)
}
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
Addr: dataBrokerURL,
OverrideCertificateName: serviceAccountOptions.overrideCertificateName,
CA: serviceAccountOptions.ca,
CAFile: serviceAccountOptions.caFile,
WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"),
})
if err != nil {
return fmt.Errorf("error creating databroker connection: %w", err)
}
defer cc.Close()
// hydrate our session // hydrate our session
serviceAccountOptions.serviceAccount.Audience = jwt.Audience(serviceAccountOptions.aud) serviceAccountOptions.serviceAccount.Audience = jwt.Audience(serviceAccountOptions.aud)
serviceAccountOptions.serviceAccount.Groups = []string(serviceAccountOptions.groups) serviceAccountOptions.serviceAccount.Groups = []string(serviceAccountOptions.groups)
@ -144,6 +128,26 @@ var serviceAccountCmd = &cobra.Command{
return errors.New("iss is required") return errors.New("iss is required")
} }
dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL)
if err != nil {
return fmt.Errorf("invalid databroker url: %w", err)
}
rawSharedKey, _ := base64.StdEncoding.DecodeString(sharedKey)
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
Addr: dataBrokerURL,
OverrideCertificateName: serviceAccountOptions.overrideCertificateName,
CA: serviceAccountOptions.ca,
CAFile: serviceAccountOptions.caFile,
WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"),
SignedJWTKey: rawSharedKey,
})
if err != nil {
return fmt.Errorf("error creating databroker connection: %w", err)
}
defer cc.Close()
sa := &user.ServiceAccount{ sa := &user.ServiceAccount{
Id: uuid.New().String(), Id: uuid.New().String(),
UserId: serviceAccountOptions.serviceAccount.User, UserId: serviceAccountOptions.serviceAccount.User,

View file

@ -20,7 +20,7 @@ func TestDashboard(t *testing.T) {
t.Run("admin impersonate", func(t *testing.T) { t.Run("admin impersonate", func(t *testing.T) {
client := testcluster.NewHTTPClient() client := testcluster.NewHTTPClient()
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"), _, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"),
flows.WithEmail("bob@dogs.test"), flows.WithGroups("user")) flows.WithEmail("bob@dogs.test"), flows.WithGroups("user"))
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
@ -31,7 +31,7 @@ func TestDashboard(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err = client.Do(req) res, err := client.Do(req)
if !assert.NoError(t, err, "unexpected http error") { if !assert.NoError(t, err, "unexpected http error") {
return return
} }

View file

@ -2,6 +2,7 @@ package databroker
import ( import (
"context" "context"
"encoding/base64"
"errors" "errors"
"sync" "sync"
"time" "time"
@ -138,6 +139,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
} }
func (src *ConfigSource) runUpdater(cfg *config.Config) { func (src *ConfigSource) runUpdater(cfg *config.Config) {
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
connectionOptions := &grpc.Options{ connectionOptions := &grpc.Options{
Addr: cfg.Options.DataBrokerURL, Addr: cfg.Options.DataBrokerURL,
OverrideCertificateName: cfg.Options.OverrideCertificateName, OverrideCertificateName: cfg.Options.OverrideCertificateName,
@ -147,6 +149,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GRPCInsecure,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey,
} }
h, err := hashstructure.Hash(connectionOptions, nil) h, err := hashstructure.Hash(connectionOptions, nil)
if err != nil { if err != nil {

View file

@ -55,13 +55,3 @@ func NewGRPCClientStatsHandler(service string) *GRPCClientStatsHandler {
UnaryInterceptor: metrics.GRPCClientInterceptor(ServiceName(service)), UnaryInterceptor: metrics.GRPCClientInterceptor(ServiceName(service)),
} }
} }
// DialOptions returns telemetry related DialOptions appended to an optional existing list
// of DialOptions
func (h *GRPCClientStatsHandler) DialOptions(o ...grpc.DialOption) []grpc.DialOption {
o = append(o,
grpc.WithUnaryInterceptor(h.UnaryInterceptor),
grpc.WithStatsHandler(h.Handler),
)
return o
}

View file

@ -6,7 +6,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.opencensus.io/plugin/ocgrpc" "go.opencensus.io/plugin/ocgrpc"
"google.golang.org/grpc"
grpcstats "google.golang.org/grpc/stats" grpcstats "google.golang.org/grpc/stats"
) )
@ -36,27 +35,3 @@ func Test_GRPCServerStatsHandler(t *testing.T) {
assert.Equal(t, ctx.Value(mockCtxTag("added")), "true") assert.Equal(t, ctx.Value(mockCtxTag("added")), "true")
assert.Equal(t, ctx.Value(mockCtxTag("original")), "true") assert.Equal(t, ctx.Value(mockCtxTag("original")), "true")
} }
type mockDialOption struct {
name string
grpc.EmptyDialOption
}
func Test_NewGRPCClientStatsHandler(t *testing.T) {
t.Parallel()
h := NewGRPCClientStatsHandler("test")
origOpts := []grpc.DialOption{
mockDialOption{name: "one"},
mockDialOption{name: "two"},
}
newOpts := h.DialOptions(origOpts...)
for i := range origOpts {
assert.Contains(t, newOpts, origOpts[i])
}
assert.Greater(t, len(newOpts), len(origOpts))
}

View file

@ -14,7 +14,7 @@ func StreamClientInterceptor() grpc.StreamClientInterceptor {
desc *grpc.StreamDesc, cc *grpc.ClientConn, desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption, method string, streamer grpc.Streamer, opts ...grpc.CallOption,
) (grpc.ClientStream, error) { ) (grpc.ClientStream, error) {
toMetadata(ctx) ctx = toMetadata(ctx)
return streamer(ctx, desc, cc, method, opts...) return streamer(ctx, desc, cc, method, opts...)
} }
} }
@ -26,21 +26,15 @@ func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
method string, req, reply interface{}, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error { ) error {
toMetadata(ctx) ctx = toMetadata(ctx)
return invoker(ctx, method, req, reply, cc, opts...) return invoker(ctx, method, req, reply, cc, opts...)
} }
} }
func toMetadata(ctx context.Context) { func toMetadata(ctx context.Context) context.Context {
requestID := FromContext(ctx) requestID := FromContext(ctx)
if requestID == "" { if requestID == "" {
requestID = New() requestID = New()
} }
return metadata.AppendToOutgoingContext(ctx, headerName, requestID)
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
return
}
md.Set(headerName, requestID)
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/pkg/grpcutil"
) )
const ( const (
@ -52,6 +53,9 @@ type Options struct {
// ServiceName specifies the service name for telemetry exposition // ServiceName specifies the service name for telemetry exposition
ServiceName string ServiceName string
// SignedJWTKey is the JWT key to use for signing a JWT attached to metadata.
SignedJWTKey []byte
} }
// NewGRPCClientConn returns a new gRPC pomerium service client connection. // NewGRPCClientConn returns a new gRPC pomerium service client connection.
@ -70,17 +74,27 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
} }
} }
dialOptions := []grpc.DialOption{ clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
grpc.WithChainUnaryInterceptor(
requestid.UnaryClientInterceptor(), unaryClientInterceptors := []grpc.UnaryClientInterceptor{
grpcTimeoutInterceptor(opts.RequestTimeout), requestid.UnaryClientInterceptor(),
), grpcTimeoutInterceptor(opts.RequestTimeout),
grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()), clientStatsHandler.UnaryInterceptor,
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...), }
streamClientInterceptors := []grpc.StreamClientInterceptor{
requestid.StreamClientInterceptor(),
}
if opts.SignedJWTKey != nil {
unaryClientInterceptors = append(unaryClientInterceptors, grpcutil.WithUnarySignedJWT(opts.SignedJWTKey))
streamClientInterceptors = append(streamClientInterceptors, grpcutil.WithStreamSignedJWT(opts.SignedJWTKey))
} }
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName) dialOptions := []grpc.DialOption{
dialOptions = clientStatsHandler.DialOptions(dialOptions...) grpc.WithChainUnaryInterceptor(unaryClientInterceptors...),
grpc.WithChainStreamInterceptor(streamClientInterceptors...),
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
grpc.WithStatsHandler(clientStatsHandler.Handler),
}
if opts.WithInsecure { if opts.WithInsecure {
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure") log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
@ -129,10 +143,8 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig()) dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig())
connAddr = fmt.Sprintf("dns:///%s", connAddr) connAddr = fmt.Sprintf("dns:///%s", connAddr)
} }
return grpc.Dial(
connAddr, return grpc.Dial(connAddr, dialOptions...)
dialOptions...,
)
} }
// grpcTimeoutInterceptor enforces per-RPC request timeouts // grpcTimeoutInterceptor enforces per-RPC request timeouts

112
pkg/grpcutil/options.go Normal file
View file

@ -0,0 +1,112 @@
package grpcutil
import (
"context"
"encoding/base64"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor {
return func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string, streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
ctx, err := withSignedJWT(ctx, key)
if err != nil {
return nil, err
}
return streamer(ctx, desc, cc, method, opts...)
}
}
// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, err := withSignedJWT(ctx, key)
if err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
if len(key) > 0 {
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return ctx, err
}
rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).CompactSerialize()
if err != nil {
return ctx, err
}
ctx = WithOutgoingJWT(ctx, rawjwt)
}
return ctx, nil
}
// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
keyBS, _ := base64.StdEncoding.DecodeString(key)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if err := requireSignedJWT(ctx, keyBS); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
keyBS, _ := base64.StdEncoding.DecodeString(key)
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := requireSignedJWT(ss.Context(), keyBS); err != nil {
return err
}
return handler(srv, ss)
}
}
func requireSignedJWT(ctx context.Context, key []byte) error {
if len(key) > 0 {
rawjwt, ok := JWTFromGRPCRequest(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "unauthenticated")
}
tok, err := jwt.ParseSigned(rawjwt)
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
}
var claims struct {
Expiry *jwt.NumericDate `json:"exp,omitempty"`
}
err = tok.Claims(key, &claims)
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
}
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
}
}
return nil
}