mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 00:47:17 +02:00
databroker: require JWT for access (#1503)
This commit is contained in:
parent
27d0cf180a
commit
eb79cc0957
11 changed files with 188 additions and 79 deletions
|
@ -116,6 +116,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
state.jwk.Keys = append(state.jwk.Keys, *jwk)
|
||||
}
|
||||
|
||||
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||
|
||||
dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||
Addr: cfg.Options.DataBrokerURL,
|
||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||
|
@ -125,6 +127,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||
WithInsecure: cfg.Options.GRPCInsecure,
|
||||
ServiceName: cfg.Options.Services,
|
||||
SignedJWTKey: sharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -41,6 +42,8 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
|
|||
return nil, err
|
||||
}
|
||||
|
||||
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||
|
||||
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||
Addr: cfg.Options.DataBrokerURL,
|
||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||
|
@ -50,6 +53,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
|
|||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||
WithInsecure: cfg.Options.GRPCInsecure,
|
||||
ServiceName: cfg.Options.Services,
|
||||
SignedJWTKey: sharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err)
|
||||
|
|
16
cache/cache.go
vendored
16
cache/cache.go
vendored
|
@ -5,6 +5,7 @@ package cache
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -21,6 +22,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
)
|
||||
|
||||
// Cache represents the cache service. The cache service is a simple interface
|
||||
|
@ -46,12 +48,22 @@ func New(cfg *config.Config) (*Cache, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||
|
||||
// No metrics handler because we have one in the control plane. Add one
|
||||
// if we no longer register with that grpc Server
|
||||
localGRPCServer := grpc.NewServer()
|
||||
localGRPCServer := grpc.NewServer(
|
||||
grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)),
|
||||
grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)),
|
||||
)
|
||||
|
||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
||||
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure())
|
||||
clientDialOptions := []grpc.DialOption{
|
||||
grpc.WithInsecure(),
|
||||
grpc.WithChainUnaryInterceptor(clientStatsHandler.UnaryInterceptor, grpcutil.WithUnarySignedJWT(sharedKey)),
|
||||
grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(sharedKey)),
|
||||
grpc.WithStatsHandler(clientStatsHandler.Handler),
|
||||
}
|
||||
|
||||
localGRPCConnection, err := grpc.DialContext(
|
||||
context.Background(),
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
@ -91,23 +92,6 @@ var serviceAccountCmd = &cobra.Command{
|
|||
l := zerolog.Nop()
|
||||
log.SetLogger(&l)
|
||||
|
||||
dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid databroker url: %w", err)
|
||||
}
|
||||
|
||||
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||
Addr: dataBrokerURL,
|
||||
OverrideCertificateName: serviceAccountOptions.overrideCertificateName,
|
||||
CA: serviceAccountOptions.ca,
|
||||
CAFile: serviceAccountOptions.caFile,
|
||||
WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating databroker connection: %w", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
// hydrate our session
|
||||
serviceAccountOptions.serviceAccount.Audience = jwt.Audience(serviceAccountOptions.aud)
|
||||
serviceAccountOptions.serviceAccount.Groups = []string(serviceAccountOptions.groups)
|
||||
|
@ -144,6 +128,26 @@ var serviceAccountCmd = &cobra.Command{
|
|||
return errors.New("iss is required")
|
||||
}
|
||||
|
||||
dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid databroker url: %w", err)
|
||||
}
|
||||
|
||||
rawSharedKey, _ := base64.StdEncoding.DecodeString(sharedKey)
|
||||
|
||||
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||
Addr: dataBrokerURL,
|
||||
OverrideCertificateName: serviceAccountOptions.overrideCertificateName,
|
||||
CA: serviceAccountOptions.ca,
|
||||
CAFile: serviceAccountOptions.caFile,
|
||||
WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"),
|
||||
SignedJWTKey: rawSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating databroker connection: %w", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
sa := &user.ServiceAccount{
|
||||
Id: uuid.New().String(),
|
||||
UserId: serviceAccountOptions.serviceAccount.User,
|
||||
|
|
|
@ -20,7 +20,7 @@ func TestDashboard(t *testing.T) {
|
|||
t.Run("admin impersonate", func(t *testing.T) {
|
||||
client := testcluster.NewHTTPClient()
|
||||
|
||||
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"),
|
||||
_, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"),
|
||||
flows.WithEmail("bob@dogs.test"), flows.WithGroups("user"))
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
|
@ -31,7 +31,7 @@ func TestDashboard(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
res, err = client.Do(req)
|
||||
res, err := client.Do(req)
|
||||
if !assert.NoError(t, err, "unexpected http error") {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -138,6 +139,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
}
|
||||
|
||||
func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||
connectionOptions := &grpc.Options{
|
||||
Addr: cfg.Options.DataBrokerURL,
|
||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||
|
@ -147,6 +149,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
|||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||
WithInsecure: cfg.Options.GRPCInsecure,
|
||||
ServiceName: cfg.Options.Services,
|
||||
SignedJWTKey: sharedKey,
|
||||
}
|
||||
h, err := hashstructure.Hash(connectionOptions, nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -55,13 +55,3 @@ func NewGRPCClientStatsHandler(service string) *GRPCClientStatsHandler {
|
|||
UnaryInterceptor: metrics.GRPCClientInterceptor(ServiceName(service)),
|
||||
}
|
||||
}
|
||||
|
||||
// DialOptions returns telemetry related DialOptions appended to an optional existing list
|
||||
// of DialOptions
|
||||
func (h *GRPCClientStatsHandler) DialOptions(o ...grpc.DialOption) []grpc.DialOption {
|
||||
o = append(o,
|
||||
grpc.WithUnaryInterceptor(h.UnaryInterceptor),
|
||||
grpc.WithStatsHandler(h.Handler),
|
||||
)
|
||||
return o
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.opencensus.io/plugin/ocgrpc"
|
||||
"google.golang.org/grpc"
|
||||
grpcstats "google.golang.org/grpc/stats"
|
||||
)
|
||||
|
||||
|
@ -36,27 +35,3 @@ func Test_GRPCServerStatsHandler(t *testing.T) {
|
|||
assert.Equal(t, ctx.Value(mockCtxTag("added")), "true")
|
||||
assert.Equal(t, ctx.Value(mockCtxTag("original")), "true")
|
||||
}
|
||||
|
||||
type mockDialOption struct {
|
||||
name string
|
||||
grpc.EmptyDialOption
|
||||
}
|
||||
|
||||
func Test_NewGRPCClientStatsHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := NewGRPCClientStatsHandler("test")
|
||||
|
||||
origOpts := []grpc.DialOption{
|
||||
mockDialOption{name: "one"},
|
||||
mockDialOption{name: "two"},
|
||||
}
|
||||
|
||||
newOpts := h.DialOptions(origOpts...)
|
||||
|
||||
for i := range origOpts {
|
||||
assert.Contains(t, newOpts, origOpts[i])
|
||||
}
|
||||
|
||||
assert.Greater(t, len(newOpts), len(origOpts))
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ func StreamClientInterceptor() grpc.StreamClientInterceptor {
|
|||
desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
|
||||
) (grpc.ClientStream, error) {
|
||||
toMetadata(ctx)
|
||||
ctx = toMetadata(ctx)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
}
|
||||
|
@ -26,21 +26,15 @@ func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
|||
method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
|
||||
) error {
|
||||
toMetadata(ctx)
|
||||
ctx = toMetadata(ctx)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func toMetadata(ctx context.Context) {
|
||||
func toMetadata(ctx context.Context) context.Context {
|
||||
requestID := FromContext(ctx)
|
||||
if requestID == "" {
|
||||
requestID = New()
|
||||
}
|
||||
|
||||
md, ok := metadata.FromOutgoingContext(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
md.Set(headerName, requestID)
|
||||
return metadata.AppendToOutgoingContext(ctx, headerName, requestID)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -52,6 +53,9 @@ type Options struct {
|
|||
|
||||
// ServiceName specifies the service name for telemetry exposition
|
||||
ServiceName string
|
||||
|
||||
// SignedJWTKey is the JWT key to use for signing a JWT attached to metadata.
|
||||
SignedJWTKey []byte
|
||||
}
|
||||
|
||||
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
|
||||
|
@ -70,17 +74,27 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
dialOptions := []grpc.DialOption{
|
||||
grpc.WithChainUnaryInterceptor(
|
||||
requestid.UnaryClientInterceptor(),
|
||||
grpcTimeoutInterceptor(opts.RequestTimeout),
|
||||
),
|
||||
grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()),
|
||||
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
|
||||
|
||||
unaryClientInterceptors := []grpc.UnaryClientInterceptor{
|
||||
requestid.UnaryClientInterceptor(),
|
||||
grpcTimeoutInterceptor(opts.RequestTimeout),
|
||||
clientStatsHandler.UnaryInterceptor,
|
||||
}
|
||||
streamClientInterceptors := []grpc.StreamClientInterceptor{
|
||||
requestid.StreamClientInterceptor(),
|
||||
}
|
||||
if opts.SignedJWTKey != nil {
|
||||
unaryClientInterceptors = append(unaryClientInterceptors, grpcutil.WithUnarySignedJWT(opts.SignedJWTKey))
|
||||
streamClientInterceptors = append(streamClientInterceptors, grpcutil.WithStreamSignedJWT(opts.SignedJWTKey))
|
||||
}
|
||||
|
||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
|
||||
dialOptions = clientStatsHandler.DialOptions(dialOptions...)
|
||||
dialOptions := []grpc.DialOption{
|
||||
grpc.WithChainUnaryInterceptor(unaryClientInterceptors...),
|
||||
grpc.WithChainStreamInterceptor(streamClientInterceptors...),
|
||||
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
||||
grpc.WithStatsHandler(clientStatsHandler.Handler),
|
||||
}
|
||||
|
||||
if opts.WithInsecure {
|
||||
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
|
||||
|
@ -129,10 +143,8 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
|||
dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig())
|
||||
connAddr = fmt.Sprintf("dns:///%s", connAddr)
|
||||
}
|
||||
return grpc.Dial(
|
||||
connAddr,
|
||||
dialOptions...,
|
||||
)
|
||||
|
||||
return grpc.Dial(connAddr, dialOptions...)
|
||||
}
|
||||
|
||||
// grpcTimeoutInterceptor enforces per-RPC request timeouts
|
||||
|
|
112
pkg/grpcutil/options.go
Normal file
112
pkg/grpcutil/options.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package grpcutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
|
||||
func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
desc *grpc.StreamDesc,
|
||||
cc *grpc.ClientConn,
|
||||
method string, streamer grpc.Streamer,
|
||||
opts ...grpc.CallOption,
|
||||
) (grpc.ClientStream, error) {
|
||||
ctx, err := withSignedJWT(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
|
||||
func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
ctx, err := withSignedJWT(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
|
||||
if len(key) > 0 {
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
}).CompactSerialize()
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
ctx = WithOutgoingJWT(ctx, rawjwt)
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
||||
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
||||
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
if err := requireSignedJWT(ctx, keyBS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
||||
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
|
||||
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := requireSignedJWT(ss.Context(), keyBS); err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
func requireSignedJWT(ctx context.Context, key []byte) error {
|
||||
if len(key) > 0 {
|
||||
rawjwt, ok := JWTFromGRPCRequest(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
|
||||
tok, err := jwt.ParseSigned(rawjwt)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||
}
|
||||
err = tok.Claims(key, &claims)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
||||
}
|
||||
|
||||
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
|
||||
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue