mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 05:57:19 +02:00
use incremental API for envoy xDS (#1732)
* use incremental API * add test * use backoff v4 * remove panic, add comment to exponential try, add test for HashProto * merge master * fix missing import
This commit is contained in:
parent
a07d85b174
commit
3524697f6f
12 changed files with 511 additions and 238 deletions
|
@ -1,136 +0,0 @@
|
|||
package controlplane
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
func (srv *Server) registerXDSHandlers() {
|
||||
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv)
|
||||
}
|
||||
|
||||
// StreamAggregatedResources streams xDS resources based on incoming discovery requests.
|
||||
//
|
||||
// This is setup as 3 concurrent goroutines:
|
||||
// - The first retrieves the requests from the client.
|
||||
// - The third sends responses back to the client.
|
||||
// - The second waits for either the client to request a new resource type
|
||||
// or for the config to have been updated
|
||||
// - in either case, we loop over all of the current client versions
|
||||
// and if any of them are different from the current version, we send
|
||||
// the updated resource
|
||||
func (srv *Server) StreamAggregatedResources(stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error {
|
||||
incoming := make(chan *envoy_service_discovery_v3.DiscoveryRequest)
|
||||
outgoing := make(chan *envoy_service_discovery_v3.DiscoveryResponse)
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
// receive requests
|
||||
eg.Go(func() error {
|
||||
return srv.streamAggregatedResourcesIncomingStep(ctx, stream, incoming)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return srv.streamAggregatedResourcesProcessStep(ctx, incoming, outgoing)
|
||||
})
|
||||
// send responses
|
||||
eg.Go(func() error {
|
||||
return srv.streamAggregatedResourcesOutgoingStep(ctx, stream, outgoing)
|
||||
})
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (srv *Server) streamAggregatedResourcesIncomingStep(
|
||||
ctx context.Context,
|
||||
stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer,
|
||||
incoming chan<- *envoy_service_discovery_v3.DiscoveryRequest,
|
||||
) error {
|
||||
for {
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case incoming <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) streamAggregatedResourcesProcessStep(
|
||||
ctx context.Context,
|
||||
incoming <-chan *envoy_service_discovery_v3.DiscoveryRequest,
|
||||
outgoing chan<- *envoy_service_discovery_v3.DiscoveryResponse,
|
||||
) error {
|
||||
versions := map[string]string{}
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-incoming:
|
||||
if req.ErrorDetail != nil {
|
||||
bs, _ := json.Marshal(req.ErrorDetail.Details)
|
||||
log.Error().
|
||||
Err(errors.New(req.ErrorDetail.Message)).
|
||||
Int32("code", req.ErrorDetail.Code).
|
||||
RawJSON("details", bs).Msg("error applying configuration")
|
||||
continue
|
||||
}
|
||||
|
||||
// update the currently stored version
|
||||
// if this version is different from the current version
|
||||
// we will send the response below
|
||||
versions[req.TypeUrl] = req.VersionInfo
|
||||
case <-srv.configUpdated:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
current := srv.currentConfig.Load()
|
||||
for typeURL, version := range versions {
|
||||
// the versions are different, so the envoy config needs to be updated
|
||||
if version != fmt.Sprint(current.version) {
|
||||
res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, ¤t.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case outgoing <- res:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) streamAggregatedResourcesOutgoingStep(
|
||||
ctx context.Context,
|
||||
stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer,
|
||||
outgoing <-chan *envoy_service_discovery_v3.DiscoveryResponse,
|
||||
) error {
|
||||
for {
|
||||
var res *envoy_service_discovery_v3.DiscoveryResponse
|
||||
select {
|
||||
case res = <-outgoing:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
err := stream.Send(res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeltaAggregatedResources is not implemented.
|
||||
func (srv *Server) DeltaAggregatedResources(in envoy_service_discovery_v3.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error {
|
||||
return fmt.Errorf("method DeltaAggregatedResources not implemented")
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
"google.golang.org/grpc/reflection"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/controlplane/xdsmgr"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
|
@ -46,15 +48,13 @@ type Server struct {
|
|||
HTTPRouter *mux.Router
|
||||
|
||||
currentConfig atomicVersionedOptions
|
||||
configUpdated chan struct{}
|
||||
name string
|
||||
xdsmgr *xdsmgr.Manager
|
||||
}
|
||||
|
||||
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
||||
func NewServer(name string) (*Server, error) {
|
||||
srv := &Server{
|
||||
configUpdated: make(chan struct{}, 1),
|
||||
}
|
||||
srv := &Server{}
|
||||
srv.currentConfig.Store(versionedOptions{})
|
||||
|
||||
var err error
|
||||
|
@ -73,7 +73,6 @@ func NewServer(name string) (*Server, error) {
|
|||
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
|
||||
)
|
||||
reflection.Register(srv.GRPCServer)
|
||||
srv.registerXDSHandlers()
|
||||
srv.registerAccessLogHandlers()
|
||||
|
||||
// setup HTTP
|
||||
|
@ -85,6 +84,9 @@ func NewServer(name string) (*Server, error) {
|
|||
srv.HTTPRouter = mux.NewRouter()
|
||||
srv.addHTTPMiddleware()
|
||||
|
||||
srv.xdsmgr = xdsmgr.NewManager(srv.buildDiscoveryResources())
|
||||
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv.xdsmgr)
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
|
@ -150,14 +152,10 @@ func (srv *Server) Run(ctx context.Context) error {
|
|||
|
||||
// OnConfigChange updates the pomerium config options.
|
||||
func (srv *Server) OnConfigChange(cfg *config.Config) {
|
||||
select {
|
||||
case <-srv.configUpdated:
|
||||
default:
|
||||
}
|
||||
prev := srv.currentConfig.Load()
|
||||
srv.currentConfig.Store(versionedOptions{
|
||||
Options: *cfg.Options,
|
||||
version: prev.version + 1,
|
||||
})
|
||||
srv.configUpdated <- struct{}{}
|
||||
srv.xdsmgr.Update(srv.buildDiscoveryResources())
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -19,51 +20,40 @@ import (
|
|||
envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3"
|
||||
envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/any"
|
||||
"golang.org/x/net/nettest"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) {
|
||||
switch typeURL {
|
||||
case "type.googleapis.com/envoy.config.listener.v3.Listener":
|
||||
listeners := buildListeners(options)
|
||||
anys := make([]*any.Any, len(listeners))
|
||||
for i, listener := range listeners {
|
||||
a, err := ptypes.MarshalAny(listener)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "error marshaling type to any: %v", err)
|
||||
}
|
||||
anys[i] = a
|
||||
}
|
||||
return &envoy_service_discovery_v3.DiscoveryResponse{
|
||||
VersionInfo: version,
|
||||
Resources: anys,
|
||||
TypeUrl: typeURL,
|
||||
}, nil
|
||||
case "type.googleapis.com/envoy.config.cluster.v3.Cluster":
|
||||
clusters := srv.buildClusters(options)
|
||||
anys := make([]*any.Any, len(clusters))
|
||||
for i, cluster := range clusters {
|
||||
a, err := ptypes.MarshalAny(cluster)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "error marshaling type to any: %v", err)
|
||||
}
|
||||
anys[i] = a
|
||||
}
|
||||
return &envoy_service_discovery_v3.DiscoveryResponse{
|
||||
VersionInfo: version,
|
||||
Resources: anys,
|
||||
TypeUrl: typeURL,
|
||||
}, nil
|
||||
default:
|
||||
return nil, status.Errorf(codes.Internal, "received request for unknown discovery request type: %s", typeURL)
|
||||
const (
|
||||
clusterTypeURL = "type.googleapis.com/envoy.config.cluster.v3.Cluster"
|
||||
listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener"
|
||||
)
|
||||
|
||||
func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discovery_v3.Resource {
|
||||
resources := map[string][]*envoy_service_discovery_v3.Resource{}
|
||||
cfg := srv.currentConfig.Load()
|
||||
for _, cluster := range srv.buildClusters(&cfg.Options) {
|
||||
any, _ := anypb.New(cluster)
|
||||
resources[clusterTypeURL] = append(resources[clusterTypeURL], &envoy_service_discovery_v3.Resource{
|
||||
Name: cluster.Name,
|
||||
Version: hex.EncodeToString(cryptutil.HashProto(cluster)),
|
||||
Resource: any,
|
||||
})
|
||||
}
|
||||
for _, listener := range buildListeners(&cfg.Options) {
|
||||
any, _ := anypb.New(listener)
|
||||
resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{
|
||||
Name: listener.Name,
|
||||
Version: hex.EncodeToString(cryptutil.HashProto(listener)),
|
||||
Resource: any,
|
||||
})
|
||||
}
|
||||
return resources
|
||||
}
|
||||
|
||||
func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
|
||||
|
@ -82,7 +72,7 @@ func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.Acces
|
|||
return nil
|
||||
}
|
||||
|
||||
tc, _ := ptypes.MarshalAny(&envoy_extensions_access_loggers_grpc_v3.HttpGrpcAccessLogConfig{
|
||||
tc := marshalAny(&envoy_extensions_access_loggers_grpc_v3.HttpGrpcAccessLogConfig{
|
||||
CommonConfig: &envoy_extensions_access_loggers_grpc_v3.CommonGrpcAccessLogConfig{
|
||||
LogName: "ingress-http",
|
||||
GrpcService: &envoy_config_core_v3.GrpcService{
|
||||
|
@ -235,3 +225,12 @@ func getRootCertificateAuthority() (string, error) {
|
|||
}
|
||||
return rootCABundle.value, nil
|
||||
}
|
||||
|
||||
func marshalAny(msg proto.Message) *anypb.Any {
|
||||
any := new(anypb.Any)
|
||||
_ = anypb.MarshalFrom(any, msg, proto.MarshalOptions{
|
||||
AllowPartial: true,
|
||||
Deterministic: true,
|
||||
})
|
||||
return any
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e
|
|||
},
|
||||
Sni: sni,
|
||||
}
|
||||
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
return &envoy_config_core_v3.TransportSocket{
|
||||
Name: "tls",
|
||||
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
|
||||
|
@ -144,7 +144,7 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra
|
|||
envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
|
||||
}
|
||||
|
||||
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
return &envoy_config_core_v3.TransportSocket{
|
||||
Name: "tls",
|
||||
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
|
||||
|
|
|
@ -30,7 +30,7 @@ import (
|
|||
var disableExtAuthz *any.Any
|
||||
|
||||
func init() {
|
||||
disableExtAuthz, _ = ptypes.MarshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute{
|
||||
disableExtAuthz = marshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute{
|
||||
Override: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute_Disabled{
|
||||
Disabled: true,
|
||||
},
|
||||
|
@ -67,7 +67,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
}
|
||||
}
|
||||
|
||||
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty))
|
||||
tlsInspectorCfg := marshalAny(new(emptypb.Empty))
|
||||
li := &envoy_config_listener_v3.Listener{
|
||||
Name: "https-ingress",
|
||||
Address: buildAddress(options.Addr, 443),
|
||||
|
@ -90,7 +90,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
}
|
||||
tlsContext := buildDownstreamTLSContext(options, tlsDomain)
|
||||
if tlsContext != nil {
|
||||
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
Name: "tls",
|
||||
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
|
||||
|
@ -161,7 +161,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
|
|||
grpcClientTimeout = ptypes.DurationProto(30 * time.Second)
|
||||
}
|
||||
|
||||
extAuthZ, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthz{
|
||||
extAuthZ := marshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthz{
|
||||
StatusOnError: &envoy_type_v3.HttpStatus{
|
||||
Code: envoy_type_v3.StatusCode_InternalServerError,
|
||||
},
|
||||
|
@ -178,13 +178,13 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
|
|||
IncludePeerCertificate: true,
|
||||
})
|
||||
|
||||
extAuthzSetCookieLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
|
||||
extAuthzSetCookieLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
|
||||
InlineCode: luascripts.ExtAuthzSetCookie,
|
||||
})
|
||||
cleanUpstreamLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
|
||||
cleanUpstreamLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
|
||||
InlineCode: luascripts.CleanUpstream,
|
||||
})
|
||||
removeImpersonateHeadersLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
|
||||
removeImpersonateHeadersLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{
|
||||
InlineCode: luascripts.RemoveImpersonateHeaders,
|
||||
})
|
||||
|
||||
|
@ -193,7 +193,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
|
|||
maxStreamDuration = ptypes.DurationProto(options.WriteTimeout)
|
||||
}
|
||||
|
||||
tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{
|
||||
tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{
|
||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||
StatPrefix: "ingress",
|
||||
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
||||
|
@ -265,7 +265,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
}
|
||||
}
|
||||
|
||||
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty))
|
||||
tlsInspectorCfg := marshalAny(new(emptypb.Empty))
|
||||
li := &envoy_config_listener_v3.Listener{
|
||||
Name: "grpc-ingress",
|
||||
Address: buildAddress(options.GRPCAddr, 443),
|
||||
|
@ -287,7 +287,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
}
|
||||
tlsContext := buildDownstreamTLSContext(options, tlsDomain)
|
||||
if tlsContext != nil {
|
||||
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
Name: "tls",
|
||||
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
|
||||
|
@ -302,7 +302,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
}
|
||||
|
||||
func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter {
|
||||
tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{
|
||||
tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{
|
||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||
StatPrefix: "grpc_ingress",
|
||||
// limit request first byte to last byte time
|
||||
|
|
|
@ -3,6 +3,7 @@ package controlplane
|
|||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
|
||||
|
@ -271,9 +272,15 @@ func mkEnvoyHeader(k, v string) *envoy_config_core_v3.HeaderValueOption {
|
|||
}
|
||||
|
||||
func toEnvoyHeaders(headers map[string]string) []*envoy_config_core_v3.HeaderValueOption {
|
||||
var ks []string
|
||||
for k := range headers {
|
||||
ks = append(ks, k)
|
||||
}
|
||||
sort.Strings(ks)
|
||||
|
||||
envoyHeaders := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(headers))
|
||||
for k, v := range headers {
|
||||
envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, v))
|
||||
for _, k := range ks {
|
||||
envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, headers[k]))
|
||||
}
|
||||
return envoyHeaders
|
||||
}
|
||||
|
|
229
internal/controlplane/xdsmgr/xdsmgr.go
Normal file
229
internal/controlplane/xdsmgr/xdsmgr.go
Normal file
|
@ -0,0 +1,229 @@
|
|||
// Package xdsmgr implements a resource discovery manager for envoy.
|
||||
package xdsmgr
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
)
|
||||
|
||||
type streamState struct {
|
||||
typeURL string
|
||||
clientResourceVersions map[string]string
|
||||
unsubscribedResources map[string]struct{}
|
||||
}
|
||||
|
||||
var onHandleDeltaRequest = func(state *streamState) {}
|
||||
|
||||
// A Manager manages xDS resources.
|
||||
type Manager struct {
|
||||
signal *signal.Signal
|
||||
|
||||
mu sync.Mutex
|
||||
nonce string
|
||||
resources map[string][]*envoy_service_discovery_v3.Resource
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager.
|
||||
func NewManager(resources map[string][]*envoy_service_discovery_v3.Resource) *Manager {
|
||||
return &Manager{
|
||||
signal: signal.New(),
|
||||
nonce: uuid.New().String(),
|
||||
resources: resources,
|
||||
}
|
||||
}
|
||||
|
||||
// DeltaAggregatedResources implements the increment xDS server.
|
||||
func (mgr *Manager) DeltaAggregatedResources(
|
||||
stream envoy_service_discovery_v3.AggregatedDiscoveryService_DeltaAggregatedResourcesServer,
|
||||
) error {
|
||||
ch := mgr.signal.Bind()
|
||||
defer mgr.signal.Unbind(ch)
|
||||
|
||||
stateByTypeURL := map[string]*streamState{}
|
||||
|
||||
getDeltaResponse := func(typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
state, ok := stateByTypeURL[typeURL]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
res := &envoy_service_discovery_v3.DeltaDiscoveryResponse{
|
||||
TypeUrl: typeURL,
|
||||
Nonce: mgr.nonce,
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
for _, resource := range mgr.resources[typeURL] {
|
||||
seen[resource.Name] = struct{}{}
|
||||
if resource.Version != state.clientResourceVersions[resource.Name] {
|
||||
res.Resources = append(res.Resources, resource)
|
||||
}
|
||||
}
|
||||
for name := range state.clientResourceVersions {
|
||||
_, ok := seen[name]
|
||||
if !ok {
|
||||
res.RemovedResources = append(res.RemovedResources, name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(res.Resources) == 0 && len(res.RemovedResources) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
handleDeltaRequest := func(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
state, ok := stateByTypeURL[req.GetTypeUrl()]
|
||||
if !ok {
|
||||
// first time we've seen a message for this type URL.
|
||||
state = &streamState{
|
||||
typeURL: req.GetTypeUrl(),
|
||||
clientResourceVersions: req.GetInitialResourceVersions(),
|
||||
unsubscribedResources: make(map[string]struct{}),
|
||||
}
|
||||
if state.clientResourceVersions == nil {
|
||||
state.clientResourceVersions = make(map[string]string)
|
||||
}
|
||||
stateByTypeURL[req.GetTypeUrl()] = state
|
||||
}
|
||||
|
||||
switch {
|
||||
case req.GetResponseNonce() == "":
|
||||
// neither an ACK or a NACK
|
||||
case req.GetErrorDetail() != nil:
|
||||
// a NACK
|
||||
bs, _ := json.Marshal(req.ErrorDetail.Details)
|
||||
log.Error().
|
||||
Err(errors.New(req.ErrorDetail.Message)).
|
||||
Int32("code", req.ErrorDetail.Code).
|
||||
RawJSON("details", bs).Msg("error applying configuration")
|
||||
case req.GetResponseNonce() == mgr.nonce:
|
||||
// an ACK for the last response
|
||||
// - set the client resource versions to the current resource versions
|
||||
state.clientResourceVersions = make(map[string]string)
|
||||
for _, resource := range mgr.resources[req.GetTypeUrl()] {
|
||||
state.clientResourceVersions[resource.Name] = resource.Version
|
||||
}
|
||||
default:
|
||||
// an ACK for a response that's not the last response
|
||||
}
|
||||
|
||||
// update subscriptions
|
||||
for _, name := range req.GetResourceNamesSubscribe() {
|
||||
delete(state.unsubscribedResources, name)
|
||||
}
|
||||
for _, name := range req.GetResourceNamesUnsubscribe() {
|
||||
state.unsubscribedResources[name] = struct{}{}
|
||||
// from the docs:
|
||||
// NOTE: the server must respond with all resources listed in
|
||||
// resource_names_subscribe, even if it believes the client has
|
||||
// the most recent version of them. The reason: the client may
|
||||
// have dropped them, but then regained interest before it had
|
||||
// a chance to send the unsubscribe message.
|
||||
// so we reset the version to treat it like a new version
|
||||
delete(state.clientResourceVersions, name)
|
||||
}
|
||||
|
||||
onHandleDeltaRequest(state)
|
||||
}
|
||||
|
||||
incoming := make(chan *envoy_service_discovery_v3.DeltaDiscoveryRequest)
|
||||
outgoing := make(chan *envoy_service_discovery_v3.DeltaDiscoveryResponse)
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
// 1. receive all incoming messages
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case incoming <- req:
|
||||
}
|
||||
}
|
||||
})
|
||||
// 2. handle incoming requests or resource changes
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
var typeURLs []string
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case req := <-incoming:
|
||||
handleDeltaRequest(req)
|
||||
typeURLs = []string{req.GetTypeUrl()}
|
||||
case <-ch:
|
||||
mgr.mu.Lock()
|
||||
for typeURL := range mgr.resources {
|
||||
typeURLs = append(typeURLs, typeURL)
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
for _, typeURL := range typeURLs {
|
||||
res := getDeltaResponse(typeURL)
|
||||
if res == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case outgoing <- res:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
// 3. send all outgoing messages
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case res := <-outgoing:
|
||||
err := stream.Send(res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
// StreamAggregatedResources is not implemented.
|
||||
func (mgr *Manager) StreamAggregatedResources(
|
||||
stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer,
|
||||
) error {
|
||||
return status.Errorf(codes.Unimplemented, "method StreamAggregatedResources not implemented")
|
||||
}
|
||||
|
||||
// Update updates the state of resources. If any changes are made they will be pushed to any listening
|
||||
// streams. For each TypeURL the list of resources should be the complete list of resources.
|
||||
func (mgr *Manager) Update(resources map[string][]*envoy_service_discovery_v3.Resource) {
|
||||
mgr.mu.Lock()
|
||||
mgr.nonce = uuid.New().String()
|
||||
mgr.resources = resources
|
||||
mgr.mu.Unlock()
|
||||
|
||||
mgr.signal.Broadcast()
|
||||
}
|
116
internal/controlplane/xdsmgr/xdsmgr_test.go
Normal file
116
internal/controlplane/xdsmgr/xdsmgr_test.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package xdsmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
)
|
||||
|
||||
const bufSize = 1024 * 1024
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
typeURL := "example.com/example"
|
||||
|
||||
stateChanged := signal.New()
|
||||
origOnHandleDeltaRequest := onHandleDeltaRequest
|
||||
defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }()
|
||||
onHandleDeltaRequest = func(state *streamState) {
|
||||
stateChanged.Broadcast()
|
||||
}
|
||||
|
||||
srv := grpc.NewServer()
|
||||
mgr := NewManager(map[string][]*envoy_service_discovery_v3.Resource{
|
||||
typeURL: {
|
||||
{Name: "r1", Version: "1"},
|
||||
},
|
||||
})
|
||||
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv, mgr)
|
||||
|
||||
li := bufconn.Listen(bufSize)
|
||||
go func() { _ = srv.Serve(li) }()
|
||||
|
||||
cc, err := grpc.Dial("test",
|
||||
grpc.WithInsecure(),
|
||||
grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) {
|
||||
return li.Dial()
|
||||
}))
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() { _ = cc.Close() }()
|
||||
|
||||
client := envoy_service_discovery_v3.NewAggregatedDiscoveryServiceClient(cc)
|
||||
t.Run("stream is disabled", func(t *testing.T) {
|
||||
stream, err := client.StreamAggregatedResources(ctx)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = stream.Recv()
|
||||
assert.Error(t, err, "only delta should be implemented")
|
||||
assert.Equal(t, codes.Unimplemented, grpc.Code(err))
|
||||
})
|
||||
|
||||
t.Run("updates", func(t *testing.T) {
|
||||
stream, err := client.DeltaAggregatedResources(ctx)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
ch := stateChanged.Bind()
|
||||
defer stateChanged.Unbind(ch)
|
||||
ack := func(nonce string) {
|
||||
err = stream.Send(&envoy_service_discovery_v3.DeltaDiscoveryRequest{
|
||||
TypeUrl: typeURL,
|
||||
ResponseNonce: nonce,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
|
||||
ack("")
|
||||
|
||||
msg, err := stream.Recv()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, msg.GetNonce(), "nonce should not be empty")
|
||||
assert.Equal(t, []*envoy_service_discovery_v3.Resource{
|
||||
{Name: "r1", Version: "1"},
|
||||
}, msg.GetResources())
|
||||
ack(msg.Nonce)
|
||||
|
||||
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
|
||||
typeURL: {{Name: "r1", Version: "2"}},
|
||||
})
|
||||
|
||||
msg, err = stream.Recv()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*envoy_service_discovery_v3.Resource{
|
||||
{Name: "r1", Version: "2"},
|
||||
}, msg.GetResources())
|
||||
ack(msg.Nonce)
|
||||
|
||||
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
|
||||
typeURL: nil,
|
||||
})
|
||||
|
||||
msg, err = stream.Recv()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"r1"}, msg.GetRemovedResources())
|
||||
ack(msg.Nonce)
|
||||
})
|
||||
}
|
|
@ -229,7 +229,7 @@ func (srv *Server) buildBootstrapConfig() ([]byte, error) {
|
|||
|
||||
dynamicCfg := &envoy_config_bootstrap_v3.Bootstrap_DynamicResources{
|
||||
AdsConfig: &envoy_config_core_v3.ApiConfigSource{
|
||||
ApiType: envoy_config_core_v3.ApiConfigSource_ApiType(envoy_config_core_v3.ApiConfigSource_ApiType_value["GRPC"]),
|
||||
ApiType: envoy_config_core_v3.ApiConfigSource_ApiType(envoy_config_core_v3.ApiConfigSource_ApiType_value["DELTA_GRPC"]),
|
||||
TransportApiVersion: envoy_config_core_v3.ApiVersion_V3,
|
||||
GrpcServices: []*envoy_config_core_v3.GrpcService{
|
||||
{
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/google/btree"
|
||||
|
@ -578,29 +579,31 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting all directory users: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryUsers = map[string]*directory.User{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryUser directory.User
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
|
||||
return exponentialTry(ctx, func() error {
|
||||
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory user: %w", err)
|
||||
return fmt.Errorf("error getting all directory users: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser
|
||||
mgr.directoryUsersRecordVersion = record.GetVersion()
|
||||
}
|
||||
mgr.directoryUsersServerVersion = res.GetServerVersion()
|
||||
mgr.directoryUsers = map[string]*directory.User{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryUser directory.User
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory user: %w", err)
|
||||
}
|
||||
|
||||
mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users")
|
||||
mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser
|
||||
mgr.directoryUsersRecordVersion = record.GetVersion()
|
||||
}
|
||||
mgr.directoryUsersServerVersion = res.GetServerVersion()
|
||||
|
||||
return nil
|
||||
mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory.User) error {
|
||||
|
@ -648,30 +651,31 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting all directory groups: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryGroups = map[string]*directory.Group{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
|
||||
return exponentialTry(ctx, func() error {
|
||||
res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory group: %w", err)
|
||||
return fmt.Errorf("error getting all directory groups: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
|
||||
mgr.directoryGroupsRecordVersion = record.GetVersion()
|
||||
}
|
||||
mgr.directoryGroupsServerVersion = res.GetServerVersion()
|
||||
mgr.directoryGroups = map[string]*directory.Group{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory group: %w", err)
|
||||
}
|
||||
|
||||
mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups")
|
||||
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
|
||||
mgr.directoryGroupsRecordVersion = record.GetVersion()
|
||||
}
|
||||
mgr.directoryGroupsServerVersion = res.GetServerVersion()
|
||||
|
||||
return nil
|
||||
mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error {
|
||||
|
@ -775,3 +779,22 @@ func isTemporaryError(err error) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// exponentialTry executes f until it succeeds or ctx is Done.
|
||||
func exponentialTry(ctx context.Context, f func() error) error {
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
backoff.MaxElapsedTime = 0
|
||||
for {
|
||||
err := f()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff.NextBackOff()):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/sha512"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to
|
||||
|
@ -28,3 +29,14 @@ func HashPassword(password []byte) ([]byte, error) {
|
|||
func CheckPasswordHash(hash, password []byte) error {
|
||||
return bcrypt.CompareHashAndPassword(hash, password)
|
||||
}
|
||||
|
||||
// HashProto hashes a protobuf message. It sets `Deterministic` to true to ensure
|
||||
// the encoded message is always the same. (ie map order is lexographic)
|
||||
func HashProto(msg proto.Message) []byte {
|
||||
opts := proto.MarshalOptions{
|
||||
AllowPartial: true,
|
||||
Deterministic: true,
|
||||
}
|
||||
bs, _ := opts.Marshal(msg)
|
||||
return Hash("proto", bs)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,9 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
|
@ -79,3 +82,25 @@ func ExampleHash() {
|
|||
fmt.Println(hex.EncodeToString(digest))
|
||||
// Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1
|
||||
}
|
||||
|
||||
func TestHashProto(t *testing.T) {
|
||||
// This test will hash a protobuf message that has a map 1000 times
|
||||
// each attempt should result in the same hash if the output is
|
||||
// deterministic.
|
||||
var cur []byte
|
||||
for i := 0; i < 1000; i++ {
|
||||
s, err := structpb.NewStruct(map[string]interface{}{
|
||||
"1": "a", "2": "b", "3": "c", "4": "d",
|
||||
"5": "e", "6": "f", "7": "g", "8": "h",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if i == 0 {
|
||||
cur = HashProto(s)
|
||||
} else {
|
||||
nxt := HashProto(s)
|
||||
if !assert.Equal(t, cur, nxt) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue