mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 20:06:03 +02:00
113 lines
3.3 KiB
Go
113 lines
3.3 KiB
Go
package grpcutil
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/go-jose/go-jose/v3/jwt"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
|
|
func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor {
|
|
return func(
|
|
ctx context.Context,
|
|
desc *grpc.StreamDesc,
|
|
cc *grpc.ClientConn,
|
|
method string, streamer grpc.Streamer,
|
|
opts ...grpc.CallOption,
|
|
) (grpc.ClientStream, error) {
|
|
ctx, err := withSignedJWT(ctx, key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return streamer(ctx, desc, cc, method, opts...)
|
|
}
|
|
}
|
|
|
|
// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
|
|
func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor {
|
|
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
|
ctx, err := withSignedJWT(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return invoker(ctx, method, req, reply, cc, opts...)
|
|
}
|
|
}
|
|
|
|
func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
|
|
if len(key) > 0 {
|
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
|
(&jose.SignerOptions{}).WithType("JWT"))
|
|
if err != nil {
|
|
return ctx, err
|
|
}
|
|
|
|
rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
|
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
}).CompactSerialize()
|
|
if err != nil {
|
|
return ctx, err
|
|
}
|
|
|
|
ctx = WithOutgoingJWT(ctx, rawjwt)
|
|
}
|
|
return ctx, nil
|
|
}
|
|
|
|
// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
|
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
|
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
|
if err := RequireSignedJWT(ctx, keyBS); err != nil {
|
|
return nil, err
|
|
}
|
|
return handler(ctx, req)
|
|
}
|
|
}
|
|
|
|
// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
|
|
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
|
|
keyBS, _ := base64.StdEncoding.DecodeString(key)
|
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
if err := RequireSignedJWT(ss.Context(), keyBS); err != nil {
|
|
return err
|
|
}
|
|
return handler(srv, ss)
|
|
}
|
|
}
|
|
|
|
// RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key.
|
|
func RequireSignedJWT(ctx context.Context, key []byte) error {
|
|
if len(key) > 0 {
|
|
rawjwt, ok := JWTFromGRPCRequest(ctx)
|
|
if !ok {
|
|
return status.Error(codes.Unauthenticated, "unauthenticated")
|
|
}
|
|
|
|
tok, err := jwt.ParseSigned(rawjwt)
|
|
if err != nil {
|
|
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
|
}
|
|
|
|
var claims struct {
|
|
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
|
}
|
|
err = tok.Claims(key, &claims)
|
|
if err != nil {
|
|
return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
|
|
}
|
|
|
|
if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
|
|
return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|