pomerium/pkg/grpcutil/options.go
2021-06-10 09:35:44 -06:00

113 lines
3.3 KiB
Go

package grpcutil
import (
"context"
"encoding/base64"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor {
return func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string, streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
ctx, err := withSignedJWT(ctx, key)
if err != nil {
return nil, err
}
return streamer(ctx, desc, cc, method, opts...)
}
}
// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, err := withSignedJWT(ctx, key)
if err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
if len(key) > 0 {
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return ctx, err
}
rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).CompactSerialize()
if err != nil {
return ctx, err
}
ctx = WithOutgoingJWT(ctx, rawjwt)
}
return ctx, nil
}
// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
keyBS, _ := base64.StdEncoding.DecodeString(key)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if err := RequireSignedJWT(ctx, keyBS); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
keyBS, _ := base64.StdEncoding.DecodeString(key)
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := RequireSignedJWT(ss.Context(), keyBS); err != nil {
return err
}
return handler(srv, ss)
}
}
// RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key.
func RequireSignedJWT(ctx context.Context, key []byte) error {
if len(key) > 0 {
rawjwt, ok := JWTFromGRPCRequest(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "unauthenticated")
}
tok, err := jwt.ParseSigned(rawjwt)
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
}
var claims struct {
Expiry *jwt.NumericDate `json:"exp,omitempty"`
}
err = tok.Claims(key, &claims)
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
}
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
}
}
return nil
}