mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 00:47:17 +02:00
databroker: require JWT for access (#1503)
This commit is contained in:
parent
27d0cf180a
commit
eb79cc0957
11 changed files with 188 additions and 79 deletions
|
@ -116,6 +116,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
||||||
state.jwk.Keys = append(state.jwk.Keys, *jwk)
|
state.jwk.Keys = append(state.jwk.Keys, *jwk)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||||
|
|
||||||
dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
dataBrokerConn, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||||
Addr: cfg.Options.DataBrokerURL,
|
Addr: cfg.Options.DataBrokerURL,
|
||||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||||
|
@ -125,6 +127,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
||||||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||||
WithInsecure: cfg.Options.GRPCInsecure,
|
WithInsecure: cfg.Options.GRPCInsecure,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
SignedJWTKey: sharedKey,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package authorize
|
package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
@ -41,6 +42,8 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||||
|
|
||||||
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||||
Addr: cfg.Options.DataBrokerURL,
|
Addr: cfg.Options.DataBrokerURL,
|
||||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||||
|
@ -50,6 +53,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
|
||||||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||||
WithInsecure: cfg.Options.GRPCInsecure,
|
WithInsecure: cfg.Options.GRPCInsecure,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
SignedJWTKey: sharedKey,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err)
|
return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err)
|
||||||
|
|
16
cache/cache.go
vendored
16
cache/cache.go
vendored
|
@ -5,6 +5,7 @@ package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -21,6 +22,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Cache represents the cache service. The cache service is a simple interface
|
// Cache represents the cache service. The cache service is a simple interface
|
||||||
|
@ -46,12 +48,22 @@ func New(cfg *config.Config) (*Cache, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||||
|
|
||||||
// No metrics handler because we have one in the control plane. Add one
|
// No metrics handler because we have one in the control plane. Add one
|
||||||
// if we no longer register with that grpc Server
|
// if we no longer register with that grpc Server
|
||||||
localGRPCServer := grpc.NewServer()
|
localGRPCServer := grpc.NewServer(
|
||||||
|
grpc.StreamInterceptor(grpcutil.StreamRequireSignedJWT(cfg.Options.SharedKey)),
|
||||||
|
grpc.UnaryInterceptor(grpcutil.UnaryRequireSignedJWT(cfg.Options.SharedKey)),
|
||||||
|
)
|
||||||
|
|
||||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
||||||
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure())
|
clientDialOptions := []grpc.DialOption{
|
||||||
|
grpc.WithInsecure(),
|
||||||
|
grpc.WithChainUnaryInterceptor(clientStatsHandler.UnaryInterceptor, grpcutil.WithUnarySignedJWT(sharedKey)),
|
||||||
|
grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(sharedKey)),
|
||||||
|
grpc.WithStatsHandler(clientStatsHandler.Handler),
|
||||||
|
}
|
||||||
|
|
||||||
localGRPCConnection, err := grpc.DialContext(
|
localGRPCConnection, err := grpc.DialContext(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
|
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -91,23 +92,6 @@ var serviceAccountCmd = &cobra.Command{
|
||||||
l := zerolog.Nop()
|
l := zerolog.Nop()
|
||||||
log.SetLogger(&l)
|
log.SetLogger(&l)
|
||||||
|
|
||||||
dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid databroker url: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
|
||||||
Addr: dataBrokerURL,
|
|
||||||
OverrideCertificateName: serviceAccountOptions.overrideCertificateName,
|
|
||||||
CA: serviceAccountOptions.ca,
|
|
||||||
CAFile: serviceAccountOptions.caFile,
|
|
||||||
WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error creating databroker connection: %w", err)
|
|
||||||
}
|
|
||||||
defer cc.Close()
|
|
||||||
|
|
||||||
// hydrate our session
|
// hydrate our session
|
||||||
serviceAccountOptions.serviceAccount.Audience = jwt.Audience(serviceAccountOptions.aud)
|
serviceAccountOptions.serviceAccount.Audience = jwt.Audience(serviceAccountOptions.aud)
|
||||||
serviceAccountOptions.serviceAccount.Groups = []string(serviceAccountOptions.groups)
|
serviceAccountOptions.serviceAccount.Groups = []string(serviceAccountOptions.groups)
|
||||||
|
@ -144,6 +128,26 @@ var serviceAccountCmd = &cobra.Command{
|
||||||
return errors.New("iss is required")
|
return errors.New("iss is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dataBrokerURL, err := url.Parse(serviceAccountOptions.dataBrokerURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid databroker url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawSharedKey, _ := base64.StdEncoding.DecodeString(sharedKey)
|
||||||
|
|
||||||
|
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||||
|
Addr: dataBrokerURL,
|
||||||
|
OverrideCertificateName: serviceAccountOptions.overrideCertificateName,
|
||||||
|
CA: serviceAccountOptions.ca,
|
||||||
|
CAFile: serviceAccountOptions.caFile,
|
||||||
|
WithInsecure: !strings.HasSuffix(dataBrokerURL.Scheme, "s"),
|
||||||
|
SignedJWTKey: rawSharedKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating databroker connection: %w", err)
|
||||||
|
}
|
||||||
|
defer cc.Close()
|
||||||
|
|
||||||
sa := &user.ServiceAccount{
|
sa := &user.ServiceAccount{
|
||||||
Id: uuid.New().String(),
|
Id: uuid.New().String(),
|
||||||
UserId: serviceAccountOptions.serviceAccount.User,
|
UserId: serviceAccountOptions.serviceAccount.User,
|
||||||
|
|
|
@ -20,7 +20,7 @@ func TestDashboard(t *testing.T) {
|
||||||
t.Run("admin impersonate", func(t *testing.T) {
|
t.Run("admin impersonate", func(t *testing.T) {
|
||||||
client := testcluster.NewHTTPClient()
|
client := testcluster.NewHTTPClient()
|
||||||
|
|
||||||
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"),
|
_, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"),
|
||||||
flows.WithEmail("bob@dogs.test"), flows.WithGroups("user"))
|
flows.WithEmail("bob@dogs.test"), flows.WithGroups("user"))
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return
|
return
|
||||||
|
@ -31,7 +31,7 @@ func TestDashboard(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = client.Do(req)
|
res, err := client.Do(req)
|
||||||
if !assert.NoError(t, err, "unexpected http error") {
|
if !assert.NoError(t, err, "unexpected http error") {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -138,6 +139,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||||
|
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||||
connectionOptions := &grpc.Options{
|
connectionOptions := &grpc.Options{
|
||||||
Addr: cfg.Options.DataBrokerURL,
|
Addr: cfg.Options.DataBrokerURL,
|
||||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||||
|
@ -147,6 +149,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||||
WithInsecure: cfg.Options.GRPCInsecure,
|
WithInsecure: cfg.Options.GRPCInsecure,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
SignedJWTKey: sharedKey,
|
||||||
}
|
}
|
||||||
h, err := hashstructure.Hash(connectionOptions, nil)
|
h, err := hashstructure.Hash(connectionOptions, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -55,13 +55,3 @@ func NewGRPCClientStatsHandler(service string) *GRPCClientStatsHandler {
|
||||||
UnaryInterceptor: metrics.GRPCClientInterceptor(ServiceName(service)),
|
UnaryInterceptor: metrics.GRPCClientInterceptor(ServiceName(service)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOptions returns telemetry related DialOptions appended to an optional existing list
|
|
||||||
// of DialOptions
|
|
||||||
func (h *GRPCClientStatsHandler) DialOptions(o ...grpc.DialOption) []grpc.DialOption {
|
|
||||||
o = append(o,
|
|
||||||
grpc.WithUnaryInterceptor(h.UnaryInterceptor),
|
|
||||||
grpc.WithStatsHandler(h.Handler),
|
|
||||||
)
|
|
||||||
return o
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.opencensus.io/plugin/ocgrpc"
|
"go.opencensus.io/plugin/ocgrpc"
|
||||||
"google.golang.org/grpc"
|
|
||||||
grpcstats "google.golang.org/grpc/stats"
|
grpcstats "google.golang.org/grpc/stats"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,27 +35,3 @@ func Test_GRPCServerStatsHandler(t *testing.T) {
|
||||||
assert.Equal(t, ctx.Value(mockCtxTag("added")), "true")
|
assert.Equal(t, ctx.Value(mockCtxTag("added")), "true")
|
||||||
assert.Equal(t, ctx.Value(mockCtxTag("original")), "true")
|
assert.Equal(t, ctx.Value(mockCtxTag("original")), "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockDialOption struct {
|
|
||||||
name string
|
|
||||||
grpc.EmptyDialOption
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NewGRPCClientStatsHandler(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
h := NewGRPCClientStatsHandler("test")
|
|
||||||
|
|
||||||
origOpts := []grpc.DialOption{
|
|
||||||
mockDialOption{name: "one"},
|
|
||||||
mockDialOption{name: "two"},
|
|
||||||
}
|
|
||||||
|
|
||||||
newOpts := h.DialOptions(origOpts...)
|
|
||||||
|
|
||||||
for i := range origOpts {
|
|
||||||
assert.Contains(t, newOpts, origOpts[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Greater(t, len(newOpts), len(origOpts))
|
|
||||||
}
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ func StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||||
desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||||
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
|
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
|
||||||
) (grpc.ClientStream, error) {
|
) (grpc.ClientStream, error) {
|
||||||
toMetadata(ctx)
|
ctx = toMetadata(ctx)
|
||||||
return streamer(ctx, desc, cc, method, opts...)
|
return streamer(ctx, desc, cc, method, opts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,21 +26,15 @@ func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||||
method string, req, reply interface{},
|
method string, req, reply interface{},
|
||||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
|
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
|
||||||
) error {
|
) error {
|
||||||
toMetadata(ctx)
|
ctx = toMetadata(ctx)
|
||||||
return invoker(ctx, method, req, reply, cc, opts...)
|
return invoker(ctx, method, req, reply, cc, opts...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toMetadata(ctx context.Context) {
|
func toMetadata(ctx context.Context) context.Context {
|
||||||
requestID := FromContext(ctx)
|
requestID := FromContext(ctx)
|
||||||
if requestID == "" {
|
if requestID == "" {
|
||||||
requestID = New()
|
requestID = New()
|
||||||
}
|
}
|
||||||
|
return metadata.AppendToOutgoingContext(ctx, headerName, requestID)
|
||||||
md, ok := metadata.FromOutgoingContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
md.Set(headerName, requestID)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -52,6 +53,9 @@ type Options struct {
|
||||||
|
|
||||||
// ServiceName specifies the service name for telemetry exposition
|
// ServiceName specifies the service name for telemetry exposition
|
||||||
ServiceName string
|
ServiceName string
|
||||||
|
|
||||||
|
// SignedJWTKey is the JWT key to use for signing a JWT attached to metadata.
|
||||||
|
SignedJWTKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
|
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
|
||||||
|
@ -70,17 +74,27 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dialOptions := []grpc.DialOption{
|
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
|
||||||
grpc.WithChainUnaryInterceptor(
|
|
||||||
requestid.UnaryClientInterceptor(),
|
unaryClientInterceptors := []grpc.UnaryClientInterceptor{
|
||||||
grpcTimeoutInterceptor(opts.RequestTimeout),
|
requestid.UnaryClientInterceptor(),
|
||||||
),
|
grpcTimeoutInterceptor(opts.RequestTimeout),
|
||||||
grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()),
|
clientStatsHandler.UnaryInterceptor,
|
||||||
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
}
|
||||||
|
streamClientInterceptors := []grpc.StreamClientInterceptor{
|
||||||
|
requestid.StreamClientInterceptor(),
|
||||||
|
}
|
||||||
|
if opts.SignedJWTKey != nil {
|
||||||
|
unaryClientInterceptors = append(unaryClientInterceptors, grpcutil.WithUnarySignedJWT(opts.SignedJWTKey))
|
||||||
|
streamClientInterceptors = append(streamClientInterceptors, grpcutil.WithStreamSignedJWT(opts.SignedJWTKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
|
dialOptions := []grpc.DialOption{
|
||||||
dialOptions = clientStatsHandler.DialOptions(dialOptions...)
|
grpc.WithChainUnaryInterceptor(unaryClientInterceptors...),
|
||||||
|
grpc.WithChainStreamInterceptor(streamClientInterceptors...),
|
||||||
|
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
||||||
|
grpc.WithStatsHandler(clientStatsHandler.Handler),
|
||||||
|
}
|
||||||
|
|
||||||
if opts.WithInsecure {
|
if opts.WithInsecure {
|
||||||
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
|
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
|
||||||
|
@ -129,10 +143,8 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
||||||
dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig())
|
dialOptions = append(dialOptions, grpc.WithBalancerName(roundrobin.Name), grpc.WithDisableServiceConfig())
|
||||||
connAddr = fmt.Sprintf("dns:///%s", connAddr)
|
connAddr = fmt.Sprintf("dns:///%s", connAddr)
|
||||||
}
|
}
|
||||||
return grpc.Dial(
|
|
||||||
connAddr,
|
return grpc.Dial(connAddr, dialOptions...)
|
||||||
dialOptions...,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// grpcTimeoutInterceptor enforces per-RPC request timeouts
|
// grpcTimeoutInterceptor enforces per-RPC request timeouts
|
||||||
|
|
112
pkg/grpcutil/options.go
Normal file
112
pkg/grpcutil/options.go
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
package grpcutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
|
||||||
|
func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor {
|
||||||
|
return func(
|
||||||
|
ctx context.Context,
|
||||||
|
desc *grpc.StreamDesc,
|
||||||
|
cc *grpc.ClientConn,
|
||||||
|
method string, streamer grpc.Streamer,
|
||||||
|
opts ...grpc.CallOption,
|
||||||
|
) (grpc.ClientStream, error) {
|
||||||
|
ctx, err := withSignedJWT(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return streamer(ctx, desc, cc, method, opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
|
||||||
|
func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor {
|
||||||
|
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||||
|
ctx, err := withSignedJWT(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return invoker(ctx, method, req, reply, cc, opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
|
||||||
|
if len(key) > 0 {
|
||||||
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
||||||
|
(&jose.SignerOptions{}).WithType("JWT"))
|
||||||
|
if err != nil {
|
||||||
|
return ctx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
}).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
return ctx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = WithOutgoingJWT(ctx, rawjwt)
|
||||||
|
}
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
||||||
|
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
||||||
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
||||||
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||||
|
if err := requireSignedJWT(ctx, keyBS); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
||||||
|
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
|
||||||
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
||||||
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
if err := requireSignedJWT(ss.Context(), keyBS); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireSignedJWT(ctx context.Context, key []byte) error {
|
||||||
|
if len(key) > 0 {
|
||||||
|
rawjwt, ok := JWTFromGRPCRequest(ctx)
|
||||||
|
if !ok {
|
||||||
|
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := jwt.ParseSigned(rawjwt)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||||
|
}
|
||||||
|
err = tok.Claims(key, &claims)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
|
||||||
|
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue