diff --git a/pkg/grpcutil/grpcutil.go b/pkg/grpcutil/grpcutil.go index 9dd664a67..baeead3b6 100644 --- a/pkg/grpcutil/grpcutil.go +++ b/pkg/grpcutil/grpcutil.go @@ -29,3 +29,26 @@ func SessionIDFromGRPCRequest(ctx context.Context) (sessionID string, ok bool) { return sessionIDs[0], true } + +// JWTMetadataKey is the key in the metadata. +const JWTMetadataKey = "jwt" + +// WithOutgoingJWT appends a metadata header for the JWT to a context. +func WithOutgoingJWT(ctx context.Context, rawjwt string) context.Context { + return metadata.AppendToOutgoingContext(ctx, JWTMetadataKey, rawjwt) +} + +// JWTFromGRPCRequest returns the JWT from the gRPC request. +func JWTFromGRPCRequest(ctx context.Context) (rawjwt string, ok bool) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", false + } + + rawjwts := md.Get(JWTMetadataKey) + if len(rawjwts) == 0 { + return "", false + } + + return rawjwts[0], true +} diff --git a/pkg/grpcutil/grpcutil_test.go b/pkg/grpcutil/grpcutil_test.go index ce076cfd0..f9f181c6a 100644 --- a/pkg/grpcutil/grpcutil_test.go +++ b/pkg/grpcutil/grpcutil_test.go @@ -27,3 +27,25 @@ func TestSessionIDFromGRPCRequest(t *testing.T) { assert.True(t, ok) assert.Equal(t, "EXAMPLE", sessionID) } + +func TestWithOutgoingJWT(t *testing.T) { + rawjwt := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + ctx := context.Background() + ctx = WithOutgoingJWT(ctx, rawjwt) + md, ok := metadata.FromOutgoingContext(ctx) + if !assert.True(t, ok) { + return + } + assert.Equal(t, []string{rawjwt}, md.Get("jwt")) +} + +func TestJWTFromGRPCRequest(t *testing.T) { + rawjwt := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + ctx := context.Background() + ctx = metadata.NewIncomingContext(ctx, metadata.MD{ + "jwt": {rawjwt}, + }) + found, ok := JWTFromGRPCRequest(ctx) + assert.True(t, ok) + assert.Equal(t, rawjwt, found) +}