mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
grpcutil: additional JWT validation (#5303)
Add additional validation to the grpcutil.RequireSignedJWT method. Log any validation error, instead of returning error details in the gRPC status message.
This commit is contained in:
parent
753b24dd7b
commit
c011957389
2 changed files with 98 additions and 15 deletions
|
@ -3,10 +3,12 @@ package grpcutil
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
@ -92,22 +94,30 @@ func RequireSignedJWT(ctx context.Context, key []byte) error {
|
|||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
|
||||
tok, err := jwt.ParseSigned(rawjwt)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||
}
|
||||
err = tok.Claims(key, &claims)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
||||
}
|
||||
|
||||
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
|
||||
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
|
||||
if err := validateJWT(rawjwt, key); err != nil {
|
||||
log.Ctx(ctx).Debug().Err(err).Msg("rejected gRPC request due to invalid JWT")
|
||||
return status.Error(codes.Unauthenticated, "invalid JWT")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateJWT(rawjwt string, key []byte) error {
|
||||
tok, err := jwt.ParseSigned(rawjwt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var claims map[string]*jwt.NumericDate
|
||||
err = tok.Claims(key, &claims)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(claims) != 1 || claims["exp"] == nil {
|
||||
return fmt.Errorf("expected exactly one claim (exp)")
|
||||
}
|
||||
|
||||
if t := claims["exp"].Time(); time.Now().After(t) {
|
||||
return fmt.Errorf("JWT expired at %s", t.Format(time.DateTime))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,7 +9,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/reflection"
|
||||
|
@ -100,3 +103,73 @@ func TestSignedJWT(t *testing.T) {
|
|||
assert.Equal(t, codes.OK, status.Code(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateJWT(t *testing.T) {
|
||||
sign := func(t *testing.T, key []byte, claims any) string {
|
||||
t.Helper()
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
require.NoError(t, err)
|
||||
s, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return s
|
||||
}
|
||||
|
||||
key := cryptutil.NewKey()
|
||||
|
||||
t.Run("unexpected_format", func(t *testing.T) {
|
||||
err := validateJWT("not a jwt", key)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("unexpected_claim_type", func(t *testing.T) {
|
||||
rawjwt := sign(t, key, jwt.Claims{
|
||||
Subject: "subject",
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
})
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("unexpected_claim_name", func(t *testing.T) {
|
||||
rawjwt := sign(t, key, jwt.Claims{
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
})
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.ErrorContains(t, err, "expected exactly one claim (exp)")
|
||||
})
|
||||
t.Run("no_claims", func(t *testing.T) {
|
||||
rawjwt := sign(t, key, jwt.Claims{})
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.ErrorContains(t, err, "expected exactly one claim (exp)")
|
||||
})
|
||||
t.Run("unexpected_expiry_type", func(t *testing.T) {
|
||||
rawjwt := sign(t, key, map[string]any{
|
||||
"exp": "foo",
|
||||
})
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.ErrorContains(t, err, "expected number value")
|
||||
})
|
||||
t.Run("expired", func(t *testing.T) {
|
||||
rawjwt := sign(t, key, jwt.Claims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
})
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.ErrorContains(t, err, "JWT expired")
|
||||
})
|
||||
t.Run("wrong_key", func(t *testing.T) {
|
||||
otherKey := cryptutil.NewKey()
|
||||
rawjwt := sign(t, otherKey, jwt.Claims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
})
|
||||
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
rawjwt := sign(t, key, jwt.Claims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
})
|
||||
err := validateJWT(rawjwt, key)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue