grpc: send client traffic through envoy (#2469)

* wip

* wip

* handle wildcards in override name

* remove wait for ready, add comment about sync, force initial sync complete in test

* address comments
This commit is contained in:
Caleb Doxsey 2021-08-16 16:12:22 -06:00 committed by GitHub
parent 87c3c675d2
commit bbec2cae9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 391 additions and 480 deletions

View file

@ -2,51 +2,23 @@ package grpc
import (
"context"
"crypto/tls"
"errors"
"net"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
const (
defaultGRPCSecurePort = 443
defaultGRPCInsecurePort = 80
)
// Options contains options for connecting to a pomerium rpc service.
type Options struct {
// Addrs is the location of the service. e.g. "service.corp.example:8443"
Addrs []*url.URL
// OverrideCertificateName overrides the server name used to verify the hostname on the
// returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set.
OverrideCertificateName string
// CA specifies the base64 encoded TLS certificate authority to use.
CA string
// CAFile specifies the TLS certificate authority file to use.
CAFile string
// RequestTimeout specifies the timeout for individual RPC calls
RequestTimeout time.Duration
// ClientDNSRoundRobin enables or disables DNS resolver based load balancing
ClientDNSRoundRobin bool
// WithInsecure disables transport security for this ClientConn.
// Note that transport security is required unless WithInsecure is set.
WithInsecure bool
// Address is the location of the service. e.g. "service.corp.example:8443"
Address string
// InstallationID specifies the installation id for telemetry exposition.
InstallationID string
@ -60,31 +32,10 @@ type Options struct {
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
func NewGRPCClientConn(ctx context.Context, opts *Options, other ...grpc.DialOption) (*grpc.ClientConn, error) {
if len(opts.Addrs) == 0 {
return nil, errors.New("internal/grpc: connection address required")
}
var addrs []string
for _, u := range opts.Addrs {
hostport := u.Host
// no colon exists in the connection string, assume one must be added manually
if _, _, err := net.SplitHostPort(hostport); err != nil {
if u.Scheme == "https" {
hostport = net.JoinHostPort(hostport, strconv.Itoa(defaultGRPCSecurePort))
} else {
hostport = net.JoinHostPort(hostport, strconv.Itoa(defaultGRPCInsecurePort))
}
}
addrs = append(addrs, hostport)
}
connAddr := "pomerium:///" + strings.Join(addrs, ",")
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
unaryClientInterceptors := []grpc.UnaryClientInterceptor{
requestid.UnaryClientInterceptor(),
grpcTimeoutInterceptor(opts.RequestTimeout),
clientStatsHandler.UnaryInterceptor,
}
streamClientInterceptors := []grpc.StreamClientInterceptor{
@ -98,38 +49,13 @@ func NewGRPCClientConn(ctx context.Context, opts *Options, other ...grpc.DialOpt
dialOptions := []grpc.DialOption{
grpc.WithChainUnaryInterceptor(unaryClientInterceptors...),
grpc.WithChainStreamInterceptor(streamClientInterceptors...),
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
grpc.WithStatsHandler(clientStatsHandler.Handler),
grpc.WithDefaultServiceConfig(roundRobinServiceConfig),
grpc.WithDisableServiceConfig(),
grpc.WithInsecure(),
}
dialOptions = append(dialOptions, other...)
if opts.WithInsecure {
log.Info(ctx).Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
dialOptions = append(dialOptions, grpc.WithInsecure())
} else {
rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile)
if err != nil {
return nil, err
}
cert := credentials.NewTLS(&tls.Config{RootCAs: rootCAs, MinVersion: tls.VersionTLS12})
// override allowed certificate name string, typically used when doing behind ingress connection
if opts.OverrideCertificateName != "" {
log.Debug(ctx).Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
err := cert.OverrideServerName(opts.OverrideCertificateName)
if err != nil {
return nil, err
}
}
// finally add our credential
dialOptions = append(dialOptions, grpc.WithTransportCredentials(cert))
}
return grpc.DialContext(ctx, connAddr, dialOptions...)
log.Info(ctx).Str("address", opts.Address).Msg("dialing")
return grpc.DialContext(ctx, opts.Address, dialOptions...)
}
// grpcTimeoutInterceptor enforces per-RPC request timeouts
@ -186,3 +112,28 @@ func GetGRPCClientConn(ctx context.Context, name string, opts *Options) (*grpc.C
}
return cc, nil
}
// OutboundOptions are the options for the outbound gRPC client.
type OutboundOptions struct {
// OutboundPort is the port for the outbound gRPC listener.
OutboundPort string
// InstallationID specifies the installation id for telemetry exposition.
InstallationID string
// ServiceName specifies the service name for telemetry exposition
ServiceName string
// SignedJWTKey is the JWT key to use for signing a JWT attached to metadata.
SignedJWTKey []byte
}
// GetOutboundGRPCClientConn gets the outbound gRPC client.
func GetOutboundGRPCClientConn(ctx context.Context, opts *OutboundOptions) (*grpc.ClientConn, error) {
return GetGRPCClientConn(ctx, "outbound", &Options{
Address: net.JoinHostPort("127.0.0.1", opts.OutboundPort),
InstallationID: opts.InstallationID,
ServiceName: opts.ServiceName,
SignedJWTKey: opts.SignedJWTKey,
})
}

View file

@ -2,12 +2,9 @@ package grpc
import (
"context"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
@ -37,81 +34,3 @@ func Test_grpcTimeoutInterceptor(t *testing.T) {
to(context.Background(), "test", nil, nil, nil, mockInvoker(timeOut*2, true))
to(context.Background(), "test", nil, nil, nil, mockInvoker(timeOut/2, false))
}
func TestNewGRPC(t *testing.T) {
t.Parallel()
tests := []struct {
name string
opts *Options
wantErr bool
wantErrStr string
wantTarget string
}{
{"empty connection", &Options{Addrs: nil}, true, "proxy/authenticator: connection address required", ""},
{"both internal and addr empty", &Options{Addrs: nil}, true, "proxy/authenticator: connection address required", ""},
{"addr with port", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:8443"}}}, false, "", "pomerium:///localhost.example:8443"},
{"secure addr without port", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example"}}}, false, "", "pomerium:///localhost.example:443"},
{"insecure addr without port", &Options{Addrs: []*url.URL{{Scheme: "http", Host: "localhost.example"}}}, false, "", "pomerium:///localhost.example:80"},
{"cert override", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local"}, false, "", "pomerium:///localhost.example:443"},
{"custom ca", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "pomerium:///localhost.example:443"},
{"bad ca encoding", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CA: "^"}, true, "", "pomerium:///localhost.example:443"},
{"custom ca file", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CAFile: "testdata/example.crt"}, false, "", "pomerium:///localhost.example:443"},
{"bad custom ca file", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CAFile: "testdata/example.crt2"}, true, "", "pomerium:///localhost.example:443"},
{"valid with insecure", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:8443"}}, WithInsecure: true}, false, "", "pomerium:///localhost.example:8443"},
{"valid client round robin", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:8443"}}, ClientDNSRoundRobin: true}, false, "", "pomerium:///localhost.example:8443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewGRPCClientConn(context.Background(), tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
t.Errorf("New() error = %v did not contain wantErr %v", err, tt.wantErrStr)
}
}
if got != nil && got.Target() != tt.wantTarget {
t.Errorf("New() target = %v expected %v", got.Target(), tt.wantTarget)
}
})
}
}
func TestGetGRPC(t *testing.T) {
cc1, err := GetGRPCClientConn(context.Background(), "example", &Options{
Addrs: mustParseURLs("https://localhost.example"),
})
if !assert.NoError(t, err) {
return
}
cc2, err := GetGRPCClientConn(context.Background(), "example", &Options{
Addrs: mustParseURLs("https://localhost.example"),
})
if !assert.NoError(t, err) {
return
}
assert.Same(t, cc1, cc2, "GetGRPCClientConn should return the same connection when there are no changes")
cc3, err := GetGRPCClientConn(context.Background(), "example", &Options{
Addrs: mustParseURLs("http://localhost.example"),
WithInsecure: true,
})
if !assert.NoError(t, err) {
return
}
assert.NotSame(t, cc1, cc3, "GetGRPCClientConn should return a new connection when there are changes")
}
func mustParseURLs(rawurls ...string) []*url.URL {
var urls []*url.URL
for _, rawurl := range rawurls {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
urls = append(urls, u)
}
return urls
}

View file

@ -1,11 +0,0 @@
package grpc
//go:generate ./protoc.bash
const roundRobinServiceConfig = `{
"loadBalancingConfig": [
{
"round_robin": {}
}
]
}`

View file

@ -1,104 +0,0 @@
package grpc
import (
"strings"
"sync"
"google.golang.org/grpc/resolver"
)
func init() {
resolver.Register(&pomeriumBuilder{})
}
type pomeriumBuilder struct {
}
func (*pomeriumBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
endpoints := strings.Split(target.Endpoint, ",")
pccd := &pomeriumClientConnData{
states: make([]resolver.State, len(endpoints)),
}
pr := &pomeriumResolver{}
for i, endpoint := range endpoints {
subTarget := parseTarget(endpoint)
b := resolver.Get(subTarget.Scheme)
pcc := &pomeriumClientConn{
data: pccd,
idx: i,
ClientConn: cc,
}
r, err := b.Build(subTarget, pcc, opts)
if err != nil {
return nil, err
}
pr.resolvers = append(pr.resolvers, r)
}
return pr, nil
}
func (*pomeriumBuilder) Scheme() string {
return "pomerium"
}
type pomeriumResolver struct {
resolvers []resolver.Resolver
}
func (pr *pomeriumResolver) ResolveNow(options resolver.ResolveNowOptions) {
for _, r := range pr.resolvers {
r.ResolveNow(options)
}
}
func (pr *pomeriumResolver) Close() {
for _, r := range pr.resolvers {
r.Close()
}
}
type pomeriumClientConn struct {
data *pomeriumClientConnData
idx int
resolver.ClientConn
}
func (pcc *pomeriumClientConn) UpdateState(state resolver.State) error {
return pcc.ClientConn.UpdateState(pcc.data.updateState(pcc.idx, state))
}
type pomeriumClientConnData struct {
mu sync.Mutex
states []resolver.State
}
func (pccd *pomeriumClientConnData) updateState(idx int, state resolver.State) resolver.State {
pccd.mu.Lock()
defer pccd.mu.Unlock()
pccd.states[idx] = state
merged := resolver.State{}
for _, s := range pccd.states {
merged.Addresses = append(merged.Addresses, s.Addresses...)
merged.ServiceConfig = s.ServiceConfig
merged.Attributes = s.Attributes
}
return merged
}
func parseTarget(raw string) resolver.Target {
target := resolver.Target{
Scheme: resolver.GetDefaultScheme(),
}
if idx := strings.Index(raw, "://"); idx >= 0 {
target.Scheme = raw[:idx]
raw = raw[idx+3:]
}
if idx := strings.Index(raw, "/"); idx >= 0 {
target.Authority = raw[:idx]
raw = raw[idx+1:]
}
target.Endpoint = raw
return target
}

View file

@ -1,68 +0,0 @@
package grpc
import (
"context"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/test/grpc_testing"
)
type resolverTestServer struct {
grpc_testing.UnimplementedTestServiceServer
username string
}
func (srv *resolverTestServer) UnaryCall(context.Context, *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) {
return &grpc_testing.SimpleResponse{
Username: srv.username,
}, nil
}
func TestResolver(t *testing.T) {
li1, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li1.Close() }()
srv1 := grpc.NewServer()
grpc_testing.RegisterTestServiceServer(srv1, &resolverTestServer{
username: "srv1",
})
go func() { _ = srv1.Serve(li1) }()
li2, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li2.Close() }()
srv2 := grpc.NewServer()
grpc_testing.RegisterTestServiceServer(srv2, &resolverTestServer{
username: "srv2",
})
go func() { _ = srv2.Serve(li2) }()
cc, err := grpc.Dial("pomerium:///"+strings.Join([]string{
"dns:///" + li1.Addr().String(),
li2.Addr().String(),
}, ","), grpc.WithInsecure(), grpc.WithDefaultServiceConfig(roundRobinServiceConfig))
if !assert.NoError(t, err) {
return
}
defer func() { _ = cc.Close() }()
c := grpc_testing.NewTestServiceClient(cc)
usernames := map[string]int{}
for i := 0; i < 1000; i++ {
res, err := c.UnaryCall(context.Background(), new(grpc_testing.SimpleRequest))
assert.NoError(t, err)
usernames[res.GetUsername()]++
}
assert.Greater(t, usernames["srv1"], 0)
assert.Greater(t, usernames["srv2"], 0)
}