grpcutil: additional JWT validation (#5303)

Add additional validation to the grpcutil.RequireSignedJWT method. Log
any validation error, instead of returning error details in the gRPC
status message.
This commit is contained in:
Kenneth Jenkins 2024-09-23 13:17:03 -07:00 committed by GitHub
parent 753b24dd7b
commit c011957389
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 98 additions and 15 deletions

View file

@ -3,10 +3,12 @@ package grpcutil
import (
"context"
"encoding/base64"
"fmt"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/pomerium/pomerium/internal/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -92,22 +94,30 @@ func RequireSignedJWT(ctx context.Context, key []byte) error {
return status.Error(codes.Unauthenticated, "unauthenticated")
}
tok, err := jwt.ParseSigned(rawjwt)
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
}
var claims struct {
Expiry *jwt.NumericDate `json:"exp,omitempty"`
}
err = tok.Claims(key, &claims)
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
}
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
if err := validateJWT(rawjwt, key); err != nil {
log.Ctx(ctx).Debug().Err(err).Msg("rejected gRPC request due to invalid JWT")
return status.Error(codes.Unauthenticated, "invalid JWT")
}
}
return nil
}
func validateJWT(rawjwt string, key []byte) error {
tok, err := jwt.ParseSigned(rawjwt)
if err != nil {
return err
}
var claims map[string]*jwt.NumericDate
err = tok.Claims(key, &claims)
if err != nil {
return err
} else if len(claims) != 1 || claims["exp"] == nil {
return fmt.Errorf("expected exactly one claim (exp)")
}
if t := claims["exp"].Time(); time.Now().After(t) {
return fmt.Errorf("JWT expired at %s", t.Format(time.DateTime))
}
return nil
}

View file

@ -9,7 +9,10 @@ import (
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/reflection"
@ -100,3 +103,73 @@ func TestSignedJWT(t *testing.T) {
assert.Equal(t, codes.OK, status.Code(err))
})
}
func TestValidateJWT(t *testing.T) {
sign := func(t *testing.T, key []byte, claims any) string {
t.Helper()
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"))
require.NoError(t, err)
s, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
require.NoError(t, err)
return s
}
key := cryptutil.NewKey()
t.Run("unexpected_format", func(t *testing.T) {
err := validateJWT("not a jwt", key)
assert.Error(t, err)
})
t.Run("unexpected_claim_type", func(t *testing.T) {
rawjwt := sign(t, key, jwt.Claims{
Subject: "subject",
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
})
err := validateJWT(rawjwt, key)
assert.Error(t, err)
})
t.Run("unexpected_claim_name", func(t *testing.T) {
rawjwt := sign(t, key, jwt.Claims{
IssuedAt: jwt.NewNumericDate(time.Now()),
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
})
err := validateJWT(rawjwt, key)
assert.ErrorContains(t, err, "expected exactly one claim (exp)")
})
t.Run("no_claims", func(t *testing.T) {
rawjwt := sign(t, key, jwt.Claims{})
err := validateJWT(rawjwt, key)
assert.ErrorContains(t, err, "expected exactly one claim (exp)")
})
t.Run("unexpected_expiry_type", func(t *testing.T) {
rawjwt := sign(t, key, map[string]any{
"exp": "foo",
})
err := validateJWT(rawjwt, key)
assert.ErrorContains(t, err, "expected number value")
})
t.Run("expired", func(t *testing.T) {
rawjwt := sign(t, key, jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
})
err := validateJWT(rawjwt, key)
assert.ErrorContains(t, err, "JWT expired")
})
t.Run("wrong_key", func(t *testing.T) {
otherKey := cryptutil.NewKey()
rawjwt := sign(t, otherKey, jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
})
err := validateJWT(rawjwt, key)
assert.Error(t, err)
})
t.Run("ok", func(t *testing.T) {
rawjwt := sign(t, key, jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
})
err := validateJWT(rawjwt, key)
assert.NoError(t, err)
})
}