mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 13:07:13 +02:00
core: fix graceful stop (#4865)
* core/grpc: fix graceful stop * core/http: add graceful stop serve
This commit is contained in:
parent
c9df5156d4
commit
7eb7861f2c
7 changed files with 311 additions and 48 deletions
|
@ -7,6 +7,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
@ -128,12 +129,7 @@ func (c *DataBroker) Register(grpcServer *grpc.Server) {
|
||||||
func (c *DataBroker) Run(ctx context.Context) error {
|
func (c *DataBroker) Run(ctx context.Context) error {
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
return c.localGRPCServer.Serve(c.localListener)
|
return grpcutil.ServeWithGracefulStop(ctx, c.localGRPCServer, c.localListener, time.Second*5)
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
c.localGRPCServer.Stop()
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
return c.manager.Run(ctx)
|
return c.manager.Run(ctx)
|
||||||
|
|
|
@ -33,6 +33,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||||
pom_grpc "github.com/pomerium/pomerium/pkg/grpc"
|
pom_grpc "github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/httputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Service can be mounted on the control plane.
|
// A Service can be mounted on the control plane.
|
||||||
|
@ -180,31 +181,7 @@ func (srv *Server) Run(ctx context.Context) error {
|
||||||
// start the gRPC server
|
// start the gRPC server
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
log.Info(ctx).Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
|
log.Info(ctx).Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
|
||||||
return srv.GRPCServer.Serve(srv.GRPCListener)
|
return grpcutil.ServeWithGracefulStop(ctx, srv.GRPCServer, srv.GRPCListener, time.Second*5)
|
||||||
})
|
|
||||||
|
|
||||||
// gracefully stop the gRPC server on context cancellation
|
|
||||||
eg.Go(func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
ctx, cleanup := context.WithTimeout(ctx, time.Second*5)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
srv.GRPCServer.GracefulStop()
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
srv.GRPCServer.Stop()
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-ctx.Done()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, entry := range []struct {
|
for _, entry := range []struct {
|
||||||
|
@ -219,29 +196,13 @@ func (srv *Server) Run(ctx context.Context) error {
|
||||||
{"metrics", srv.MetricsListener, srv.MetricsRouter},
|
{"metrics", srv.MetricsListener, srv.MetricsRouter},
|
||||||
} {
|
} {
|
||||||
entry := entry
|
entry := entry
|
||||||
hsrv := (&http.Server{
|
|
||||||
BaseContext: func(li net.Listener) context.Context {
|
|
||||||
return ctx
|
|
||||||
},
|
|
||||||
Handler: entry.Handler,
|
|
||||||
})
|
|
||||||
|
|
||||||
// start the HTTP server
|
// start the HTTP server
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
log.Info(ctx).
|
log.Info(ctx).
|
||||||
Str("addr", entry.Listener.Addr().String()).
|
Str("addr", entry.Listener.Addr().String()).
|
||||||
Msgf("starting control-plane %s server", entry.Name)
|
Msgf("starting control-plane %s server", entry.Name)
|
||||||
return hsrv.Serve(entry.Listener)
|
return httputil.ServeWithGracefulStop(ctx, entry.Handler, entry.Listener, time.Second*5)
|
||||||
})
|
|
||||||
|
|
||||||
// gracefully stop the HTTP server on context cancellation
|
|
||||||
eg.Go(func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
|
|
||||||
ctx, cleanup := context.WithTimeout(ctx, time.Second*5)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
return hsrv.Shutdown(ctx)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -64,3 +64,25 @@ func (mc *mergedCtx) Value(key interface{}) interface{} {
|
||||||
}
|
}
|
||||||
return mc.doneCtx.Value(key)
|
return mc.doneCtx.Value(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type onlyValues struct {
|
||||||
|
context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyValues returns a derived context that removes deadlines and cancellation,
|
||||||
|
// but keeps values.
|
||||||
|
func OnlyValues(ctx context.Context) context.Context {
|
||||||
|
return onlyValues{ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o onlyValues) Deadline() (time.Time, bool) {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o onlyValues) Done() <-chan struct{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o onlyValues) Err() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
38
pkg/grpcutil/serve.go
Normal file
38
pkg/grpcutil/serve.go
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
package grpcutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServeWithGracefulStop serves the gRPC listener until ctx.Done(), and then gracefully stops and waits for gracefulTimeout
|
||||||
|
// before definitively stopping.
|
||||||
|
func ServeWithGracefulStop(ctx context.Context, srv *grpc.Server, li net.Listener, gracefulTimeout time.Duration) error {
|
||||||
|
go func() {
|
||||||
|
// wait for the context to complete
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
sctx, stopped := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
srv.GracefulStop()
|
||||||
|
stopped()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wait := time.NewTimer(gracefulTimeout)
|
||||||
|
defer wait.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-wait.C:
|
||||||
|
case <-sctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally stop it completely
|
||||||
|
srv.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return srv.Serve(li)
|
||||||
|
}
|
100
pkg/grpcutil/serve_test.go
Normal file
100
pkg/grpcutil/serve_test.go
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
package grpcutil_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/health"
|
||||||
|
"google.golang.org/grpc/health/grpc_health_v1"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServeWithGracefulStop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("immediate", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srv := grpc.NewServer()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
err = grpcutil.ServeWithGracefulStop(ctx, srv, li, time.Millisecond*100)
|
||||||
|
elapsed := time.Since(now)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Less(t, elapsed, time.Millisecond*100, "should complete immediately")
|
||||||
|
})
|
||||||
|
t.Run("graceful", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srv := grpc.NewServer()
|
||||||
|
hsrv := health.NewServer()
|
||||||
|
grpc_health_v1.RegisterHealthServer(srv, hsrv)
|
||||||
|
hsrv.SetServingStatus("test", grpc_health_v1.HealthCheckResponse_SERVING)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
return grpcutil.ServeWithGracefulStop(ctx, srv, li, time.Millisecond*100)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
var cc *grpc.ClientConn
|
||||||
|
for {
|
||||||
|
var err error
|
||||||
|
cc, err = grpc.Dial(li.Addr().String(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
c := grpc_health_v1.NewHealthClient(cc)
|
||||||
|
|
||||||
|
// wait till the server is ready
|
||||||
|
for {
|
||||||
|
_, err := c.Check(ctx, &grpc_health_v1.HealthCheckRequest{
|
||||||
|
Service: "test",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// start streaming to hold open the server during graceful stop
|
||||||
|
_, err = c.Watch(context.Background(), &grpc_health_v1.HealthCheckRequest{
|
||||||
|
Service: "test",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
eg.Wait()
|
||||||
|
elapsed := time.Since(now)
|
||||||
|
assert.Greater(t, elapsed, time.Millisecond*100, "should complete after 100ms")
|
||||||
|
})
|
||||||
|
}
|
48
pkg/httputil/serve.go
Normal file
48
pkg/httputil/serve.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
// Package httputil contains additional functionality for working with http.
|
||||||
|
package httputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServeWithGracefulStop serves the HTTP listener until ctx.Done(), and then gracefully stops and waits for gracefulTimeout
|
||||||
|
// before definitively stopping.
|
||||||
|
func ServeWithGracefulStop(ctx context.Context, handler http.Handler, li net.Listener, gracefulTimeout time.Duration) error {
|
||||||
|
// create a context that will be used for the http requests
|
||||||
|
// it will only be cancelled when baseCancel is called but will
|
||||||
|
// preserve the values from ctx
|
||||||
|
baseCtx, baseCancel := context.WithCancelCause(contextutil.OnlyValues(ctx))
|
||||||
|
|
||||||
|
srv := http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
BaseContext: func(l net.Listener) context.Context {
|
||||||
|
return baseCtx
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
// create a context that will cancel after the graceful timeout
|
||||||
|
timeoutCtx, clearTimeout := context.WithTimeout(context.Background(), gracefulTimeout)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
// shut the http server down
|
||||||
|
_ = srv.Shutdown(timeoutCtx)
|
||||||
|
|
||||||
|
// cancel the base context used for http requests
|
||||||
|
baseCancel(ctx.Err())
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := srv.Serve(li)
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
98
pkg/httputil/serve_test.go
Normal file
98
pkg/httputil/serve_test.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package httputil_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/httputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServeWithGracefulStop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("immediate", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
})
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
err = httputil.ServeWithGracefulStop(ctx, h, li, time.Millisecond*100)
|
||||||
|
elapsed := time.Since(now)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Less(t, elapsed, time.Millisecond*100, "should complete immediately")
|
||||||
|
})
|
||||||
|
t.Run("graceful", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/":
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
case "/wait":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("\n"))
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
case <-make(chan struct{}):
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
return httputil.ServeWithGracefulStop(ctx, h, li, time.Millisecond*100)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
// poll until the server is ready
|
||||||
|
for {
|
||||||
|
res, err := http.Get("http://" + li.Addr().String() + "/")
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// issue a stream request that will last indefinitely
|
||||||
|
res, err := http.Get("http://" + li.Addr().String() + "/wait")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// wait until the request completes (should stop after the graceful timeout)
|
||||||
|
io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
eg.Wait()
|
||||||
|
elapsed := time.Since(now)
|
||||||
|
assert.Greater(t, elapsed, time.Millisecond*100, "should complete after 100ms")
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue