mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
117 lines
2.9 KiB
Go
117 lines
2.9 KiB
Go
// Package grpcconn provides a gRPC client with authentication
|
|
package grpcconn
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/connectivity"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
)
|
|
|
|
type client struct {
|
|
config *config
|
|
tokenProvider TokenProviderFn
|
|
}
|
|
|
|
// TokenProviderFn is a function that returns an authorization token
|
|
type TokenProviderFn func(ctx context.Context) (string, error)
|
|
|
|
// New creates a new gRPC client with authentication
|
|
func New(
|
|
ctx context.Context,
|
|
endpoint string,
|
|
tokenProvider TokenProviderFn,
|
|
dialOpts ...grpc.DialOption,
|
|
) (*grpc.ClientConn, error) {
|
|
cfg, err := getConfig(endpoint, dialOpts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cc := &client{
|
|
tokenProvider: tokenProvider,
|
|
config: cfg,
|
|
}
|
|
|
|
conn, err := cc.getGRPCConn(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return conn, err
|
|
}
|
|
|
|
func (c *client) getGRPCConn(ctx context.Context) (*grpc.ClientConn, error) {
|
|
opts := append(
|
|
c.config.GetDialOptions(),
|
|
grpc.WithAuthority(c.config.GetAuthority()),
|
|
grpc.WithPerRPCCredentials(c),
|
|
grpc.WithDefaultCallOptions(
|
|
grpc.UseCompressor("gzip"),
|
|
),
|
|
grpc.WithChainUnaryInterceptor(
|
|
logging.UnaryClientInterceptor(logging.LoggerFunc(interceptorLogger)),
|
|
),
|
|
grpc.WithStreamInterceptor(
|
|
logging.StreamClientInterceptor(logging.LoggerFunc(interceptorLogger)),
|
|
),
|
|
)
|
|
|
|
conn, err := grpc.NewClient(c.config.GetConnectionURI(), opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error dialing grpc server: %w", err)
|
|
}
|
|
|
|
go c.logConnectionState(ctx, conn)
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// GetRequestMetadata implements credentials.PerRPCCredentials
|
|
func (c *client) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
|
|
token, err := c.tokenProvider(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return map[string]string{
|
|
"authorization": fmt.Sprintf("Bearer %s", token),
|
|
}, nil
|
|
}
|
|
|
|
// RequireTransportSecurity implements credentials.PerRPCCredentials
|
|
func (c *client) RequireTransportSecurity() bool {
|
|
return c.config.RequireTLS()
|
|
}
|
|
|
|
func (c *client) logConnectionState(ctx context.Context, conn *grpc.ClientConn) {
|
|
var state connectivity.State = -1
|
|
for ctx.Err() == nil && state != connectivity.Shutdown {
|
|
_ = conn.WaitForStateChange(ctx, state)
|
|
state = conn.GetState()
|
|
log.Ctx(ctx).Debug().
|
|
Str("endpoint", c.config.GetConnectionURI()).
|
|
Str("state", state.String()).
|
|
Msg("grpc connection state")
|
|
}
|
|
}
|
|
|
|
func interceptorLogger(ctx context.Context, lvl logging.Level, msg string, fields ...any) {
|
|
l := log.Ctx(ctx).With().Fields(fields).Logger()
|
|
|
|
switch lvl {
|
|
case logging.LevelDebug:
|
|
l.Debug().Msg(msg)
|
|
case logging.LevelInfo:
|
|
l.Debug().Msg(msg)
|
|
case logging.LevelWarn:
|
|
l.Warn().Msg(msg)
|
|
case logging.LevelError:
|
|
l.Error().Msg(msg)
|
|
default:
|
|
panic(fmt.Sprintf("unknown level %v", lvl))
|
|
}
|
|
}
|