mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
Add additional validation to the grpcutil.RequireSignedJWT method. Log any validation error, instead of returning error details in the gRPC status message.
123 lines
3.5 KiB
Go
123 lines
3.5 KiB
Go
package grpcutil
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/go-jose/go-jose/v3/jwt"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
|
|
func WithStreamSignedJWT(getKey func() []byte) grpc.StreamClientInterceptor {
|
|
return func(
|
|
ctx context.Context,
|
|
desc *grpc.StreamDesc,
|
|
cc *grpc.ClientConn,
|
|
method string, streamer grpc.Streamer,
|
|
opts ...grpc.CallOption,
|
|
) (grpc.ClientStream, error) {
|
|
ctx, err := withSignedJWT(ctx, getKey())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return streamer(ctx, desc, cc, method, opts...)
|
|
}
|
|
}
|
|
|
|
// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
|
|
func WithUnarySignedJWT(getKey func() []byte) grpc.UnaryClientInterceptor {
|
|
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
|
ctx, err := withSignedJWT(ctx, getKey())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return invoker(ctx, method, req, reply, cc, opts...)
|
|
}
|
|
}
|
|
|
|
func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
|
|
if len(key) > 0 {
|
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
|
(&jose.SignerOptions{}).WithType("JWT"))
|
|
if err != nil {
|
|
return ctx, err
|
|
}
|
|
|
|
rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
|
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
}).CompactSerialize()
|
|
if err != nil {
|
|
return ctx, err
|
|
}
|
|
|
|
ctx = WithOutgoingJWT(ctx, rawjwt)
|
|
}
|
|
return ctx, nil
|
|
}
|
|
|
|
// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
|
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
|
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
|
if err := RequireSignedJWT(ctx, keyBS); err != nil {
|
|
return nil, err
|
|
}
|
|
return handler(ctx, req)
|
|
}
|
|
}
|
|
|
|
// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
|
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
|
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
|
return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
if err := RequireSignedJWT(ss.Context(), keyBS); err != nil {
|
|
return err
|
|
}
|
|
return handler(srv, ss)
|
|
}
|
|
}
|
|
|
|
// RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key.
|
|
func RequireSignedJWT(ctx context.Context, key []byte) error {
|
|
if len(key) > 0 {
|
|
rawjwt, ok := JWTFromGRPCRequest(ctx)
|
|
if !ok {
|
|
return status.Error(codes.Unauthenticated, "unauthenticated")
|
|
}
|
|
|
|
if err := validateJWT(rawjwt, key); err != nil {
|
|
log.Ctx(ctx).Debug().Err(err).Msg("rejected gRPC request due to invalid JWT")
|
|
return status.Error(codes.Unauthenticated, "invalid JWT")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateJWT(rawjwt string, key []byte) error {
|
|
tok, err := jwt.ParseSigned(rawjwt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var claims map[string]*jwt.NumericDate
|
|
err = tok.Claims(key, &claims)
|
|
if err != nil {
|
|
return err
|
|
} else if len(claims) != 1 || claims["exp"] == nil {
|
|
return fmt.Errorf("expected exactly one claim (exp)")
|
|
}
|
|
|
|
if t := claims["exp"].Time(); time.Now().After(t) {
|
|
return fmt.Errorf("JWT expired at %s", t.Format(time.DateTime))
|
|
}
|
|
return nil
|
|
}
|