use incremental API for envoy xDS (#1732)

* use incremental API

* add test

* use backoff v4

* remove panic, add comment to exponential try, add test for HashProto

* merge master

* fix missing import
This commit is contained in:
Caleb Doxsey 2021-01-05 12:45:55 -07:00 committed by GitHub
parent a07d85b174
commit 3524697f6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 511 additions and 238 deletions

View file

@ -1,136 +0,0 @@
package controlplane
import (
"context"
"encoding/json"
"errors"
"fmt"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/internal/log"
)
func (srv *Server) registerXDSHandlers() {
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv)
}
// StreamAggregatedResources streams xDS resources based on incoming discovery requests.
//
// This is setup as 3 concurrent goroutines:
// - The first retrieves the requests from the client.
// - The third sends responses back to the client.
// - The second waits for either the client to request a new resource type
// or for the config to have been updated
// - in either case, we loop over all of the current client versions
// and if any of them are different from the current version, we send
// the updated resource
func (srv *Server) StreamAggregatedResources(stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error {
incoming := make(chan *envoy_service_discovery_v3.DiscoveryRequest)
outgoing := make(chan *envoy_service_discovery_v3.DiscoveryResponse)
eg, ctx := errgroup.WithContext(stream.Context())
// receive requests
eg.Go(func() error {
return srv.streamAggregatedResourcesIncomingStep(ctx, stream, incoming)
})
eg.Go(func() error {
return srv.streamAggregatedResourcesProcessStep(ctx, incoming, outgoing)
})
// send responses
eg.Go(func() error {
return srv.streamAggregatedResourcesOutgoingStep(ctx, stream, outgoing)
})
return eg.Wait()
}
func (srv *Server) streamAggregatedResourcesIncomingStep(
ctx context.Context,
stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer,
incoming chan<- *envoy_service_discovery_v3.DiscoveryRequest,
) error {
for {
req, err := stream.Recv()
if err != nil {
return err
}
select {
case incoming <- req:
case <-ctx.Done():
return ctx.Err()
}
}
}
func (srv *Server) streamAggregatedResourcesProcessStep(
ctx context.Context,
incoming <-chan *envoy_service_discovery_v3.DiscoveryRequest,
outgoing chan<- *envoy_service_discovery_v3.DiscoveryResponse,
) error {
versions := map[string]string{}
for {
select {
case req := <-incoming:
if req.ErrorDetail != nil {
bs, _ := json.Marshal(req.ErrorDetail.Details)
log.Error().
Err(errors.New(req.ErrorDetail.Message)).
Int32("code", req.ErrorDetail.Code).
RawJSON("details", bs).Msg("error applying configuration")
continue
}
// update the currently stored version
// if this version is different from the current version
// we will send the response below
versions[req.TypeUrl] = req.VersionInfo
case <-srv.configUpdated:
case <-ctx.Done():
return ctx.Err()
}
current := srv.currentConfig.Load()
for typeURL, version := range versions {
// the versions are different, so the envoy config needs to be updated
if version != fmt.Sprint(current.version) {
res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, &current.Options)
if err != nil {
return err
}
select {
case outgoing <- res:
case <-ctx.Done():
return ctx.Err()
}
}
}
}
}
func (srv *Server) streamAggregatedResourcesOutgoingStep(
ctx context.Context,
stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer,
outgoing <-chan *envoy_service_discovery_v3.DiscoveryResponse,
) error {
for {
var res *envoy_service_discovery_v3.DiscoveryResponse
select {
case res = <-outgoing:
case <-ctx.Done():
return ctx.Err()
}
err := stream.Send(res)
if err != nil {
return err
}
}
}
// DeltaAggregatedResources is not implemented.
func (srv *Server) DeltaAggregatedResources(in envoy_service_discovery_v3.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error {
return fmt.Errorf("method DeltaAggregatedResources not implemented")
}

View file

@ -7,6 +7,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -14,6 +15,7 @@ import (
"google.golang.org/grpc/reflection" "google.golang.org/grpc/reflection"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/controlplane/xdsmgr"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
@ -46,15 +48,13 @@ type Server struct {
HTTPRouter *mux.Router HTTPRouter *mux.Router
currentConfig atomicVersionedOptions currentConfig atomicVersionedOptions
configUpdated chan struct{}
name string name string
xdsmgr *xdsmgr.Manager
} }
// NewServer creates a new Server. Listener ports are chosen by the OS. // NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(name string) (*Server, error) { func NewServer(name string) (*Server, error) {
srv := &Server{ srv := &Server{}
configUpdated: make(chan struct{}, 1),
}
srv.currentConfig.Store(versionedOptions{}) srv.currentConfig.Store(versionedOptions{})
var err error var err error
@ -73,7 +73,6 @@ func NewServer(name string) (*Server, error) {
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si), grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
) )
reflection.Register(srv.GRPCServer) reflection.Register(srv.GRPCServer)
srv.registerXDSHandlers()
srv.registerAccessLogHandlers() srv.registerAccessLogHandlers()
// setup HTTP // setup HTTP
@ -85,6 +84,9 @@ func NewServer(name string) (*Server, error) {
srv.HTTPRouter = mux.NewRouter() srv.HTTPRouter = mux.NewRouter()
srv.addHTTPMiddleware() srv.addHTTPMiddleware()
srv.xdsmgr = xdsmgr.NewManager(srv.buildDiscoveryResources())
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv.xdsmgr)
return srv, nil return srv, nil
} }
@ -150,14 +152,10 @@ func (srv *Server) Run(ctx context.Context) error {
// OnConfigChange updates the pomerium config options. // OnConfigChange updates the pomerium config options.
func (srv *Server) OnConfigChange(cfg *config.Config) { func (srv *Server) OnConfigChange(cfg *config.Config) {
select {
case <-srv.configUpdated:
default:
}
prev := srv.currentConfig.Load() prev := srv.currentConfig.Load()
srv.currentConfig.Store(versionedOptions{ srv.currentConfig.Store(versionedOptions{
Options: *cfg.Options, Options: *cfg.Options,
version: prev.version + 1, version: prev.version + 1,
}) })
srv.configUpdated <- struct{}{} srv.xdsmgr.Update(srv.buildDiscoveryResources())
} }

View file

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -19,51 +20,40 @@ import (
envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3" envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3"
envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/any"
"golang.org/x/net/nettest" "golang.org/x/net/nettest"
"google.golang.org/grpc/codes" "google.golang.org/protobuf/proto"
"google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
) )
func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) { const (
switch typeURL { clusterTypeURL = "type.googleapis.com/envoy.config.cluster.v3.Cluster"
case "type.googleapis.com/envoy.config.listener.v3.Listener": listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener"
listeners := buildListeners(options) )
anys := make([]*any.Any, len(listeners))
for i, listener := range listeners { func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discovery_v3.Resource {
a, err := ptypes.MarshalAny(listener) resources := map[string][]*envoy_service_discovery_v3.Resource{}
if err != nil { cfg := srv.currentConfig.Load()
return nil, status.Errorf(codes.Internal, "error marshaling type to any: %v", err) for _, cluster := range srv.buildClusters(&cfg.Options) {
} any, _ := anypb.New(cluster)
anys[i] = a resources[clusterTypeURL] = append(resources[clusterTypeURL], &envoy_service_discovery_v3.Resource{
} Name: cluster.Name,
return &envoy_service_discovery_v3.DiscoveryResponse{ Version: hex.EncodeToString(cryptutil.HashProto(cluster)),
VersionInfo: version, Resource: any,
Resources: anys, })
TypeUrl: typeURL,
}, nil
case "type.googleapis.com/envoy.config.cluster.v3.Cluster":
clusters := srv.buildClusters(options)
anys := make([]*any.Any, len(clusters))
for i, cluster := range clusters {
a, err := ptypes.MarshalAny(cluster)
if err != nil {
return nil, status.Errorf(codes.Internal, "error marshaling type to any: %v", err)
}
anys[i] = a
}
return &envoy_service_discovery_v3.DiscoveryResponse{
VersionInfo: version,
Resources: anys,
TypeUrl: typeURL,
}, nil
default:
return nil, status.Errorf(codes.Internal, "received request for unknown discovery request type: %s", typeURL)
} }
for _, listener := range buildListeners(&cfg.Options) {
any, _ := anypb.New(listener)
resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{
Name: listener.Name,
Version: hex.EncodeToString(cryptutil.HashProto(listener)),
Resource: any,
})
}
return resources
} }
func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
@ -82,7 +72,7 @@ func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.Acces
return nil return nil
} }
tc, _ := ptypes.MarshalAny(&envoy_extensions_access_loggers_grpc_v3.HttpGrpcAccessLogConfig{ tc := marshalAny(&envoy_extensions_access_loggers_grpc_v3.HttpGrpcAccessLogConfig{
CommonConfig: &envoy_extensions_access_loggers_grpc_v3.CommonGrpcAccessLogConfig{ CommonConfig: &envoy_extensions_access_loggers_grpc_v3.CommonGrpcAccessLogConfig{
LogName: "ingress-http", LogName: "ingress-http",
GrpcService: &envoy_config_core_v3.GrpcService{ GrpcService: &envoy_config_core_v3.GrpcService{
@ -235,3 +225,12 @@ func getRootCertificateAuthority() (string, error) {
} }
return rootCABundle.value, nil return rootCABundle.value, nil
} }
func marshalAny(msg proto.Message) *anypb.Any {
any := new(anypb.Any)
_ = anypb.MarshalFrom(any, msg, proto.MarshalOptions{
AllowPartial: true,
Deterministic: true,
})
return any
}

View file

@ -104,7 +104,7 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e
}, },
Sni: sni, Sni: sni,
} }
tlsConfig, _ := ptypes.MarshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
return &envoy_config_core_v3.TransportSocket{ return &envoy_config_core_v3.TransportSocket{
Name: "tls", Name: "tls",
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
@ -144,7 +144,7 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra
envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate)) envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
} }
tlsConfig, _ := ptypes.MarshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
return &envoy_config_core_v3.TransportSocket{ return &envoy_config_core_v3.TransportSocket{
Name: "tls", Name: "tls",
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{

View file

@ -30,7 +30,7 @@ import (
var disableExtAuthz *any.Any var disableExtAuthz *any.Any
func init() { func init() {
disableExtAuthz, _ = ptypes.MarshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute{ disableExtAuthz = marshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute{
Override: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute_Disabled{ Override: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute_Disabled{
Disabled: true, Disabled: true,
}, },
@ -67,7 +67,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen
} }
} }
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty)) tlsInspectorCfg := marshalAny(new(emptypb.Empty))
li := &envoy_config_listener_v3.Listener{ li := &envoy_config_listener_v3.Listener{
Name: "https-ingress", Name: "https-ingress",
Address: buildAddress(options.Addr, 443), Address: buildAddress(options.Addr, 443),
@ -90,7 +90,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen
} }
tlsContext := buildDownstreamTLSContext(options, tlsDomain) tlsContext := buildDownstreamTLSContext(options, tlsDomain)
if tlsContext != nil { if tlsContext != nil {
tlsConfig, _ := ptypes.MarshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
Name: "tls", Name: "tls",
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
@ -161,7 +161,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
grpcClientTimeout = ptypes.DurationProto(30 * time.Second) grpcClientTimeout = ptypes.DurationProto(30 * time.Second)
} }
extAuthZ, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthz{ extAuthZ := marshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthz{
StatusOnError: &envoy_type_v3.HttpStatus{ StatusOnError: &envoy_type_v3.HttpStatus{
Code: envoy_type_v3.StatusCode_InternalServerError, Code: envoy_type_v3.StatusCode_InternalServerError,
}, },
@ -178,13 +178,13 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
IncludePeerCertificate: true, IncludePeerCertificate: true,
}) })
extAuthzSetCookieLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ extAuthzSetCookieLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
InlineCode: luascripts.ExtAuthzSetCookie, InlineCode: luascripts.ExtAuthzSetCookie,
}) })
cleanUpstreamLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ cleanUpstreamLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
InlineCode: luascripts.CleanUpstream, InlineCode: luascripts.CleanUpstream,
}) })
removeImpersonateHeadersLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ removeImpersonateHeadersLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
InlineCode: luascripts.RemoveImpersonateHeaders, InlineCode: luascripts.RemoveImpersonateHeaders,
}) })
@ -193,7 +193,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
maxStreamDuration = ptypes.DurationProto(options.WriteTimeout) maxStreamDuration = ptypes.DurationProto(options.WriteTimeout)
} }
tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
StatPrefix: "ingress", StatPrefix: "ingress",
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
@ -265,7 +265,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
} }
} }
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty)) tlsInspectorCfg := marshalAny(new(emptypb.Empty))
li := &envoy_config_listener_v3.Listener{ li := &envoy_config_listener_v3.Listener{
Name: "grpc-ingress", Name: "grpc-ingress",
Address: buildAddress(options.GRPCAddr, 443), Address: buildAddress(options.GRPCAddr, 443),
@ -287,7 +287,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
} }
tlsContext := buildDownstreamTLSContext(options, tlsDomain) tlsContext := buildDownstreamTLSContext(options, tlsDomain)
if tlsContext != nil { if tlsContext != nil {
tlsConfig, _ := ptypes.MarshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
Name: "tls", Name: "tls",
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
@ -302,7 +302,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
} }
func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter {
tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
StatPrefix: "grpc_ingress", StatPrefix: "grpc_ingress",
// limit request first byte to last byte time // limit request first byte to last byte time

View file

@ -3,6 +3,7 @@ package controlplane
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"sort"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
@ -271,9 +272,15 @@ func mkEnvoyHeader(k, v string) *envoy_config_core_v3.HeaderValueOption {
} }
func toEnvoyHeaders(headers map[string]string) []*envoy_config_core_v3.HeaderValueOption { func toEnvoyHeaders(headers map[string]string) []*envoy_config_core_v3.HeaderValueOption {
var ks []string
for k := range headers {
ks = append(ks, k)
}
sort.Strings(ks)
envoyHeaders := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(headers)) envoyHeaders := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(headers))
for k, v := range headers { for _, k := range ks {
envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, v)) envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, headers[k]))
} }
return envoyHeaders return envoyHeaders
} }

View file

@ -0,0 +1,229 @@
// Package xdsmgr implements a resource discovery manager for envoy.
package xdsmgr
import (
"encoding/json"
"errors"
"sync"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
)
type streamState struct {
typeURL string
clientResourceVersions map[string]string
unsubscribedResources map[string]struct{}
}
var onHandleDeltaRequest = func(state *streamState) {}
// A Manager manages xDS resources.
type Manager struct {
signal *signal.Signal
mu sync.Mutex
nonce string
resources map[string][]*envoy_service_discovery_v3.Resource
}
// NewManager creates a new Manager.
func NewManager(resources map[string][]*envoy_service_discovery_v3.Resource) *Manager {
return &Manager{
signal: signal.New(),
nonce: uuid.New().String(),
resources: resources,
}
}
// DeltaAggregatedResources implements the increment xDS server.
func (mgr *Manager) DeltaAggregatedResources(
stream envoy_service_discovery_v3.AggregatedDiscoveryService_DeltaAggregatedResourcesServer,
) error {
ch := mgr.signal.Bind()
defer mgr.signal.Unbind(ch)
stateByTypeURL := map[string]*streamState{}
getDeltaResponse := func(typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
mgr.mu.Lock()
defer mgr.mu.Unlock()
state, ok := stateByTypeURL[typeURL]
if !ok {
return nil
}
res := &envoy_service_discovery_v3.DeltaDiscoveryResponse{
TypeUrl: typeURL,
Nonce: mgr.nonce,
}
seen := map[string]struct{}{}
for _, resource := range mgr.resources[typeURL] {
seen[resource.Name] = struct{}{}
if resource.Version != state.clientResourceVersions[resource.Name] {
res.Resources = append(res.Resources, resource)
}
}
for name := range state.clientResourceVersions {
_, ok := seen[name]
if !ok {
res.RemovedResources = append(res.RemovedResources, name)
}
}
if len(res.Resources) == 0 && len(res.RemovedResources) == 0 {
return nil
}
return res
}
handleDeltaRequest := func(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
state, ok := stateByTypeURL[req.GetTypeUrl()]
if !ok {
// first time we've seen a message for this type URL.
state = &streamState{
typeURL: req.GetTypeUrl(),
clientResourceVersions: req.GetInitialResourceVersions(),
unsubscribedResources: make(map[string]struct{}),
}
if state.clientResourceVersions == nil {
state.clientResourceVersions = make(map[string]string)
}
stateByTypeURL[req.GetTypeUrl()] = state
}
switch {
case req.GetResponseNonce() == "":
// neither an ACK or a NACK
case req.GetErrorDetail() != nil:
// a NACK
bs, _ := json.Marshal(req.ErrorDetail.Details)
log.Error().
Err(errors.New(req.ErrorDetail.Message)).
Int32("code", req.ErrorDetail.Code).
RawJSON("details", bs).Msg("error applying configuration")
case req.GetResponseNonce() == mgr.nonce:
// an ACK for the last response
// - set the client resource versions to the current resource versions
state.clientResourceVersions = make(map[string]string)
for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version
}
default:
// an ACK for a response that's not the last response
}
// update subscriptions
for _, name := range req.GetResourceNamesSubscribe() {
delete(state.unsubscribedResources, name)
}
for _, name := range req.GetResourceNamesUnsubscribe() {
state.unsubscribedResources[name] = struct{}{}
// from the docs:
// NOTE: the server must respond with all resources listed in
// resource_names_subscribe, even if it believes the client has
// the most recent version of them. The reason: the client may
// have dropped them, but then regained interest before it had
// a chance to send the unsubscribe message.
// so we reset the version to treat it like a new version
delete(state.clientResourceVersions, name)
}
onHandleDeltaRequest(state)
}
incoming := make(chan *envoy_service_discovery_v3.DeltaDiscoveryRequest)
outgoing := make(chan *envoy_service_discovery_v3.DeltaDiscoveryResponse)
eg, ctx := errgroup.WithContext(stream.Context())
// 1. receive all incoming messages
eg.Go(func() error {
for {
req, err := stream.Recv()
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case incoming <- req:
}
}
})
// 2. handle incoming requests or resource changes
eg.Go(func() error {
for {
var typeURLs []string
select {
case <-ctx.Done():
return ctx.Err()
case req := <-incoming:
handleDeltaRequest(req)
typeURLs = []string{req.GetTypeUrl()}
case <-ch:
mgr.mu.Lock()
for typeURL := range mgr.resources {
typeURLs = append(typeURLs, typeURL)
}
mgr.mu.Unlock()
}
for _, typeURL := range typeURLs {
res := getDeltaResponse(typeURL)
if res == nil {
continue
}
select {
case <-ctx.Done():
return ctx.Err()
case outgoing <- res:
}
}
}
})
// 3. send all outgoing messages
eg.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case res := <-outgoing:
err := stream.Send(res)
if err != nil {
return err
}
}
}
})
return eg.Wait()
}
// StreamAggregatedResources is not implemented.
func (mgr *Manager) StreamAggregatedResources(
stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer,
) error {
return status.Errorf(codes.Unimplemented, "method StreamAggregatedResources not implemented")
}
// Update updates the state of resources. If any changes are made they will be pushed to any listening
// streams. For each TypeURL the list of resources should be the complete list of resources.
func (mgr *Manager) Update(resources map[string][]*envoy_service_discovery_v3.Resource) {
mgr.mu.Lock()
mgr.nonce = uuid.New().String()
mgr.resources = resources
mgr.mu.Unlock()
mgr.signal.Broadcast()
}

View file

@ -0,0 +1,116 @@
package xdsmgr
import (
"context"
"net"
"testing"
"time"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/test/bufconn"
"github.com/pomerium/pomerium/internal/signal"
)
const bufSize = 1024 * 1024
func TestManager(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
typeURL := "example.com/example"
stateChanged := signal.New()
origOnHandleDeltaRequest := onHandleDeltaRequest
defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }()
onHandleDeltaRequest = func(state *streamState) {
stateChanged.Broadcast()
}
srv := grpc.NewServer()
mgr := NewManager(map[string][]*envoy_service_discovery_v3.Resource{
typeURL: {
{Name: "r1", Version: "1"},
},
})
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv, mgr)
li := bufconn.Listen(bufSize)
go func() { _ = srv.Serve(li) }()
cc, err := grpc.Dial("test",
grpc.WithInsecure(),
grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) {
return li.Dial()
}))
if !assert.NoError(t, err) {
return
}
defer func() { _ = cc.Close() }()
client := envoy_service_discovery_v3.NewAggregatedDiscoveryServiceClient(cc)
t.Run("stream is disabled", func(t *testing.T) {
stream, err := client.StreamAggregatedResources(ctx)
if !assert.NoError(t, err) {
return
}
_, err = stream.Recv()
assert.Error(t, err, "only delta should be implemented")
assert.Equal(t, codes.Unimplemented, grpc.Code(err))
})
t.Run("updates", func(t *testing.T) {
stream, err := client.DeltaAggregatedResources(ctx)
if !assert.NoError(t, err) {
return
}
ch := stateChanged.Bind()
defer stateChanged.Unbind(ch)
ack := func(nonce string) {
err = stream.Send(&envoy_service_discovery_v3.DeltaDiscoveryRequest{
TypeUrl: typeURL,
ResponseNonce: nonce,
})
assert.NoError(t, err)
select {
case <-ctx.Done():
t.Fatal(ctx.Err())
case <-ch:
}
}
ack("")
msg, err := stream.Recv()
assert.NoError(t, err)
assert.NotEmpty(t, msg.GetNonce(), "nonce should not be empty")
assert.Equal(t, []*envoy_service_discovery_v3.Resource{
{Name: "r1", Version: "1"},
}, msg.GetResources())
ack(msg.Nonce)
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
typeURL: {{Name: "r1", Version: "2"}},
})
msg, err = stream.Recv()
assert.NoError(t, err)
assert.Equal(t, []*envoy_service_discovery_v3.Resource{
{Name: "r1", Version: "2"},
}, msg.GetResources())
ack(msg.Nonce)
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
typeURL: nil,
})
msg, err = stream.Recv()
assert.NoError(t, err)
assert.Equal(t, []string{"r1"}, msg.GetRemovedResources())
ack(msg.Nonce)
})
}

View file

@ -229,7 +229,7 @@ func (srv *Server) buildBootstrapConfig() ([]byte, error) {
dynamicCfg := &envoy_config_bootstrap_v3.Bootstrap_DynamicResources{ dynamicCfg := &envoy_config_bootstrap_v3.Bootstrap_DynamicResources{
AdsConfig: &envoy_config_core_v3.ApiConfigSource{ AdsConfig: &envoy_config_core_v3.ApiConfigSource{
ApiType: envoy_config_core_v3.ApiConfigSource_ApiType(envoy_config_core_v3.ApiConfigSource_ApiType_value["GRPC"]), ApiType: envoy_config_core_v3.ApiConfigSource_ApiType(envoy_config_core_v3.ApiConfigSource_ApiType_value["DELTA_GRPC"]),
TransportApiVersion: envoy_config_core_v3.ApiVersion_V3, TransportApiVersion: envoy_config_core_v3.ApiVersion_V3,
GrpcServices: []*envoy_config_core_v3.GrpcService{ GrpcServices: []*envoy_config_core_v3.GrpcService{
{ {

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/cenkalti/backoff/v4"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes"
"github.com/google/btree" "github.com/google/btree"
@ -578,29 +579,31 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
return err return err
} }
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{ return exponentialTry(ctx, func() error {
Type: any.GetTypeUrl(), res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
}) Type: any.GetTypeUrl(),
if err != nil { })
return fmt.Errorf("error getting all directory users: %w", err)
}
mgr.directoryUsers = map[string]*directory.User{}
for _, record := range res.GetRecords() {
var pbDirectoryUser directory.User
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
if err != nil { if err != nil {
return fmt.Errorf("error unmarshaling directory user: %w", err) return fmt.Errorf("error getting all directory users: %w", err)
} }
mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser mgr.directoryUsers = map[string]*directory.User{}
mgr.directoryUsersRecordVersion = record.GetVersion() for _, record := range res.GetRecords() {
} var pbDirectoryUser directory.User
mgr.directoryUsersServerVersion = res.GetServerVersion() err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
if err != nil {
return fmt.Errorf("error unmarshaling directory user: %w", err)
}
mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users") mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser
mgr.directoryUsersRecordVersion = record.GetVersion()
}
mgr.directoryUsersServerVersion = res.GetServerVersion()
return nil mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users")
return nil
})
} }
func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory.User) error { func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory.User) error {
@ -648,30 +651,31 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
return exponentialTry(ctx, func() error {
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{ res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
Type: any.GetTypeUrl(), Type: any.GetTypeUrl(),
}) })
if err != nil {
return fmt.Errorf("error getting all directory groups: %w", err)
}
mgr.directoryGroups = map[string]*directory.Group{}
for _, record := range res.GetRecords() {
var pbDirectoryGroup directory.Group
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
if err != nil { if err != nil {
return fmt.Errorf("error unmarshaling directory group: %w", err) return fmt.Errorf("error getting all directory groups: %w", err)
} }
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup mgr.directoryGroups = map[string]*directory.Group{}
mgr.directoryGroupsRecordVersion = record.GetVersion() for _, record := range res.GetRecords() {
} var pbDirectoryGroup directory.Group
mgr.directoryGroupsServerVersion = res.GetServerVersion() err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
if err != nil {
return fmt.Errorf("error unmarshaling directory group: %w", err)
}
mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups") mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
mgr.directoryGroupsRecordVersion = record.GetVersion()
}
mgr.directoryGroupsServerVersion = res.GetServerVersion()
return nil mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups")
return nil
})
} }
func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error { func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error {
@ -775,3 +779,22 @@ func isTemporaryError(err error) bool {
} }
return false return false
} }
// exponentialTry executes f until it succeeds or ctx is Done.
func exponentialTry(ctx context.Context, f func() error) error {
backoff := backoff.NewExponentialBackOff()
backoff.MaxElapsedTime = 0
for {
err := f()
if err == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff.NextBackOff()):
}
continue
}
}

View file

@ -5,6 +5,7 @@ import (
"crypto/sha512" "crypto/sha512"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"google.golang.org/protobuf/proto"
) )
// Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to // Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to
@ -28,3 +29,14 @@ func HashPassword(password []byte) ([]byte, error) {
func CheckPasswordHash(hash, password []byte) error { func CheckPasswordHash(hash, password []byte) error {
return bcrypt.CompareHashAndPassword(hash, password) return bcrypt.CompareHashAndPassword(hash, password)
} }
// HashProto hashes a protobuf message. It sets `Deterministic` to true to ensure
// the encoded message is always the same. (ie map order is lexographic)
func HashProto(msg proto.Message) []byte {
opts := proto.MarshalOptions{
AllowPartial: true,
Deterministic: true,
}
bs, _ := opts.Marshal(msg)
return Hash("proto", bs)
}

View file

@ -8,6 +8,9 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
) )
func TestPasswordHashing(t *testing.T) { func TestPasswordHashing(t *testing.T) {
@ -79,3 +82,25 @@ func ExampleHash() {
fmt.Println(hex.EncodeToString(digest)) fmt.Println(hex.EncodeToString(digest))
// Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1 // Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1
} }
func TestHashProto(t *testing.T) {
// This test will hash a protobuf message that has a map 1000 times
// each attempt should result in the same hash if the output is
// deterministic.
var cur []byte
for i := 0; i < 1000; i++ {
s, err := structpb.NewStruct(map[string]interface{}{
"1": "a", "2": "b", "3": "c", "4": "d",
"5": "e", "6": "f", "7": "g", "8": "h",
})
assert.NoError(t, err)
if i == 0 {
cur = HashProto(s)
} else {
nxt := HashProto(s)
if !assert.Equal(t, cur, nxt) {
return
}
}
}
}