mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 11:26:29 +02:00
grpcutil: additional JWT validation (#5303)
Add additional validation to the grpcutil.RequireSignedJWT method. Log any validation error, instead of returning error details in the gRPC status message.
This commit is contained in:
parent
753b24dd7b
commit
c011957389
2 changed files with 98 additions and 15 deletions
|
@ -3,10 +3,12 @@ package grpcutil
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
@ -92,22 +94,30 @@ func RequireSignedJWT(ctx context.Context, key []byte) error {
|
||||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
tok, err := jwt.ParseSigned(rawjwt)
|
if err := validateJWT(rawjwt, key); err != nil {
|
||||||
if err != nil {
|
log.Ctx(ctx).Debug().Err(err).Msg("rejected gRPC request due to invalid JWT")
|
||||||
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
return status.Error(codes.Unauthenticated, "invalid JWT")
|
||||||
}
|
|
||||||
|
|
||||||
var claims struct {
|
|
||||||
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
|
||||||
}
|
|
||||||
err = tok.Claims(key, &claims)
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
|
|
||||||
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateJWT(rawjwt string, key []byte) error {
|
||||||
|
tok, err := jwt.ParseSigned(rawjwt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims map[string]*jwt.NumericDate
|
||||||
|
err = tok.Claims(key, &claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if len(claims) != 1 || claims["exp"] == nil {
|
||||||
|
return fmt.Errorf("expected exactly one claim (exp)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if t := claims["exp"].Time(); time.Now().After(t) {
|
||||||
|
return fmt.Errorf("JWT expired at %s", t.Format(time.DateTime))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/reflection"
|
"google.golang.org/grpc/reflection"
|
||||||
|
@ -100,3 +103,73 @@ func TestSignedJWT(t *testing.T) {
|
||||||
assert.Equal(t, codes.OK, status.Code(err))
|
assert.Equal(t, codes.OK, status.Code(err))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateJWT(t *testing.T) {
|
||||||
|
sign := func(t *testing.T, key []byte, claims any) string {
|
||||||
|
t.Helper()
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
||||||
|
(&jose.SignerOptions{}).WithType("JWT"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
s, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
key := cryptutil.NewKey()
|
||||||
|
|
||||||
|
t.Run("unexpected_format", func(t *testing.T) {
|
||||||
|
err := validateJWT("not a jwt", key)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("unexpected_claim_type", func(t *testing.T) {
|
||||||
|
rawjwt := sign(t, key, jwt.Claims{
|
||||||
|
Subject: "subject",
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
})
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("unexpected_claim_name", func(t *testing.T) {
|
||||||
|
rawjwt := sign(t, key, jwt.Claims{
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
})
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.ErrorContains(t, err, "expected exactly one claim (exp)")
|
||||||
|
})
|
||||||
|
t.Run("no_claims", func(t *testing.T) {
|
||||||
|
rawjwt := sign(t, key, jwt.Claims{})
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.ErrorContains(t, err, "expected exactly one claim (exp)")
|
||||||
|
})
|
||||||
|
t.Run("unexpected_expiry_type", func(t *testing.T) {
|
||||||
|
rawjwt := sign(t, key, map[string]any{
|
||||||
|
"exp": "foo",
|
||||||
|
})
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.ErrorContains(t, err, "expected number value")
|
||||||
|
})
|
||||||
|
t.Run("expired", func(t *testing.T) {
|
||||||
|
rawjwt := sign(t, key, jwt.Claims{
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||||
|
})
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.ErrorContains(t, err, "JWT expired")
|
||||||
|
})
|
||||||
|
t.Run("wrong_key", func(t *testing.T) {
|
||||||
|
otherKey := cryptutil.NewKey()
|
||||||
|
rawjwt := sign(t, otherKey, jwt.Claims{
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
})
|
||||||
|
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("ok", func(t *testing.T) {
|
||||||
|
rawjwt := sign(t, key, jwt.Claims{
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
})
|
||||||
|
err := validateJWT(rawjwt, key)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue