diff --git a/internal/controlplane/grpc_xds.go b/internal/controlplane/grpc_xds.go deleted file mode 100644 index d404f3e53..000000000 --- a/internal/controlplane/grpc_xds.go +++ /dev/null @@ -1,136 +0,0 @@ -package controlplane - -import ( - "context" - "encoding/json" - "errors" - "fmt" - - envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" - "golang.org/x/sync/errgroup" - - "github.com/pomerium/pomerium/internal/log" -) - -func (srv *Server) registerXDSHandlers() { - envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv) -} - -// StreamAggregatedResources streams xDS resources based on incoming discovery requests. -// -// This is setup as 3 concurrent goroutines: -// - The first retrieves the requests from the client. -// - The third sends responses back to the client. -// - The second waits for either the client to request a new resource type -// or for the config to have been updated -// - in either case, we loop over all of the current client versions -// and if any of them are different from the current version, we send -// the updated resource -func (srv *Server) StreamAggregatedResources(stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error { - incoming := make(chan *envoy_service_discovery_v3.DiscoveryRequest) - outgoing := make(chan *envoy_service_discovery_v3.DiscoveryResponse) - - eg, ctx := errgroup.WithContext(stream.Context()) - // receive requests - eg.Go(func() error { - return srv.streamAggregatedResourcesIncomingStep(ctx, stream, incoming) - }) - eg.Go(func() error { - return srv.streamAggregatedResourcesProcessStep(ctx, incoming, outgoing) - }) - // send responses - eg.Go(func() error { - return srv.streamAggregatedResourcesOutgoingStep(ctx, stream, outgoing) - }) - return eg.Wait() -} - -func (srv *Server) streamAggregatedResourcesIncomingStep( - ctx context.Context, - stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer, - incoming chan<- *envoy_service_discovery_v3.DiscoveryRequest, -) error { - for { - req, err := stream.Recv() - if err != nil { - return err - } - - select { - case incoming <- req: - case <-ctx.Done(): - return ctx.Err() - } - } -} - -func (srv *Server) streamAggregatedResourcesProcessStep( - ctx context.Context, - incoming <-chan *envoy_service_discovery_v3.DiscoveryRequest, - outgoing chan<- *envoy_service_discovery_v3.DiscoveryResponse, -) error { - versions := map[string]string{} - - for { - select { - case req := <-incoming: - if req.ErrorDetail != nil { - bs, _ := json.Marshal(req.ErrorDetail.Details) - log.Error(). - Err(errors.New(req.ErrorDetail.Message)). - Int32("code", req.ErrorDetail.Code). - RawJSON("details", bs).Msg("error applying configuration") - continue - } - - // update the currently stored version - // if this version is different from the current version - // we will send the response below - versions[req.TypeUrl] = req.VersionInfo - case <-srv.configUpdated: - case <-ctx.Done(): - return ctx.Err() - } - - current := srv.currentConfig.Load() - for typeURL, version := range versions { - // the versions are different, so the envoy config needs to be updated - if version != fmt.Sprint(current.version) { - res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, ¤t.Options) - if err != nil { - return err - } - select { - case outgoing <- res: - case <-ctx.Done(): - return ctx.Err() - } - } - } - } -} - -func (srv *Server) streamAggregatedResourcesOutgoingStep( - ctx context.Context, - stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer, - outgoing <-chan *envoy_service_discovery_v3.DiscoveryResponse, -) error { - for { - var res *envoy_service_discovery_v3.DiscoveryResponse - select { - case res = <-outgoing: - case <-ctx.Done(): - return ctx.Err() - } - - err := stream.Send(res) - if err != nil { - return err - } - } -} - -// DeltaAggregatedResources is not implemented. -func (srv *Server) DeltaAggregatedResources(in envoy_service_discovery_v3.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error { - return fmt.Errorf("method DeltaAggregatedResources not implemented") -} diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 560f318b6..a60a16d94 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" "github.com/gorilla/mux" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -14,6 +15,7 @@ import ( "google.golang.org/grpc/reflection" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/controlplane/xdsmgr" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry/requestid" @@ -46,15 +48,13 @@ type Server struct { HTTPRouter *mux.Router currentConfig atomicVersionedOptions - configUpdated chan struct{} name string + xdsmgr *xdsmgr.Manager } // NewServer creates a new Server. Listener ports are chosen by the OS. func NewServer(name string) (*Server, error) { - srv := &Server{ - configUpdated: make(chan struct{}, 1), - } + srv := &Server{} srv.currentConfig.Store(versionedOptions{}) var err error @@ -73,7 +73,6 @@ func NewServer(name string) (*Server, error) { grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si), ) reflection.Register(srv.GRPCServer) - srv.registerXDSHandlers() srv.registerAccessLogHandlers() // setup HTTP @@ -85,6 +84,9 @@ func NewServer(name string) (*Server, error) { srv.HTTPRouter = mux.NewRouter() srv.addHTTPMiddleware() + srv.xdsmgr = xdsmgr.NewManager(srv.buildDiscoveryResources()) + envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv.xdsmgr) + return srv, nil } @@ -150,14 +152,10 @@ func (srv *Server) Run(ctx context.Context) error { // OnConfigChange updates the pomerium config options. func (srv *Server) OnConfigChange(cfg *config.Config) { - select { - case <-srv.configUpdated: - default: - } prev := srv.currentConfig.Load() srv.currentConfig.Store(versionedOptions{ Options: *cfg.Options, version: prev.version + 1, }) - srv.configUpdated <- struct{}{} + srv.xdsmgr.Update(srv.buildDiscoveryResources()) } diff --git a/internal/controlplane/xds.go b/internal/controlplane/xds.go index 20468ec7b..ea274fa9b 100644 --- a/internal/controlplane/xds.go +++ b/internal/controlplane/xds.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "encoding/hex" "encoding/pem" "fmt" "io/ioutil" @@ -19,51 +20,40 @@ import ( envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3" envoy_extensions_transport_sockets_tls_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/any" "golang.org/x/net/nettest" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/cryptutil" ) -func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) { - switch typeURL { - case "type.googleapis.com/envoy.config.listener.v3.Listener": - listeners := buildListeners(options) - anys := make([]*any.Any, len(listeners)) - for i, listener := range listeners { - a, err := ptypes.MarshalAny(listener) - if err != nil { - return nil, status.Errorf(codes.Internal, "error marshaling type to any: %v", err) - } - anys[i] = a - } - return &envoy_service_discovery_v3.DiscoveryResponse{ - VersionInfo: version, - Resources: anys, - TypeUrl: typeURL, - }, nil - case "type.googleapis.com/envoy.config.cluster.v3.Cluster": - clusters := srv.buildClusters(options) - anys := make([]*any.Any, len(clusters)) - for i, cluster := range clusters { - a, err := ptypes.MarshalAny(cluster) - if err != nil { - return nil, status.Errorf(codes.Internal, "error marshaling type to any: %v", err) - } - anys[i] = a - } - return &envoy_service_discovery_v3.DiscoveryResponse{ - VersionInfo: version, - Resources: anys, - TypeUrl: typeURL, - }, nil - default: - return nil, status.Errorf(codes.Internal, "received request for unknown discovery request type: %s", typeURL) +const ( + clusterTypeURL = "type.googleapis.com/envoy.config.cluster.v3.Cluster" + listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener" +) + +func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discovery_v3.Resource { + resources := map[string][]*envoy_service_discovery_v3.Resource{} + cfg := srv.currentConfig.Load() + for _, cluster := range srv.buildClusters(&cfg.Options) { + any, _ := anypb.New(cluster) + resources[clusterTypeURL] = append(resources[clusterTypeURL], &envoy_service_discovery_v3.Resource{ + Name: cluster.Name, + Version: hex.EncodeToString(cryptutil.HashProto(cluster)), + Resource: any, + }) } + for _, listener := range buildListeners(&cfg.Options) { + any, _ := anypb.New(listener) + resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{ + Name: listener.Name, + Version: hex.EncodeToString(cryptutil.HashProto(listener)), + Resource: any, + }) + } + return resources } func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { @@ -82,7 +72,7 @@ func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.Acces return nil } - tc, _ := ptypes.MarshalAny(&envoy_extensions_access_loggers_grpc_v3.HttpGrpcAccessLogConfig{ + tc := marshalAny(&envoy_extensions_access_loggers_grpc_v3.HttpGrpcAccessLogConfig{ CommonConfig: &envoy_extensions_access_loggers_grpc_v3.CommonGrpcAccessLogConfig{ LogName: "ingress-http", GrpcService: &envoy_config_core_v3.GrpcService{ @@ -235,3 +225,12 @@ func getRootCertificateAuthority() (string, error) { } return rootCABundle.value, nil } + +func marshalAny(msg proto.Message) *anypb.Any { + any := new(anypb.Any) + _ = anypb.MarshalFrom(any, msg, proto.MarshalOptions{ + AllowPartial: true, + Deterministic: true, + }) + return any +} diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index 066c42c7c..7cc3ef2e6 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -104,7 +104,7 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e }, Sni: sni, } - tlsConfig, _ := ptypes.MarshalAny(tlsContext) + tlsConfig := marshalAny(tlsContext) return &envoy_config_core_v3.TransportSocket{ Name: "tls", ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ @@ -144,7 +144,7 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate)) } - tlsConfig, _ := ptypes.MarshalAny(tlsContext) + tlsConfig := marshalAny(tlsContext) return &envoy_config_core_v3.TransportSocket{ Name: "tls", ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 83c9464b5..0e67a106e 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -30,7 +30,7 @@ import ( var disableExtAuthz *any.Any func init() { - disableExtAuthz, _ = ptypes.MarshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute{ + disableExtAuthz = marshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute{ Override: &envoy_extensions_filters_http_ext_authz_v3.ExtAuthzPerRoute_Disabled{ Disabled: true, }, @@ -67,7 +67,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen } } - tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty)) + tlsInspectorCfg := marshalAny(new(emptypb.Empty)) li := &envoy_config_listener_v3.Listener{ Name: "https-ingress", Address: buildAddress(options.Addr, 443), @@ -90,7 +90,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen } tlsContext := buildDownstreamTLSContext(options, tlsDomain) if tlsContext != nil { - tlsConfig, _ := ptypes.MarshalAny(tlsContext) + tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ Name: "tls", ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ @@ -161,7 +161,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str grpcClientTimeout = ptypes.DurationProto(30 * time.Second) } - extAuthZ, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthz{ + extAuthZ := marshalAny(&envoy_extensions_filters_http_ext_authz_v3.ExtAuthz{ StatusOnError: &envoy_type_v3.HttpStatus{ Code: envoy_type_v3.StatusCode_InternalServerError, }, @@ -178,13 +178,13 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str IncludePeerCertificate: true, }) - extAuthzSetCookieLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ + extAuthzSetCookieLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ InlineCode: luascripts.ExtAuthzSetCookie, }) - cleanUpstreamLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ + cleanUpstreamLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ InlineCode: luascripts.CleanUpstream, }) - removeImpersonateHeadersLua, _ := ptypes.MarshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ + removeImpersonateHeadersLua := marshalAny(&envoy_extensions_filters_http_lua_v3.Lua{ InlineCode: luascripts.RemoveImpersonateHeaders, }) @@ -193,7 +193,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str maxStreamDuration = ptypes.DurationProto(options.WriteTimeout) } - tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ + tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "ingress", RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ @@ -265,7 +265,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen } } - tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty)) + tlsInspectorCfg := marshalAny(new(emptypb.Empty)) li := &envoy_config_listener_v3.Listener{ Name: "grpc-ingress", Address: buildAddress(options.GRPCAddr, 443), @@ -287,7 +287,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen } tlsContext := buildDownstreamTLSContext(options, tlsDomain) if tlsContext != nil { - tlsConfig, _ := ptypes.MarshalAny(tlsContext) + tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ Name: "tls", ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ @@ -302,7 +302,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen } func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { - tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ + tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "grpc_ingress", // limit request first byte to last byte time diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index bbae1a97c..de64fbb6e 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -3,6 +3,7 @@ package controlplane import ( "fmt" "net/url" + "sort" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" @@ -271,9 +272,15 @@ func mkEnvoyHeader(k, v string) *envoy_config_core_v3.HeaderValueOption { } func toEnvoyHeaders(headers map[string]string) []*envoy_config_core_v3.HeaderValueOption { + var ks []string + for k := range headers { + ks = append(ks, k) + } + sort.Strings(ks) + envoyHeaders := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(headers)) - for k, v := range headers { - envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, v)) + for _, k := range ks { + envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, headers[k])) } return envoyHeaders } diff --git a/internal/controlplane/xdsmgr/xdsmgr.go b/internal/controlplane/xdsmgr/xdsmgr.go new file mode 100644 index 000000000..24cadbe58 --- /dev/null +++ b/internal/controlplane/xdsmgr/xdsmgr.go @@ -0,0 +1,229 @@ +// Package xdsmgr implements a resource discovery manager for envoy. +package xdsmgr + +import ( + "encoding/json" + "errors" + "sync" + + envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/google/uuid" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/signal" +) + +type streamState struct { + typeURL string + clientResourceVersions map[string]string + unsubscribedResources map[string]struct{} +} + +var onHandleDeltaRequest = func(state *streamState) {} + +// A Manager manages xDS resources. +type Manager struct { + signal *signal.Signal + + mu sync.Mutex + nonce string + resources map[string][]*envoy_service_discovery_v3.Resource +} + +// NewManager creates a new Manager. +func NewManager(resources map[string][]*envoy_service_discovery_v3.Resource) *Manager { + return &Manager{ + signal: signal.New(), + nonce: uuid.New().String(), + resources: resources, + } +} + +// DeltaAggregatedResources implements the increment xDS server. +func (mgr *Manager) DeltaAggregatedResources( + stream envoy_service_discovery_v3.AggregatedDiscoveryService_DeltaAggregatedResourcesServer, +) error { + ch := mgr.signal.Bind() + defer mgr.signal.Unbind(ch) + + stateByTypeURL := map[string]*streamState{} + + getDeltaResponse := func(typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse { + mgr.mu.Lock() + defer mgr.mu.Unlock() + + state, ok := stateByTypeURL[typeURL] + if !ok { + return nil + } + + res := &envoy_service_discovery_v3.DeltaDiscoveryResponse{ + TypeUrl: typeURL, + Nonce: mgr.nonce, + } + seen := map[string]struct{}{} + for _, resource := range mgr.resources[typeURL] { + seen[resource.Name] = struct{}{} + if resource.Version != state.clientResourceVersions[resource.Name] { + res.Resources = append(res.Resources, resource) + } + } + for name := range state.clientResourceVersions { + _, ok := seen[name] + if !ok { + res.RemovedResources = append(res.RemovedResources, name) + } + } + + if len(res.Resources) == 0 && len(res.RemovedResources) == 0 { + return nil + } + + return res + } + + handleDeltaRequest := func(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { + mgr.mu.Lock() + defer mgr.mu.Unlock() + + state, ok := stateByTypeURL[req.GetTypeUrl()] + if !ok { + // first time we've seen a message for this type URL. + state = &streamState{ + typeURL: req.GetTypeUrl(), + clientResourceVersions: req.GetInitialResourceVersions(), + unsubscribedResources: make(map[string]struct{}), + } + if state.clientResourceVersions == nil { + state.clientResourceVersions = make(map[string]string) + } + stateByTypeURL[req.GetTypeUrl()] = state + } + + switch { + case req.GetResponseNonce() == "": + // neither an ACK or a NACK + case req.GetErrorDetail() != nil: + // a NACK + bs, _ := json.Marshal(req.ErrorDetail.Details) + log.Error(). + Err(errors.New(req.ErrorDetail.Message)). + Int32("code", req.ErrorDetail.Code). + RawJSON("details", bs).Msg("error applying configuration") + case req.GetResponseNonce() == mgr.nonce: + // an ACK for the last response + // - set the client resource versions to the current resource versions + state.clientResourceVersions = make(map[string]string) + for _, resource := range mgr.resources[req.GetTypeUrl()] { + state.clientResourceVersions[resource.Name] = resource.Version + } + default: + // an ACK for a response that's not the last response + } + + // update subscriptions + for _, name := range req.GetResourceNamesSubscribe() { + delete(state.unsubscribedResources, name) + } + for _, name := range req.GetResourceNamesUnsubscribe() { + state.unsubscribedResources[name] = struct{}{} + // from the docs: + // NOTE: the server must respond with all resources listed in + // resource_names_subscribe, even if it believes the client has + // the most recent version of them. The reason: the client may + // have dropped them, but then regained interest before it had + // a chance to send the unsubscribe message. + // so we reset the version to treat it like a new version + delete(state.clientResourceVersions, name) + } + + onHandleDeltaRequest(state) + } + + incoming := make(chan *envoy_service_discovery_v3.DeltaDiscoveryRequest) + outgoing := make(chan *envoy_service_discovery_v3.DeltaDiscoveryResponse) + eg, ctx := errgroup.WithContext(stream.Context()) + // 1. receive all incoming messages + eg.Go(func() error { + for { + req, err := stream.Recv() + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case incoming <- req: + } + } + }) + // 2. handle incoming requests or resource changes + eg.Go(func() error { + for { + var typeURLs []string + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-incoming: + handleDeltaRequest(req) + typeURLs = []string{req.GetTypeUrl()} + case <-ch: + mgr.mu.Lock() + for typeURL := range mgr.resources { + typeURLs = append(typeURLs, typeURL) + } + mgr.mu.Unlock() + } + + for _, typeURL := range typeURLs { + res := getDeltaResponse(typeURL) + if res == nil { + continue + } + + select { + case <-ctx.Done(): + return ctx.Err() + case outgoing <- res: + } + } + } + }) + // 3. send all outgoing messages + eg.Go(func() error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case res := <-outgoing: + err := stream.Send(res) + if err != nil { + return err + } + } + } + }) + return eg.Wait() +} + +// StreamAggregatedResources is not implemented. +func (mgr *Manager) StreamAggregatedResources( + stream envoy_service_discovery_v3.AggregatedDiscoveryService_StreamAggregatedResourcesServer, +) error { + return status.Errorf(codes.Unimplemented, "method StreamAggregatedResources not implemented") +} + +// Update updates the state of resources. If any changes are made they will be pushed to any listening +// streams. For each TypeURL the list of resources should be the complete list of resources. +func (mgr *Manager) Update(resources map[string][]*envoy_service_discovery_v3.Resource) { + mgr.mu.Lock() + mgr.nonce = uuid.New().String() + mgr.resources = resources + mgr.mu.Unlock() + + mgr.signal.Broadcast() +} diff --git a/internal/controlplane/xdsmgr/xdsmgr_test.go b/internal/controlplane/xdsmgr/xdsmgr_test.go new file mode 100644 index 000000000..d53b8f185 --- /dev/null +++ b/internal/controlplane/xdsmgr/xdsmgr_test.go @@ -0,0 +1,116 @@ +package xdsmgr + +import ( + "context" + "net" + "testing" + "time" + + envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/test/bufconn" + + "github.com/pomerium/pomerium/internal/signal" +) + +const bufSize = 1024 * 1024 + +func TestManager(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + typeURL := "example.com/example" + + stateChanged := signal.New() + origOnHandleDeltaRequest := onHandleDeltaRequest + defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }() + onHandleDeltaRequest = func(state *streamState) { + stateChanged.Broadcast() + } + + srv := grpc.NewServer() + mgr := NewManager(map[string][]*envoy_service_discovery_v3.Resource{ + typeURL: { + {Name: "r1", Version: "1"}, + }, + }) + envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv, mgr) + + li := bufconn.Listen(bufSize) + go func() { _ = srv.Serve(li) }() + + cc, err := grpc.Dial("test", + grpc.WithInsecure(), + grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { + return li.Dial() + })) + if !assert.NoError(t, err) { + return + } + defer func() { _ = cc.Close() }() + + client := envoy_service_discovery_v3.NewAggregatedDiscoveryServiceClient(cc) + t.Run("stream is disabled", func(t *testing.T) { + stream, err := client.StreamAggregatedResources(ctx) + if !assert.NoError(t, err) { + return + } + _, err = stream.Recv() + assert.Error(t, err, "only delta should be implemented") + assert.Equal(t, codes.Unimplemented, grpc.Code(err)) + }) + + t.Run("updates", func(t *testing.T) { + stream, err := client.DeltaAggregatedResources(ctx) + if !assert.NoError(t, err) { + return + } + + ch := stateChanged.Bind() + defer stateChanged.Unbind(ch) + ack := func(nonce string) { + err = stream.Send(&envoy_service_discovery_v3.DeltaDiscoveryRequest{ + TypeUrl: typeURL, + ResponseNonce: nonce, + }) + assert.NoError(t, err) + select { + case <-ctx.Done(): + t.Fatal(ctx.Err()) + case <-ch: + } + } + + ack("") + + msg, err := stream.Recv() + assert.NoError(t, err) + assert.NotEmpty(t, msg.GetNonce(), "nonce should not be empty") + assert.Equal(t, []*envoy_service_discovery_v3.Resource{ + {Name: "r1", Version: "1"}, + }, msg.GetResources()) + ack(msg.Nonce) + + mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{ + typeURL: {{Name: "r1", Version: "2"}}, + }) + + msg, err = stream.Recv() + assert.NoError(t, err) + assert.Equal(t, []*envoy_service_discovery_v3.Resource{ + {Name: "r1", Version: "2"}, + }, msg.GetResources()) + ack(msg.Nonce) + + mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{ + typeURL: nil, + }) + + msg, err = stream.Recv() + assert.NoError(t, err) + assert.Equal(t, []string{"r1"}, msg.GetRemovedResources()) + ack(msg.Nonce) + }) +} diff --git a/internal/envoy/envoy.go b/internal/envoy/envoy.go index 0afb5ba14..85965b249 100644 --- a/internal/envoy/envoy.go +++ b/internal/envoy/envoy.go @@ -229,7 +229,7 @@ func (srv *Server) buildBootstrapConfig() ([]byte, error) { dynamicCfg := &envoy_config_bootstrap_v3.Bootstrap_DynamicResources{ AdsConfig: &envoy_config_core_v3.ApiConfigSource{ - ApiType: envoy_config_core_v3.ApiConfigSource_ApiType(envoy_config_core_v3.ApiConfigSource_ApiType_value["GRPC"]), + ApiType: envoy_config_core_v3.ApiConfigSource_ApiType(envoy_config_core_v3.ApiConfigSource_ApiType_value["DELTA_GRPC"]), TransportApiVersion: envoy_config_core_v3.ApiVersion_V3, GrpcServices: []*envoy_config_core_v3.GrpcService{ { diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 97ae0562c..bd88e4360 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/cenkalti/backoff/v4" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/google/btree" @@ -578,29 +579,31 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error { return err } - res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{ - Type: any.GetTypeUrl(), - }) - if err != nil { - return fmt.Errorf("error getting all directory users: %w", err) - } - - mgr.directoryUsers = map[string]*directory.User{} - for _, record := range res.GetRecords() { - var pbDirectoryUser directory.User - err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser) + return exponentialTry(ctx, func() error { + res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{ + Type: any.GetTypeUrl(), + }) if err != nil { - return fmt.Errorf("error unmarshaling directory user: %w", err) + return fmt.Errorf("error getting all directory users: %w", err) } - mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser - mgr.directoryUsersRecordVersion = record.GetVersion() - } - mgr.directoryUsersServerVersion = res.GetServerVersion() + mgr.directoryUsers = map[string]*directory.User{} + for _, record := range res.GetRecords() { + var pbDirectoryUser directory.User + err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryUser) + if err != nil { + return fmt.Errorf("error unmarshaling directory user: %w", err) + } - mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users") + mgr.directoryUsers[pbDirectoryUser.GetId()] = &pbDirectoryUser + mgr.directoryUsersRecordVersion = record.GetVersion() + } + mgr.directoryUsersServerVersion = res.GetServerVersion() - return nil + mgr.log.Info().Int("count", len(mgr.directoryUsers)).Msg("initialized directory users") + + return nil + }) } func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory.User) error { @@ -648,30 +651,31 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error { if err != nil { return err } - - res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{ - Type: any.GetTypeUrl(), - }) - if err != nil { - return fmt.Errorf("error getting all directory groups: %w", err) - } - - mgr.directoryGroups = map[string]*directory.Group{} - for _, record := range res.GetRecords() { - var pbDirectoryGroup directory.Group - err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup) + return exponentialTry(ctx, func() error { + res, err := databroker.InitialSync(ctx, mgr.cfg.Load().dataBrokerClient, &databroker.SyncRequest{ + Type: any.GetTypeUrl(), + }) if err != nil { - return fmt.Errorf("error unmarshaling directory group: %w", err) + return fmt.Errorf("error getting all directory groups: %w", err) } - mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup - mgr.directoryGroupsRecordVersion = record.GetVersion() - } - mgr.directoryGroupsServerVersion = res.GetServerVersion() + mgr.directoryGroups = map[string]*directory.Group{} + for _, record := range res.GetRecords() { + var pbDirectoryGroup directory.Group + err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup) + if err != nil { + return fmt.Errorf("error unmarshaling directory group: %w", err) + } - mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups") + mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup + mgr.directoryGroupsRecordVersion = record.GetVersion() + } + mgr.directoryGroupsServerVersion = res.GetServerVersion() - return nil + mgr.log.Info().Int("count", len(mgr.directoryGroups)).Msg("initialized directory groups") + + return nil + }) } func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error { @@ -775,3 +779,22 @@ func isTemporaryError(err error) bool { } return false } + +// exponentialTry executes f until it succeeds or ctx is Done. +func exponentialTry(ctx context.Context, f func() error) error { + backoff := backoff.NewExponentialBackOff() + backoff.MaxElapsedTime = 0 + for { + err := f() + if err == nil { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff.NextBackOff()): + } + continue + } +} diff --git a/pkg/cryptutil/hash.go b/pkg/cryptutil/hash.go index 14561de22..4ff7e76c8 100644 --- a/pkg/cryptutil/hash.go +++ b/pkg/cryptutil/hash.go @@ -5,6 +5,7 @@ import ( "crypto/sha512" "golang.org/x/crypto/bcrypt" + "google.golang.org/protobuf/proto" ) // Hash generates a hash of data using HMAC-SHA-512/256. The tag is intended to @@ -28,3 +29,14 @@ func HashPassword(password []byte) ([]byte, error) { func CheckPasswordHash(hash, password []byte) error { return bcrypt.CompareHashAndPassword(hash, password) } + +// HashProto hashes a protobuf message. It sets `Deterministic` to true to ensure +// the encoded message is always the same. (ie map order is lexographic) +func HashProto(msg proto.Message) []byte { + opts := proto.MarshalOptions{ + AllowPartial: true, + Deterministic: true, + } + bs, _ := opts.Marshal(msg) + return Hash("proto", bs) +} diff --git a/pkg/cryptutil/hash_test.go b/pkg/cryptutil/hash_test.go index f3a8c0c2c..c22924266 100644 --- a/pkg/cryptutil/hash_test.go +++ b/pkg/cryptutil/hash_test.go @@ -8,6 +8,9 @@ import ( "io/ioutil" "os" "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" ) func TestPasswordHashing(t *testing.T) { @@ -79,3 +82,25 @@ func ExampleHash() { fmt.Println(hex.EncodeToString(digest)) // Output: 9f4c795d8ae5c207f19184ccebee6a606c1fdfe509c793614066d613580f03e1 } + +func TestHashProto(t *testing.T) { + // This test will hash a protobuf message that has a map 1000 times + // each attempt should result in the same hash if the output is + // deterministic. + var cur []byte + for i := 0; i < 1000; i++ { + s, err := structpb.NewStruct(map[string]interface{}{ + "1": "a", "2": "b", "3": "c", "4": "d", + "5": "e", "6": "f", "7": "g", "8": "h", + }) + assert.NoError(t, err) + if i == 0 { + cur = HashProto(s) + } else { + nxt := HashProto(s) + if !assert.Equal(t, cur, nxt) { + return + } + } + } +}