From 7eb7861f2ca9fd7e76ada9cbdaf7536b4c8ddfbf Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 29 Dec 2023 10:18:08 -0700 Subject: [PATCH] core: fix graceful stop (#4865) * core/grpc: fix graceful stop * core/http: add graceful stop serve --- databroker/cache.go | 8 +-- internal/controlplane/server.go | 45 +------------- pkg/contextutil/contextutil.go | 22 +++++++ pkg/grpcutil/serve.go | 38 ++++++++++++ pkg/grpcutil/serve_test.go | 100 ++++++++++++++++++++++++++++++++ pkg/httputil/serve.go | 48 +++++++++++++++ pkg/httputil/serve_test.go | 98 +++++++++++++++++++++++++++++++ 7 files changed, 311 insertions(+), 48 deletions(-) create mode 100644 pkg/grpcutil/serve.go create mode 100644 pkg/grpcutil/serve_test.go create mode 100644 pkg/httputil/serve.go create mode 100644 pkg/httputil/serve_test.go diff --git a/databroker/cache.go b/databroker/cache.go index e4cbc9536..1762e5108 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net" + "time" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" @@ -128,12 +129,7 @@ func (c *DataBroker) Register(grpcServer *grpc.Server) { func (c *DataBroker) Run(ctx context.Context) error { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { - return c.localGRPCServer.Serve(c.localListener) - }) - eg.Go(func() error { - <-ctx.Done() - c.localGRPCServer.Stop() - return nil + return grpcutil.ServeWithGracefulStop(ctx, c.localGRPCServer, c.localListener, time.Second*5) }) eg.Go(func() error { return c.manager.Run(ctx) diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 3cc6cbb3c..cbdf92a0f 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -33,6 +33,7 @@ import ( "github.com/pomerium/pomerium/pkg/envoy/files" pom_grpc "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/httputil" ) // A Service can be mounted on the control plane. @@ -180,31 +181,7 @@ func (srv *Server) Run(ctx context.Context) error { // start the gRPC server eg.Go(func() error { log.Info(ctx).Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server") - return srv.GRPCServer.Serve(srv.GRPCListener) - }) - - // gracefully stop the gRPC server on context cancellation - eg.Go(func() error { - <-ctx.Done() - - ctx, cancel := context.WithCancel(ctx) - ctx, cleanup := context.WithTimeout(ctx, time.Second*5) - defer cleanup() - - go func() { - srv.GRPCServer.GracefulStop() - cancel() - }() - - go func() { - <-ctx.Done() - srv.GRPCServer.Stop() - cancel() - }() - - <-ctx.Done() - - return nil + return grpcutil.ServeWithGracefulStop(ctx, srv.GRPCServer, srv.GRPCListener, time.Second*5) }) for _, entry := range []struct { @@ -219,29 +196,13 @@ func (srv *Server) Run(ctx context.Context) error { {"metrics", srv.MetricsListener, srv.MetricsRouter}, } { entry := entry - hsrv := (&http.Server{ - BaseContext: func(li net.Listener) context.Context { - return ctx - }, - Handler: entry.Handler, - }) // start the HTTP server eg.Go(func() error { log.Info(ctx). Str("addr", entry.Listener.Addr().String()). Msgf("starting control-plane %s server", entry.Name) - return hsrv.Serve(entry.Listener) - }) - - // gracefully stop the HTTP server on context cancellation - eg.Go(func() error { - <-ctx.Done() - - ctx, cleanup := context.WithTimeout(ctx, time.Second*5) - defer cleanup() - - return hsrv.Shutdown(ctx) + return httputil.ServeWithGracefulStop(ctx, entry.Handler, entry.Listener, time.Second*5) }) } diff --git a/pkg/contextutil/contextutil.go b/pkg/contextutil/contextutil.go index cd12f17e6..a4770a5d7 100644 --- a/pkg/contextutil/contextutil.go +++ b/pkg/contextutil/contextutil.go @@ -64,3 +64,25 @@ func (mc *mergedCtx) Value(key interface{}) interface{} { } return mc.doneCtx.Value(key) } + +type onlyValues struct { + context.Context +} + +// OnlyValues returns a derived context that removes deadlines and cancellation, +// but keeps values. +func OnlyValues(ctx context.Context) context.Context { + return onlyValues{ctx} +} + +func (o onlyValues) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +func (o onlyValues) Done() <-chan struct{} { + return nil +} + +func (o onlyValues) Err() error { + return nil +} diff --git a/pkg/grpcutil/serve.go b/pkg/grpcutil/serve.go new file mode 100644 index 000000000..16ed15b96 --- /dev/null +++ b/pkg/grpcutil/serve.go @@ -0,0 +1,38 @@ +package grpcutil + +import ( + "context" + "net" + "time" + + "google.golang.org/grpc" +) + +// ServeWithGracefulStop serves the gRPC listener until ctx.Done(), and then gracefully stops and waits for gracefulTimeout +// before definitively stopping. +func ServeWithGracefulStop(ctx context.Context, srv *grpc.Server, li net.Listener, gracefulTimeout time.Duration) error { + go func() { + // wait for the context to complete + <-ctx.Done() + + sctx, stopped := context.WithCancel(context.Background()) + go func() { + srv.GracefulStop() + stopped() + }() + + wait := time.NewTimer(gracefulTimeout) + defer wait.Stop() + + select { + case <-wait.C: + case <-sctx.Done(): + return + } + + // finally stop it completely + srv.Stop() + }() + + return srv.Serve(li) +} diff --git a/pkg/grpcutil/serve_test.go b/pkg/grpcutil/serve_test.go new file mode 100644 index 000000000..c071b4129 --- /dev/null +++ b/pkg/grpcutil/serve_test.go @@ -0,0 +1,100 @@ +package grpcutil_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/pomerium/pomerium/pkg/grpcutil" +) + +func TestServeWithGracefulStop(t *testing.T) { + t.Parallel() + + t.Run("immediate", func(t *testing.T) { + t.Parallel() + + li, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + + srv := grpc.NewServer() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + now := time.Now() + err = grpcutil.ServeWithGracefulStop(ctx, srv, li, time.Millisecond*100) + elapsed := time.Since(now) + assert.Nil(t, err) + assert.Less(t, elapsed, time.Millisecond*100, "should complete immediately") + }) + t.Run("graceful", func(t *testing.T) { + t.Parallel() + + li, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + + srv := grpc.NewServer() + hsrv := health.NewServer() + grpc_health_v1.RegisterHealthServer(srv, hsrv) + hsrv.SetServingStatus("test", grpc_health_v1.HealthCheckResponse_SERVING) + + now := time.Now() + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return grpcutil.ServeWithGracefulStop(ctx, srv, li, time.Millisecond*100) + }) + eg.Go(func() error { + var cc *grpc.ClientConn + for { + var err error + cc, err = grpc.Dial(li.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + continue + } + + break + } + + c := grpc_health_v1.NewHealthClient(cc) + + // wait till the server is ready + for { + _, err := c.Check(ctx, &grpc_health_v1.HealthCheckRequest{ + Service: "test", + }) + if err != nil { + return err + } + + break + } + + // start streaming to hold open the server during graceful stop + _, err = c.Watch(context.Background(), &grpc_health_v1.HealthCheckRequest{ + Service: "test", + }) + if err != nil { + return err + } + + cancel() + + return nil + }) + eg.Wait() + elapsed := time.Since(now) + assert.Greater(t, elapsed, time.Millisecond*100, "should complete after 100ms") + }) +} diff --git a/pkg/httputil/serve.go b/pkg/httputil/serve.go new file mode 100644 index 000000000..19fc72f5f --- /dev/null +++ b/pkg/httputil/serve.go @@ -0,0 +1,48 @@ +// Package httputil contains additional functionality for working with http. +package httputil + +import ( + "context" + "errors" + "net" + "net/http" + "time" + + "github.com/pomerium/pomerium/pkg/contextutil" +) + +// ServeWithGracefulStop serves the HTTP listener until ctx.Done(), and then gracefully stops and waits for gracefulTimeout +// before definitively stopping. +func ServeWithGracefulStop(ctx context.Context, handler http.Handler, li net.Listener, gracefulTimeout time.Duration) error { + // create a context that will be used for the http requests + // it will only be cancelled when baseCancel is called but will + // preserve the values from ctx + baseCtx, baseCancel := context.WithCancelCause(contextutil.OnlyValues(ctx)) + + srv := http.Server{ + Handler: handler, + BaseContext: func(l net.Listener) context.Context { + return baseCtx + }, + } + + go func() { + <-ctx.Done() + + // create a context that will cancel after the graceful timeout + timeoutCtx, clearTimeout := context.WithTimeout(context.Background(), gracefulTimeout) + defer clearTimeout() + + // shut the http server down + _ = srv.Shutdown(timeoutCtx) + + // cancel the base context used for http requests + baseCancel(ctx.Err()) + }() + + err := srv.Serve(li) + if errors.Is(err, http.ErrServerClosed) { + err = nil + } + return err +} diff --git a/pkg/httputil/serve_test.go b/pkg/httputil/serve_test.go new file mode 100644 index 000000000..553f29e71 --- /dev/null +++ b/pkg/httputil/serve_test.go @@ -0,0 +1,98 @@ +package httputil_test + +import ( + "context" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/pkg/httputil" +) + +func TestServeWithGracefulStop(t *testing.T) { + t.Parallel() + + t.Run("immediate", func(t *testing.T) { + t.Parallel() + + li, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + }) + + now := time.Now() + err = httputil.ServeWithGracefulStop(ctx, h, li, time.Millisecond*100) + elapsed := time.Since(now) + assert.Nil(t, err) + assert.Less(t, elapsed, time.Millisecond*100, "should complete immediately") + }) + t.Run("graceful", func(t *testing.T) { + t.Parallel() + + li, err := net.Listen("tcp4", "127.0.0.1:0") + require.NoError(t, err) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/": + w.WriteHeader(http.StatusNoContent) + case "/wait": + w.WriteHeader(http.StatusOK) + w.Write([]byte("\n")) + w.(http.Flusher).Flush() + select { + case <-r.Context().Done(): + case <-make(chan struct{}): + } + default: + http.NotFound(w, r) + } + }) + + now := time.Now() + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return httputil.ServeWithGracefulStop(ctx, h, li, time.Millisecond*100) + }) + eg.Go(func() error { + // poll until the server is ready + for { + res, err := http.Get("http://" + li.Addr().String() + "/") + if err != nil { + continue + } + res.Body.Close() + + break + } + + // issue a stream request that will last indefinitely + res, err := http.Get("http://" + li.Addr().String() + "/wait") + if err != nil { + return err + } + + cancel() + + // wait until the request completes (should stop after the graceful timeout) + io.ReadAll(res.Body) + res.Body.Close() + + return nil + }) + eg.Wait() + elapsed := time.Since(now) + assert.Greater(t, elapsed, time.Millisecond*100, "should complete after 100ms") + }) +}