package grpcutil

import (
	"context"
	"encoding/base64"
	"time"

	"github.com/go-jose/go-jose/v3"
	"github.com/go-jose/go-jose/v3/jwt"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

// WithStreamSignedJWT returns a StreamClientInterceptor that adds a JWT to requests.
func WithStreamSignedJWT(key []byte) grpc.StreamClientInterceptor {
	return func(
		ctx context.Context,
		desc *grpc.StreamDesc,
		cc *grpc.ClientConn,
		method string, streamer grpc.Streamer,
		opts ...grpc.CallOption,
	) (grpc.ClientStream, error) {
		ctx, err := withSignedJWT(ctx, key)
		if err != nil {
			return nil, err
		}

		return streamer(ctx, desc, cc, method, opts...)
	}
}

// WithUnarySignedJWT returns a UnaryClientInterceptor that adds a JWT to requests.
func WithUnarySignedJWT(key []byte) grpc.UnaryClientInterceptor {
	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
		ctx, err := withSignedJWT(ctx, key)
		if err != nil {
			return err
		}

		return invoker(ctx, method, req, reply, cc, opts...)
	}
}

func withSignedJWT(ctx context.Context, key []byte) (context.Context, error) {
	if len(key) > 0 {
		sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
			(&jose.SignerOptions{}).WithType("JWT"))
		if err != nil {
			return ctx, err
		}

		rawjwt, err := jwt.Signed(sig).Claims(jwt.Claims{
			Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
		}).CompactSerialize()
		if err != nil {
			return ctx, err
		}

		ctx = WithOutgoingJWT(ctx, rawjwt)
	}
	return ctx, nil
}

// UnaryRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
func UnaryRequireSignedJWT(key string) grpc.UnaryServerInterceptor {
	keyBS, _ := base64.StdEncoding.DecodeString(key)
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
		if err := RequireSignedJWT(ctx, keyBS); err != nil {
			return nil, err
		}
		return handler(ctx, req)
	}
}

// StreamRequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the base64-encoded key.
func StreamRequireSignedJWT(key string) grpc.StreamServerInterceptor {
	keyBS, _ := base64.StdEncoding.DecodeString(key)
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		if err := RequireSignedJWT(ss.Context(), keyBS); err != nil {
			return err
		}
		return handler(srv, ss)
	}
}

// RequireSignedJWT requires a JWT in the gRPC metadata and that it be signed by the given key.
func RequireSignedJWT(ctx context.Context, key []byte) error {
	if len(key) > 0 {
		rawjwt, ok := JWTFromGRPCRequest(ctx)
		if !ok {
			return status.Error(codes.Unauthenticated, "unauthenticated")
		}

		tok, err := jwt.ParseSigned(rawjwt)
		if err != nil {
			return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
		}

		var claims struct {
			Expiry *jwt.NumericDate `json:"exp,omitempty"`
		}
		err = tok.Claims(key, &claims)
		if err != nil {
			return status.Errorf(codes.Unauthenticated, "invalid JWT: %v", err)
		}

		if claims.Expiry == nil || time.Now().After(claims.Expiry.Time()) {
			return status.Errorf(codes.Unauthenticated, "expired JWT: %v", err)
		}
	}
	return nil
}