mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present.
113 lines
2.8 KiB
Go
113 lines
2.8 KiB
Go
package controller_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/zero/controller"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
)
|
|
|
|
type mockConfigSource struct {
|
|
mock.Mock
|
|
config.Source
|
|
}
|
|
|
|
func (s *mockConfigSource) GetConfig() *config.Config {
|
|
args := s.Called()
|
|
return args.Get(0).(*config.Config)
|
|
}
|
|
|
|
func (s *mockConfigSource) OnConfigChange(ctx context.Context, cl config.ChangeListener) {
|
|
s.Called(ctx, cl)
|
|
}
|
|
|
|
func TestDatabrokerRestart(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
newConfig := func() *config.Config {
|
|
return &config.Config{
|
|
Options: &config.Options{
|
|
SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()),
|
|
},
|
|
GRPCPort: ":12345",
|
|
}
|
|
}
|
|
|
|
t.Run("no error", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
src := new(mockConfigSource)
|
|
src.On("OnConfigChange", mock.Anything, mock.Anything).Once()
|
|
src.On("GetConfig").Once().Return(newConfig())
|
|
|
|
ctx := context.Background()
|
|
r := controller.NewDatabrokerRestartRunner(ctx, src)
|
|
defer r.Close()
|
|
|
|
err := r.Run(ctx, func(_ context.Context, _ databroker.DataBrokerServiceClient) error {
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
t.Run("error, retry", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
src := new(mockConfigSource)
|
|
src.On("OnConfigChange", mock.Anything, mock.Anything).Once()
|
|
src.On("GetConfig").Once().Return(newConfig())
|
|
|
|
ctx := context.Background()
|
|
r := controller.NewDatabrokerRestartRunner(ctx, src)
|
|
defer r.Close()
|
|
|
|
count := 0
|
|
err := r.Run(ctx, func(_ context.Context, _ databroker.DataBrokerServiceClient) error {
|
|
count++
|
|
if count == 1 {
|
|
return errors.New("simulated error")
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, 2, count)
|
|
})
|
|
t.Run("config changed, execution restarted", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
src := new(mockConfigSource)
|
|
var cl config.ChangeListener
|
|
src.On("OnConfigChange", mock.Anything, mock.Anything).Once().Run(func(args mock.Arguments) {
|
|
cl = args.Get(1).(config.ChangeListener)
|
|
})
|
|
src.On("GetConfig").Once().Return(newConfig())
|
|
|
|
ctx := context.Background()
|
|
r := controller.NewDatabrokerRestartRunner(ctx, src)
|
|
defer r.Close()
|
|
|
|
count := 0
|
|
var clients [2]databroker.DataBrokerServiceClient
|
|
err := r.Run(ctx, func(ctx context.Context, client databroker.DataBrokerServiceClient) error {
|
|
clients[count] = client
|
|
count++
|
|
if count == 1 {
|
|
cl(context.Background(), newConfig())
|
|
<-ctx.Done()
|
|
require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged)
|
|
return context.Cause(ctx)
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, 2, count)
|
|
require.NotSame(t, clients[0], clients[1])
|
|
})
|
|
}
|