From c01195738961636de10413cf6e6bf735227816de Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:17:03 -0700 Subject: [PATCH] grpcutil: additional JWT validation (#5303) Add additional validation to the grpcutil.RequireSignedJWT method. Log any validation error, instead of returning error details in the gRPC status message. --- pkg/grpcutil/options.go | 40 ++++++++++++-------- pkg/grpcutil/options_test.go | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 15 deletions(-) diff --git a/pkg/grpcutil/options.go b/pkg/grpcutil/options.go index d3329251e..5e1a4255b 100644 --- a/pkg/grpcutil/options.go +++ b/pkg/grpcutil/options.go @@ -3,10 +3,12 @@ package grpcutil import ( "context" "encoding/base64" + "fmt" "time" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" + "github.com/pomerium/pomerium/internal/log" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -92,22 +94,30 @@ func RequireSignedJWT(ctx context.Context, key []byte) error { return status.Error(codes.Unauthenticated, "unauthenticated") } - tok, err := jwt.ParseSigned(rawjwt) - if err != nil { - return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err) - } - - var claims struct { - Expiry *jwt.NumericDate `json:"exp,omitempty"` - } - err = tok.Claims(key, &claims) - if err != nil { - return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err) - } - - if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) { - return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err) + if err := validateJWT(rawjwt, key); err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("rejected gRPC request due to invalid JWT") + return status.Error(codes.Unauthenticated, "invalid JWT") } } return nil } + +func validateJWT(rawjwt string, key []byte) error { + tok, err := jwt.ParseSigned(rawjwt) + if err != nil { + return err + } + + var claims map[string]*jwt.NumericDate + err = tok.Claims(key, &claims) + if err != nil { + return err + } else if len(claims) != 1 || claims["exp"] == nil { + return fmt.Errorf("expected exactly one claim (exp)") + } + + if t := claims["exp"].Time(); time.Now().After(t) { + return fmt.Errorf("JWT expired at %s", t.Format(time.DateTime)) + } + return nil +} diff --git a/pkg/grpcutil/options_test.go b/pkg/grpcutil/options_test.go index 35c715cd4..0c48fd2fd 100644 --- a/pkg/grpcutil/options_test.go +++ b/pkg/grpcutil/options_test.go @@ -9,7 +9,10 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/reflection" @@ -100,3 +103,73 @@ func TestSignedJWT(t *testing.T) { assert.Equal(t, codes.OK, status.Code(err)) }) } + +func TestValidateJWT(t *testing.T) { + sign := func(t *testing.T, key []byte, claims any) string { + t.Helper() + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, + (&jose.SignerOptions{}).WithType("JWT")) + require.NoError(t, err) + s, err := jwt.Signed(signer).Claims(claims).CompactSerialize() + require.NoError(t, err) + return s + } + + key := cryptutil.NewKey() + + t.Run("unexpected_format", func(t *testing.T) { + err := validateJWT("not a jwt", key) + assert.Error(t, err) + }) + t.Run("unexpected_claim_type", func(t *testing.T) { + rawjwt := sign(t, key, jwt.Claims{ + Subject: "subject", + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + err := validateJWT(rawjwt, key) + assert.Error(t, err) + }) + t.Run("unexpected_claim_name", func(t *testing.T) { + rawjwt := sign(t, key, jwt.Claims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + err := validateJWT(rawjwt, key) + assert.ErrorContains(t, err, "expected exactly one claim (exp)") + }) + t.Run("no_claims", func(t *testing.T) { + rawjwt := sign(t, key, jwt.Claims{}) + err := validateJWT(rawjwt, key) + assert.ErrorContains(t, err, "expected exactly one claim (exp)") + }) + t.Run("unexpected_expiry_type", func(t *testing.T) { + rawjwt := sign(t, key, map[string]any{ + "exp": "foo", + }) + err := validateJWT(rawjwt, key) + assert.ErrorContains(t, err, "expected number value") + }) + t.Run("expired", func(t *testing.T) { + rawjwt := sign(t, key, jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + }) + err := validateJWT(rawjwt, key) + assert.ErrorContains(t, err, "JWT expired") + }) + t.Run("wrong_key", func(t *testing.T) { + otherKey := cryptutil.NewKey() + rawjwt := sign(t, otherKey, jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + + err := validateJWT(rawjwt, key) + assert.Error(t, err) + }) + t.Run("ok", func(t *testing.T) { + rawjwt := sign(t, key, jwt.Claims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }) + err := validateJWT(rawjwt, key) + assert.NoError(t, err) + }) +}