diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 8fc910852..6e62271ea 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -195,7 +195,8 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) { syncer := databroker.NewSyncer("databroker", &syncerHandler{ client: client, src: src, - }, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config)))) + }, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))), + databroker.WithFastForward()) go func() { _ = syncer.Run(ctx) }() } diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index 0a43c7b3d..6d26c7a2e 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -17,7 +17,7 @@ import ( ) func TestConfigSource(t *testing.T) { - ctx, clearTimeout := context.WithTimeout(context.Background(), 5*time.Second) + ctx, clearTimeout := context.WithTimeout(context.Background(), 50*time.Second) defer clearTimeout() li, err := net.Listen("tcp", "127.0.0.1:0") diff --git a/internal/tests/xdserr/cmd/main.go b/internal/tests/xdserr/cmd/main.go new file mode 100644 index 000000000..4bb9e7e9c --- /dev/null +++ b/internal/tests/xdserr/cmd/main.go @@ -0,0 +1,197 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/base64" + "flag" + "fmt" + "math/rand" + "net/http" + "net/url" + "time" + + _ "github.com/envoyproxy/go-control-plane/envoy/api/v2" + _ "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3" + _ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_authz/v3" + _ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/lua/v3" + _ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + _ "github.com/envoyproxy/go-control-plane/envoy/extensions/upstreams/http/v3" + "github.com/google/uuid" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/tests/xdserr" + "github.com/pomerium/pomerium/pkg/grpc/config" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +var httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, +} + +func main() { + ctx := context.Background() + + graceful := flag.Bool("graceful", false, "gracefully grow") + domain := flag.String("domain", "localhost.pomerium.io", "domain to create routes in") + routes := flag.Int("routes", 100, "number of routes") + cycles := flag.Int("cycles", 1, "number of cycles") + change := flag.Int("change", 1, "number of change per cycle") + addr := flag.String("db-url", "http://localhost:5443", "databroker url") + key := flag.String("key", "", "databroker connection key") + to := flag.String("to", "", "route To url") + + flag.Parse() + + toURL, err := url.Parse(*to) + if err != nil { + log.Error(ctx).Err(err).Msg(*to) + return + } + + eg, ctx := errgroup.WithContext(ctx) + conn, err := grpcConn(ctx, *addr, *key) + if err != nil { + log.Error(ctx).Err(err).Msg("databroker grpc conn") + return + } + defer conn.Close() + + if *to == "" { + *to, err = xdserr.RunEcho(ctx) + if err != nil { + log.Error(ctx).Err(err).Msg("echo server") + return + } + } + log.Info(ctx).Str("url", *to).Msg("echo server") + + eg.Go(func() error { + return run(ctx, conn, *toURL, *domain, opts{ + graceful: *graceful, + nRoutes: *routes, + nIter: *cycles, + nMod: *change, + }) + }) + if err := eg.Wait(); err != nil { + log.Error(ctx).Err(err).Msg("altering config") + } +} + +type opts struct { + nRoutes, nIter, nMod int + graceful bool +} + +func run(ctx context.Context, conn *grpc.ClientConn, to url.URL, domain string, o opts) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + dbc := databroker.NewDataBrokerServiceClient(conn) + cfg := new(config.Config) + + for i := 0; i < o.nRoutes; i++ { + cfg.Routes = append(cfg.Routes, makeRoute(domain, to)) + } + + rand.Seed(time.Now().Unix()) + + changed := make([]int, o.nMod) + for i := 0; i < o.nIter; i++ { + for j := 0; j < o.nMod; j++ { + // nolint: gosec + idx := rand.Intn(o.nRoutes) + changed[j] = idx + cfg.Routes[idx] = makeRoute(domain, to) + } + log.Info(ctx).Ints("changed", changed).Msg("changed") + if err := saveAndLogConfig(ctx, dbc, cfg, o.graceful); err != nil { + return err + } + } + + if !o.graceful { + return waitHealthy(ctx, httpClient, cfg.Routes) + } + + return nil +} + +func grpcConn(ctx context.Context, addr, keyTxt string) (*grpc.ClientConn, error) { + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + + key, err := base64.StdEncoding.DecodeString(keyTxt) + if err != nil { + return nil, err + } + fmt.Println(keyTxt) + return xdserr.NewGRPCClientConn(ctx, &xdserr.Options{ + Address: u, + WithInsecure: u.Scheme == "http", + InsecureSkipVerify: true, + SignedJWTKey: key, + }) +} + +func makeRoute(domain string, to url.URL) *config.Route { + id := fmt.Sprintf("r-%s", uuid.NewString()) + return &config.Route{ + Name: id, + From: fmt.Sprintf("https://%s.%s", id, domain), + Path: "/", + PrefixRewrite: to.Path, + To: []string{to.String()}, + AllowPublicUnauthenticatedAccess: true, + } +} + +func saveAndLogConfig(ctx context.Context, client databroker.DataBrokerServiceClient, cfg *config.Config, graceful bool) error { + if err := saveConfig(ctx, client, cfg); err != nil { + return err + } + + if graceful { + return waitHealthy(ctx, httpClient, cfg.Routes) + } + + return nil +} + +func waitHealthy(ctx context.Context, client *http.Client, routes []*config.Route) error { + now := time.Now() + if err := xdserr.WaitForHealthy(ctx, httpClient, routes); err != nil { + return err + } + + log.Info(ctx). + Int("routes", len(routes)). + Str("elapsed", time.Since(now).String()). + Msg("ok") + + return nil +} +func saveConfig(ctx context.Context, client databroker.DataBrokerServiceClient, cfg *config.Config) error { + any := protoutil.NewAny(cfg) + r, err := client.Put(ctx, &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.GetTypeUrl(), + Id: "test_config", + Data: any, + }}) + if err != nil { + return err + } + log.Info(ctx).Uint64("version", r.GetRecord().GetVersion()).Msg("set config") + return nil +} diff --git a/internal/tests/xdserr/config.go b/internal/tests/xdserr/config.go new file mode 100644 index 000000000..571f6ea49 --- /dev/null +++ b/internal/tests/xdserr/config.go @@ -0,0 +1,63 @@ +// Package xdserr to load test configuration updates +package xdserr + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + adminv3 "github.com/envoyproxy/go-control-plane/envoy/admin/v3" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/pomerium/pomerium/internal/log" +) + +type cfgDump struct { + Configs []json.RawMessage `json:"configs"` +} + +// DumpConfig acquires current config from admin endpoint +func DumpConfig(ctx context.Context, adminURL string) (*adminv3.RoutesConfigDump, error) { + u, err := url.Parse(adminURL) + if err != nil { + return nil, err + } + u.Path = "/config_dump" + + req := http.Request{ + Method: http.MethodGet, + URL: u, + } + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + cfg := cfgDump{} + if err := json.NewDecoder(resp.Body).Decode(&cfg); err != nil { + return nil, err + } + + any, _ := anypb.New(&emptypb.Empty{}) + fmt.Println(protojson.Format(any)) + opts := &protojson.UnmarshalOptions{ + AllowPartial: true, + DiscardUnknown: true, + } + for i, data := range cfg.Configs { + any := new(anypb.Any) + if err = opts.Unmarshal(data, any); err != nil { + log.Error(ctx).Err(err).Int("config", i). + //RawJSON("data", data). + Msg("decode") + } else { + log.Info(ctx).Msg(any.TypeUrl) + } + } + return nil, err +} diff --git a/internal/tests/xdserr/echo.go b/internal/tests/xdserr/echo.go new file mode 100644 index 000000000..fcb3729ea --- /dev/null +++ b/internal/tests/xdserr/echo.go @@ -0,0 +1,34 @@ +package xdserr + +import ( + "context" + "fmt" + "net" + "net/http" + + "golang.org/x/sync/errgroup" +) + +func echo(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "pong") +} + +// RunEcho runs a test echo http server +func RunEcho(ctx context.Context) (string, error) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + mux := http.NewServeMux() + mux.HandleFunc("/", echo) + srv := http.Server{ + Handler: mux, + } + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { return srv.Serve(l) }) + eg.Go(func() error { + <-ctx.Done() + return srv.Close() + }) + return l.Addr().String(), nil +} diff --git a/internal/tests/xdserr/grpc.go b/internal/tests/xdserr/grpc.go new file mode 100644 index 000000000..03a00a6cb --- /dev/null +++ b/internal/tests/xdserr/grpc.go @@ -0,0 +1,123 @@ +package xdserr + +import ( + "context" + "crypto/tls" + "net" + "net/url" + "strconv" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpcutil" +) + +const ( + defaultGRPCSecurePort = 443 + defaultGRPCInsecurePort = 80 +) + +// Options contains options for connecting to a pomerium rpc service. +type Options struct { + // Address is the location of the service. e.g. "service.corp.example:8443" + Address *url.URL + // OverrideCertificateName overrides the server name used to verify the hostname on the + // returned certificates from the server. gRPC internals also use it to override the virtual + // hosting name if it is set. + OverrideCertificateName string + // CA specifies the base64 encoded TLS certificate authority to use. + CA string + // CAFile specifies the TLS certificate authority file to use. + CAFile string + // RequestTimeout specifies the timeout for individual RPC calls + RequestTimeout time.Duration + // ClientDNSRoundRobin enables or disables DNS resolver based load balancing + ClientDNSRoundRobin bool + + // WithInsecure disables transport security for this ClientConn. + // Note that transport security is required unless WithInsecure is set. + WithInsecure bool + + // InsecureSkipVerify skips destination hostname and ca check + InsecureSkipVerify bool + + // ServiceName specifies the service name for telemetry exposition + ServiceName string + + // SignedJWTKey is the JWT key to use for signing a JWT attached to metadata. + SignedJWTKey []byte +} + +// NewGRPCClientConn returns a new gRPC pomerium service client connection. +func NewGRPCClientConn(ctx context.Context, opts *Options, other ...grpc.DialOption) (*grpc.ClientConn, error) { + hostport := opts.Address.Host + // no colon exists in the connection string, assume one must be added manually + if _, _, err := net.SplitHostPort(hostport); err != nil { + if opts.Address.Scheme == "https" { + hostport = net.JoinHostPort(hostport, strconv.Itoa(defaultGRPCSecurePort)) + } else { + hostport = net.JoinHostPort(hostport, strconv.Itoa(defaultGRPCInsecurePort)) + } + } + + unaryClientInterceptors := []grpc.UnaryClientInterceptor{ + grpcTimeoutInterceptor(opts.RequestTimeout), + } + streamClientInterceptors := []grpc.StreamClientInterceptor{} + if opts.SignedJWTKey != nil { + unaryClientInterceptors = append(unaryClientInterceptors, grpcutil.WithUnarySignedJWT(opts.SignedJWTKey)) + streamClientInterceptors = append(streamClientInterceptors, grpcutil.WithStreamSignedJWT(opts.SignedJWTKey)) + } + + dialOptions := []grpc.DialOption{ + grpc.WithChainUnaryInterceptor(unaryClientInterceptors...), + grpc.WithChainStreamInterceptor(streamClientInterceptors...), + grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...), + grpc.WithDisableServiceConfig(), + } + + dialOptions = append(dialOptions, other...) + + if opts.WithInsecure { + dialOptions = append(dialOptions, grpc.WithInsecure()) + } else { + rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile) + if err != nil { + return nil, err + } + + cert := credentials.NewTLS(&tls.Config{ + // nolint: gosec + InsecureSkipVerify: opts.InsecureSkipVerify, + RootCAs: rootCAs, + MinVersion: tls.VersionTLS12, + }) + + // override allowed certificate name string, typically used when doing behind ingress connection + if opts.OverrideCertificateName != "" { + err := cert.OverrideServerName(opts.OverrideCertificateName) + if err != nil { + return nil, err + } + } + // finally add our credential + dialOptions = append(dialOptions, grpc.WithTransportCredentials(cert)) + } + + return grpc.DialContext(ctx, hostport, dialOptions...) +} + +// grpcTimeoutInterceptor enforces per-RPC request timeouts +func grpcTimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if timeout <= 0 { + return invoker(ctx, method, req, reply, cc, opts...) + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return invoker(ctx, method, req, reply, cc, opts...) + } +} diff --git a/internal/tests/xdserr/health.go b/internal/tests/xdserr/health.go new file mode 100644 index 000000000..e441b9e1a --- /dev/null +++ b/internal/tests/xdserr/health.go @@ -0,0 +1,50 @@ +package xdserr + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "net/url" + + "github.com/pomerium/pomerium/pkg/grpc/config" +) + +// WaitForHealthy waits until all routes are up +func WaitForHealthy(ctx context.Context, client *http.Client, routes []*config.Route) error { + healthy := 0 + for healthy != len(routes) && ctx.Err() == nil { + healthy = 0 + for _, r := range routes { + if err := checkHealth(ctx, client, r.From); err != nil { + continue + } + healthy++ + } + } + return ctx.Err() +} + +func checkHealth(ctx context.Context, client *http.Client, addr string) error { + u, err := url.Parse(addr) + if err != nil { + return err + } + req := http.Request{ + Method: http.MethodGet, + URL: u, + } + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + defer resp.Body.Close() + + if _, err = ioutil.ReadAll(resp.Body); err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} diff --git a/pkg/grpc/databroker/fast_forward.go b/pkg/grpc/databroker/fast_forward.go new file mode 100644 index 000000000..5d8d19ad7 --- /dev/null +++ b/pkg/grpc/databroker/fast_forward.go @@ -0,0 +1,107 @@ +package databroker + +import ( + "context" + "time" + + "github.com/pomerium/pomerium/internal/log" +) + +// fastForwardHandler will skip +type fastForwardHandler struct { + handler SyncerHandler + in chan *ffCmd + exec chan *ffCmd +} + +type ffCmd struct { + clearRecords bool + serverVersion uint64 + records []*Record +} + +func newFastForwardHandler(ctx context.Context, handler SyncerHandler) SyncerHandler { + ff := &fastForwardHandler{ + handler: handler, + in: make(chan *ffCmd, 20), + exec: make(chan *ffCmd), + } + go ff.runSelect(ctx) + go ff.runExec(ctx) + + return ff +} + +func (ff *fastForwardHandler) update(ctx context.Context, c *ffCmd) { + versions := make([]uint64, len(c.records)) + for i, r := range c.records { + versions[i] = r.Version + } + + now := time.Now() + ff.handler.UpdateRecords(ctx, c.serverVersion, c.records) + log.Info(ctx). + Dur("elapsed", time.Since(now)). + Uint64("server_version", c.serverVersion). + Uints64("versions", versions). + Msg("UpdateRecords") +} + +func (ff *fastForwardHandler) runSelect(ctx context.Context) { + var update *ffCmd + + for { + if update == nil { + select { + case <-ctx.Done(): + return + case update = <-ff.in: + } + } else { + select { + case <-ctx.Done(): + return + case update = <-ff.in: + case ff.exec <- update: + update = nil + } + } + } +} + +func (ff *fastForwardHandler) runExec(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case update := <-ff.exec: + if update.clearRecords { + ff.handler.ClearRecords(ctx) + continue + } + ff.update(ctx, update) + } + } +} + +func (ff *fastForwardHandler) GetDataBrokerServiceClient() DataBrokerServiceClient { + return ff.handler.GetDataBrokerServiceClient() +} + +func (ff *fastForwardHandler) ClearRecords(ctx context.Context) { + select { + case <-ctx.Done(): + log.Error(ctx). + Msg("ff_handler: ClearRecords: context canceled") + case ff.exec <- &ffCmd{clearRecords: true}: + } +} + +func (ff *fastForwardHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) { + select { + case <-ctx.Done(): + log.Error(ctx). + Msg("ff_handler: UpdateRecords: context canceled") + case ff.in <- &ffCmd{serverVersion: serverVersion, records: records}: + } +} diff --git a/pkg/grpc/databroker/fast_forward_test.go b/pkg/grpc/databroker/fast_forward_test.go new file mode 100644 index 000000000..626076fca --- /dev/null +++ b/pkg/grpc/databroker/fast_forward_test.go @@ -0,0 +1,74 @@ +package databroker + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockFF struct { + clear chan struct{} + update chan uint64 +} + +func (ff *mockFF) ClearRecords(ctx context.Context) { + ff.clear <- struct{}{} +} + +func (ff *mockFF) UpdateRecords(ctx context.Context, sv uint64, records []*Record) { + time.Sleep(time.Millisecond * time.Duration(rand.Intn(5))) + ff.update <- sv +} + +func (ff *mockFF) GetDataBrokerServiceClient() DataBrokerServiceClient { + return nil +} + +func (ff *mockFF) getUpdate(ctx context.Context) (uint64, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case sv := <-ff.update: + return sv, nil + } +} + +func TestFastForward(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + m := &mockFF{ + clear: make(chan struct{}), + update: make(chan uint64), + } + + f := newFastForwardHandler(ctx, m) + + for x := 0; x < 100; x++ { + n := rand.Intn(100) + 1 + for i := 1; i <= n; i++ { + f.UpdateRecords(ctx, uint64(i), nil) + } + + var prev uint64 + assert.Eventually(t, func() bool { + sv, err := m.getUpdate(ctx) + require.NoError(t, err) + assert.Less(t, prev, sv) + prev = sv + t.Log(x, sv) + return int(sv) == n + }, time.Second, time.Millisecond*10) + + f.ClearRecords(ctx) + select { + case <-ctx.Done(): + t.Error("timed out") + case <-m.clear: + } + } +} diff --git a/pkg/grpc/databroker/syncer.go b/pkg/grpc/databroker/syncer.go index 0554fe1db..4abd16bd5 100644 --- a/pkg/grpc/databroker/syncer.go +++ b/pkg/grpc/databroker/syncer.go @@ -15,7 +15,8 @@ import ( ) type syncerConfig struct { - typeURL string + typeURL string + withFastForward bool } // A SyncerOption customizes the syncer configuration. @@ -36,6 +37,15 @@ func WithTypeURL(typeURL string) SyncerOption { } } +// WithFastForward in case updates are coming faster then Update can process them, +// will skip older records to maintain an update rate. +// Use for entries that represent a full state snapshot i.e. Config +func WithFastForward() SyncerOption { + return func(cfg *syncerConfig) { + cfg.withFastForward = true + } +} + // A SyncerHandler receives sync events from the Syncer. type SyncerHandler interface { GetDataBrokerServiceClient() DataBrokerServiceClient @@ -67,7 +77,7 @@ func NewSyncer(id string, handler SyncerHandler, options ...SyncerOption) *Synce bo := backoff.NewExponentialBackOff() bo.MaxElapsedTime = 0 - return &Syncer{ + s := &Syncer{ cfg: getSyncerConfig(options...), handler: handler, backoff: bo, @@ -77,6 +87,10 @@ func NewSyncer(id string, handler SyncerHandler, options ...SyncerOption) *Synce id: id, } + if s.cfg.withFastForward { + s.handler = newFastForwardHandler(s.logCtx(closeCtx), handler) + } + return s } // Close closes the Syncer. @@ -169,7 +183,6 @@ func (syncer *Syncer) sync(ctx context.Context) error { syncer.recordVersion = res.GetRecord().GetVersion() if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() { ctx := logCtxRec(ctx, rec) - log.Debug(ctx).Msg("update records") syncer.handler.UpdateRecords( context.WithValue(ctx, contextkeys.UpdateRecordsVersion, rec.GetVersion()), syncer.serverVersion, []*Record{rec})