From 0ac7e45a212b1692df541e01a7ee7a7a3ab0bbdb Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 28 Jul 2022 15:38:38 -0600 Subject: [PATCH] atomicutil: use atomicutil.Value wherever possible (#3517) * atomicutil: use atomicutil.Value wherever possible * fix test * fix mux router --- authenticate/authenticate.go | 7 +++--- authenticate/handlers_test.go | 15 +++++++------ authenticate/state.go | 19 ----------------- authorize/authorize.go | 7 +++--- authorize/check_response_test.go | 5 +++-- authorize/grpc_test.go | 7 +++--- authorize/state.go | 19 ----------------- config/options.go | 23 +++----------------- databroker/databroker.go | 32 +++++++++++++++------------- databroker/databroker_test.go | 4 ++-- internal/atomicutil/value.go | 11 +++++++++- internal/atomicutil/value_test.go | 21 ++++++++++++++++++ internal/autocert/manager.go | 12 ++++------- internal/controlplane/events_test.go | 21 +++++++++--------- internal/controlplane/server.go | 29 ++++++++----------------- internal/identity/manager/config.go | 19 ----------------- internal/identity/manager/manager.go | 5 +++-- internal/identity/manager/sync.go | 5 +++-- internal/identity/providers.go | 28 ------------------------ internal/log/log.go | 11 +++++----- proxy/handlers_test.go | 3 ++- proxy/proxy.go | 14 ++++++------ proxy/state.go | 19 ----------------- 23 files changed, 121 insertions(+), 215 deletions(-) create mode 100644 internal/atomicutil/value_test.go diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 578bd1f0e..206c0c3b6 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -9,6 +9,7 @@ import ( "github.com/pomerium/pomerium/authenticate/handlers/webauthn" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -39,8 +40,8 @@ func ValidateOptions(o *config.Options) error { // Authenticate contains data required to run the authenticate service. type Authenticate struct { cfg *authenticateConfig - options *config.AtomicOptions - state *atomicAuthenticateState + options *atomicutil.Value[*config.Options] + state *atomicutil.Value[*authenticateState] webauthn *webauthn.Handler } @@ -49,7 +50,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) { a := &Authenticate{ cfg: getAuthenticateConfig(options...), options: config.NewAtomicOptions(), - state: newAtomicAuthenticateState(newAuthenticateState()), + state: atomicutil.NewValue(newAuthenticateState()), } a.webauthn = webauthn.New(a.getWebauthnState) diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 111e087e0..71fadd942 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -26,6 +26,7 @@ import ( "github.com/pomerium/pomerium/authenticate/handlers/webauthn" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" @@ -44,7 +45,7 @@ import ( func testAuthenticate() *Authenticate { redirectURL, _ := url.Parse("https://auth.example.com/oauth/callback") var auth Authenticate - auth.state = newAtomicAuthenticateState(&authenticateState{ + auth.state = atomicutil.NewValue(&authenticateState{ redirectURL: redirectURL, cookieSecret: cryptutil.NewKey(), }) @@ -150,7 +151,7 @@ func TestAuthenticate_SignIn(t *testing.T) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), - state: newAtomicAuthenticateState(&authenticateState{ + state: atomicutil.NewValue(&authenticateState{ sharedCipher: sharedCipher, sessionStore: tt.session, redirectURL: uriParseHelper("https://some.example"), @@ -306,7 +307,7 @@ func TestAuthenticate_SignOut(t *testing.T) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), - state: newAtomicAuthenticateState(&authenticateState{ + state: atomicutil.NewValue(&authenticateState{ sessionStore: tt.sessionStore, encryptedEncoder: mock.Encoder{}, sharedEncoder: mock.Encoder{}, @@ -419,7 +420,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), - state: newAtomicAuthenticateState(&authenticateState{ + state: atomicutil.NewValue(&authenticateState{ dataBrokerClient: mockDataBrokerServiceClient{ get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { return nil, fmt.Errorf("not implemented") @@ -554,7 +555,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), - state: newAtomicAuthenticateState(&authenticateState{ + state: atomicutil.NewValue(&authenticateState{ cookieSecret: cryptutil.NewKey(), redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), sessionStore: tt.session, @@ -644,7 +645,7 @@ func TestAuthenticate_userInfo(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "https://authenticate.service.cluster.local/.pomerium/?pomerium_redirect_uri=https://www.example.com", nil) var a Authenticate - a.state = newAtomicAuthenticateState(&authenticateState{ + a.state = atomicutil.NewValue(&authenticateState{ cookieSecret: cryptutil.NewKey(), }) a.options = config.NewAtomicOptions() @@ -709,7 +710,7 @@ func TestAuthenticate_userInfo(t *testing.T) { }) a := &Authenticate{ options: o, - state: newAtomicAuthenticateState(&authenticateState{ + state: atomicutil.NewValue(&authenticateState{ sessionStore: tt.sessionStore, encryptedEncoder: signer, sharedEncoder: signer, diff --git a/authenticate/state.go b/authenticate/state.go index 0f580bca2..0ef549cb7 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "fmt" "net/url" - "sync/atomic" "github.com/go-jose/go-jose/v3" @@ -172,21 +171,3 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err return state, nil } - -type atomicAuthenticateState struct { - atomic.Value -} - -func newAtomicAuthenticateState(state *authenticateState) *atomicAuthenticateState { - aas := new(atomicAuthenticateState) - aas.Store(state) - return aas -} - -func (aas *atomicAuthenticateState) Load() *authenticateState { - return aas.Value.Load().(*authenticateState) -} - -func (aas *atomicAuthenticateState) Store(state *authenticateState) { - aas.Value.Store(state) -} diff --git a/authorize/authorize.go b/authorize/authorize.go index ccee02be6..e0d152cf1 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" @@ -24,9 +25,9 @@ import ( // Authorize struct holds type Authorize struct { - state *atomicAuthorizeState + state *atomicutil.Value[*authorizeState] store *store.Store - currentOptions *config.AtomicOptions + currentOptions *atomicutil.Value[*config.Options] accessTracker *AccessTracker globalCache storage.Cache @@ -49,7 +50,7 @@ func New(cfg *config.Config) (*Authorize, error) { if err != nil { return nil, err } - a.state = newAtomicAuthorizeState(state) + a.state = atomicutil.NewValue(state) return a, nil } diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index d868461fa..1e201edd3 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -18,6 +18,7 @@ import ( "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/testutil" ) @@ -34,7 +35,7 @@ func TestAuthorize_okResponse(t *testing.T) { }}, JWTClaimsHeaders: config.NewJWTClaimHeaders("email"), } - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} + a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder a.currentOptions.Store(opt) @@ -90,7 +91,7 @@ func TestAuthorize_okResponse(t *testing.T) { } func TestAuthorize_deniedResponse(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} + a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder a.currentOptions.Store(&config.Options{ diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 6ffb6258d..28a8cdbbb 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" @@ -46,7 +47,7 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` func Test_getEvaluatorRequest(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} + a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder a.currentOptions.Store(&config.Options{ @@ -247,7 +248,7 @@ func Test_handleForwardAuth(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} + a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL}) got := a.isForwardAuth(tc.checkReq) @@ -260,7 +261,7 @@ func Test_handleForwardAuth(t *testing.T) { } func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} + a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder a.currentOptions.Store(&config.Options{ diff --git a/authorize/state.go b/authorize/state.go index 9214dfe42..6440c4a10 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -3,7 +3,6 @@ package authorize import ( "context" "fmt" - "sync/atomic" googlegrpc "google.golang.org/grpc" @@ -79,21 +78,3 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho return state, nil } - -type atomicAuthorizeState struct { - value atomic.Value -} - -func newAtomicAuthorizeState(state *authorizeState) *atomicAuthorizeState { - aas := new(atomicAuthorizeState) - aas.Store(state) - return aas -} - -func (aas *atomicAuthorizeState) Load() *authorizeState { - return aas.value.Load().(*authorizeState) -} - -func (aas *atomicAuthorizeState) Store(state *authorizeState) { - aas.value.Store(state) -} diff --git a/config/options.go b/config/options.go index e65e3d035..d4d5637f4 100644 --- a/config/options.go +++ b/config/options.go @@ -12,13 +12,13 @@ import ( "path/filepath" "reflect" "strings" - "sync/atomic" "time" "github.com/mitchellh/mapstructure" "github.com/spf13/viper" "github.com/volatiletech/null/v9" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/directory/azure" "github.com/pomerium/pomerium/internal/directory/github" "github.com/pomerium/pomerium/internal/directory/gitlab" @@ -1586,24 +1586,7 @@ func min(x, y int) int { return y } -// AtomicOptions are Options that can be access atomically. -type AtomicOptions struct { - value atomic.Value -} - // NewAtomicOptions creates a new AtomicOptions. -func NewAtomicOptions() *AtomicOptions { - ao := new(AtomicOptions) - ao.Store(new(Options)) - return ao -} - -// Load loads the options. -func (a *AtomicOptions) Load() *Options { - return a.value.Load().(*Options) -} - -// Store stores the options. -func (a *AtomicOptions) Store(options *Options) { - a.value.Store(options) +func NewAtomicOptions() *atomicutil.Value[*Options] { + return atomicutil.NewValue(new(Options)) } diff --git a/databroker/databroker.go b/databroker/databroker.go index 17ce54f61..7cab0e276 100644 --- a/databroker/databroker.go +++ b/databroker/databroker.go @@ -3,11 +3,11 @@ package databroker import ( "context" - "sync/atomic" "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" @@ -17,12 +17,14 @@ import ( // A dataBrokerServer implements the data broker service interface. type dataBrokerServer struct { server *databroker.Server - sharedKey atomic.Value + sharedKey *atomicutil.Value[[]byte] } // newDataBrokerServer creates a new databroker service server. func newDataBrokerServer(cfg *config.Config) *dataBrokerServer { - srv := &dataBrokerServer{} + srv := &dataBrokerServer{ + sharedKey: atomicutil.NewValue([]byte{}), + } srv.server = databroker.New(srv.getOptions(cfg)...) srv.setKey(cfg) return srv @@ -57,63 +59,63 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) { // Databroker functions func (srv *dataBrokerServer) AcquireLease(ctx context.Context, req *databrokerpb.AcquireLeaseRequest) (*databrokerpb.AcquireLeaseResponse, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.AcquireLease(ctx, req) } func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.Get(ctx, req) } func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.Query(ctx, req) } func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutRequest) (*databrokerpb.PutResponse, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.Put(ctx, req) } func (srv *dataBrokerServer) ReleaseLease(ctx context.Context, req *databrokerpb.ReleaseLeaseRequest) (*emptypb.Empty, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.ReleaseLease(ctx, req) } func (srv *dataBrokerServer) RenewLease(ctx context.Context, req *databrokerpb.RenewLeaseRequest) (*emptypb.Empty, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.RenewLease(ctx, req) } func (srv *dataBrokerServer) SetOptions(ctx context.Context, req *databrokerpb.SetOptionsRequest) (*databrokerpb.SetOptionsResponse, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.SetOptions(ctx, req) } func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error { - if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil { return err } return srv.server.Sync(req, stream) } func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, stream databrokerpb.DataBrokerService_SyncLatestServer) error { - if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil { return err } return srv.server.SyncLatest(req, stream) @@ -122,21 +124,21 @@ func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, str // Registry functions func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.Report(ctx, req) } func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { - if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil { return nil, err } return srv.server.List(ctx, req) } func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { - if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { + if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil { return err } return srv.server.Watch(req, stream) diff --git a/databroker/databroker_test.go b/databroker/databroker_test.go index ee6b75821..2db752384 100644 --- a/databroker/databroker_test.go +++ b/databroker/databroker_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" + "github.com/pomerium/pomerium/internal/atomicutil" internal_databroker "github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -28,8 +29,7 @@ func init() { lis = bufconn.Listen(bufSize) s := grpc.NewServer() internalSrv := internal_databroker.New() - srv := &dataBrokerServer{server: internalSrv} - srv.sharedKey.Store([]byte{}) + srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})} databroker.RegisterDataBrokerServiceServer(s, srv) go func() { diff --git a/internal/atomicutil/value.go b/internal/atomicutil/value.go index 58dc71c6b..665ff3241 100644 --- a/internal/atomicutil/value.go +++ b/internal/atomicutil/value.go @@ -17,7 +17,16 @@ func NewValue[T any](init T) *Value[T] { // Load loads the value atomically. func (v *Value[T]) Load() T { - return v.value.Load().(T) + var def T + if v == nil { + return def + } + + cur := v.value.Load() + if cur == nil { + return def + } + return cur.(T) } // Store stores the value atomically. diff --git a/internal/atomicutil/value_test.go b/internal/atomicutil/value_test.go new file mode 100644 index 000000000..9522f425c --- /dev/null +++ b/internal/atomicutil/value_test.go @@ -0,0 +1,21 @@ +package atomicutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValue(t *testing.T) { + v := NewValue(5) + assert.Equal(t, 5, v.Load()) + + t.Run("nil", func(t *testing.T) { + var v *Value[int] + assert.Equal(t, 0, v.Load()) + }) + t.Run("default", func(t *testing.T) { + var v Value[int] + assert.Equal(t, 0, v.Load()) + }) +} diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 605ae7411..c38497986 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -9,7 +9,6 @@ import ( "net/http" "sort" "sync" - "sync/atomic" "time" "github.com/caddyserver/certmagic" @@ -18,6 +17,7 @@ import ( "go.uber.org/zap" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" @@ -46,7 +46,7 @@ type Manager struct { mu sync.RWMutex config *config.Config certmagic *certmagic.Config - acmeMgr atomic.Value + acmeMgr *atomicutil.Value[*certmagic.ACMEIssuer] srv *http.Server *ocspCache @@ -87,6 +87,7 @@ func newManager(ctx context.Context, mgr := &Manager{ src: src, acmeTemplate: acmeTemplate, + acmeMgr: atomicutil.NewValue(new(certmagic.ACMEIssuer)), certmagic: certmagicConfig, ocspCache: ocspRespCache, } @@ -324,12 +325,7 @@ func (mgr *Manager) updateServer(ctx context.Context, cfg *config.Config) { } func (mgr *Manager) handleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool { - obj := mgr.acmeMgr.Load() - if obj == nil { - return false - } - acmeMgr := obj.(*certmagic.ACMEIssuer) - return acmeMgr.HandleHTTPChallenge(w, r) + return mgr.acmeMgr.Load().HandleHTTPChallenge(w, r) } // GetConfig gets the config. diff --git a/internal/controlplane/events_test.go b/internal/controlplane/events_test.go index 8425df512..8bc628a2b 100644 --- a/internal/controlplane/events_test.go +++ b/internal/controlplane/events_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/pkg/cryptutil" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/events" @@ -73,17 +74,17 @@ func TestEvents(t *testing.T) { srv := &Server{ haveSetCapacity: make(map[string]bool), - } - srv.currentConfig.Store(versionedConfig{ - Config: &config.Config{ - OutboundPort: outboundPort, - Options: &config.Options{ - SharedKey: cryptutil.NewBase64Key(), - DataBrokerURLString: "http://" + li.Addr().String(), - GRPCInsecure: true, + currentConfig: atomicutil.NewValue(versionedConfig{ + Config: &config.Config{ + OutboundPort: outboundPort, + Options: &config.Options{ + SharedKey: cryptutil.NewBase64Key(), + DataBrokerURLString: "http://" + li.Addr().String(), + GRPCInsecure: true, + }, }, - }, - }) + }), + } err := srv.storeEvent(ctx, new(events.EnvoyConfigurationEvent)) assert.NoError(t, err) return err diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index b0f007ab1..d73301df7 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -5,7 +5,6 @@ import ( "net" "net/http" "net/http/pprof" - "sync/atomic" "time" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" @@ -20,6 +19,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig" "github.com/pomerium/pomerium/config/envoyconfig/filemgr" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/controlplane/xdsmgr" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/httputil/reproxy" @@ -38,18 +38,6 @@ type versionedConfig struct { version int64 } -type atomicVersionedConfig struct { - value atomic.Value -} - -func (avo *atomicVersionedConfig) Load() versionedConfig { - return avo.value.Load().(versionedConfig) -} - -func (avo *atomicVersionedConfig) Store(cfg versionedConfig) { - avo.value.Store(cfg) -} - // A Service can be mounted on the control plane. type Service interface { Mount(r *mux.Router) @@ -67,14 +55,14 @@ type Server struct { Builder *envoyconfig.Builder EventsMgr *events.Manager - currentConfig atomicVersionedConfig + currentConfig *atomicutil.Value[versionedConfig] name string xdsmgr *xdsmgr.Manager filemgr *filemgr.Manager metricsMgr *config.MetricsManager reproxy *reproxy.Handler - httpRouter atomic.Value + httpRouter *atomicutil.Value[*mux.Router] authenticateSvc Service proxySvc Service @@ -88,10 +76,11 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr EventsMgr: eventsMgr, reproxy: reproxy.New(), haveSetCapacity: map[string]bool{}, + currentConfig: atomicutil.NewValue(versionedConfig{ + Config: cfg, + }), + httpRouter: atomicutil.NewValue(mux.NewRouter()), } - srv.currentConfig.Store(versionedConfig{ - Config: cfg, - }) var err error @@ -227,7 +216,7 @@ func (srv *Server) Run(ctx context.Context) error { Handler http.Handler }{ {"http", srv.HTTPListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - srv.httpRouter.Load().(http.Handler).ServeHTTP(w, r) + srv.httpRouter.Load().ServeHTTP(w, r) })}, {"debug", srv.DebugListener, srv.DebugRouter}, {"metrics", srv.MetricsListener, srv.MetricsRouter}, @@ -307,6 +296,6 @@ func (srv *Server) updateRouter(cfg *config.Config) error { if srv.proxySvc != nil { srv.proxySvc.Mount(httpRouter) } - srv.httpRouter.Store(http.Handler(httpRouter)) + srv.httpRouter.Store(httpRouter) return nil } diff --git a/internal/identity/manager/config.go b/internal/identity/manager/config.go index 4620a67d6..f566fb91c 100644 --- a/internal/identity/manager/config.go +++ b/internal/identity/manager/config.go @@ -1,7 +1,6 @@ package manager import ( - "sync/atomic" "time" "github.com/pomerium/pomerium/internal/directory" @@ -106,21 +105,3 @@ func WithEventManager(mgr *events.Manager) Option { c.eventMgr = mgr } } - -type atomicConfig struct { - value atomic.Value -} - -func newAtomicConfig(cfg *config) *atomicConfig { - ac := new(atomicConfig) - ac.Store(cfg) - return ac -} - -func (ac *atomicConfig) Load() *config { - return ac.value.Load().(*config) -} - -func (ac *atomicConfig) Store(cfg *config) { - ac.value.Store(cfg) -} diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index d651e6b9b..92c9098f8 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/directory" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/identity/identity" @@ -43,7 +44,7 @@ type ( // A Manager refreshes identity information using session and user data. type Manager struct { - cfg *atomicConfig + cfg *atomicutil.Value[*config] sessionScheduler *scheduler.Scheduler userScheduler *scheduler.Scheduler @@ -62,7 +63,7 @@ func New( options ...Option, ) *Manager { mgr := &Manager{ - cfg: newAtomicConfig(newConfig()), + cfg: atomicutil.NewValue(newConfig()), sessionScheduler: scheduler.New(), userScheduler: scheduler.New(), diff --git a/internal/identity/manager/sync.go b/internal/identity/manager/sync.go index 6fe4a71cb..e99300c68 100644 --- a/internal/identity/manager/sync.go +++ b/internal/identity/manager/sync.go @@ -3,11 +3,12 @@ package manager import ( "context" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" ) type dataBrokerSyncer struct { - cfg *atomicConfig + cfg *atomicutil.Value[*config] update chan<- updateRecordsMessage clear chan<- struct{} @@ -17,7 +18,7 @@ type dataBrokerSyncer struct { func newDataBrokerSyncer( ctx context.Context, - cfg *atomicConfig, + cfg *atomicutil.Value[*config], update chan<- updateRecordsMessage, clear chan<- struct{}, ) *dataBrokerSyncer { diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 49902b2d4..48d14bef3 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "net/url" - "sync/atomic" "golang.org/x/oauth2" @@ -66,30 +65,3 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) { } return a, nil } - -// wrap the Authenticator for the AtomicAuthenticator to support a nil default value. -type authenticatorValue struct { - Authenticator -} - -// An AtomicAuthenticator is a strongly-typed atomic.Value for storing an authenticator. -type AtomicAuthenticator struct { - current atomic.Value -} - -// NewAtomicAuthenticator creates a new AtomicAuthenticator. -func NewAtomicAuthenticator() *AtomicAuthenticator { - a := &AtomicAuthenticator{} - a.current.Store(authenticatorValue{}) - return a -} - -// Load loads the current authenticator. -func (a *AtomicAuthenticator) Load() Authenticator { - return a.current.Load().(authenticatorValue) -} - -// Store stores the authenticator. -func (a *AtomicAuthenticator) Store(value Authenticator) { - a.current.Store(authenticatorValue{value}) -} diff --git a/internal/log/log.go b/internal/log/log.go index e81140781..2bf68a773 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -5,16 +5,17 @@ import ( "context" "net/http" "os" - "sync/atomic" "github.com/rs/zerolog" "go.uber.org/zap" "go.uber.org/zap/zapcore" + + "github.com/pomerium/pomerium/internal/atomicutil" ) var ( - logger atomic.Value - zapLogger atomic.Value + logger = atomicutil.NewValue(new(zerolog.Logger)) + zapLogger = atomicutil.NewValue(new(zap.Logger)) zapLevel zap.AtomicLevel ) @@ -55,12 +56,12 @@ func SetLogger(l *zerolog.Logger) { // Logger returns the global logger. func Logger() *zerolog.Logger { - return logger.Load().(*zerolog.Logger) + return logger.Load() } // ZapLogger returns the global zap logger. func ZapLogger() *zap.Logger { - return zapLogger.Load().(*zap.Logger) + return zapLogger.Load() } // SetLevel sets the minimum global log level. Options are 'debug' 'info' 'warn' and 'error'. diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 5f5351abf..f842e7615 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/httputil" @@ -538,7 +539,7 @@ func TestProxy_jwt(t *testing.T) { w := httptest.NewRecorder() proxy := &Proxy{ - state: newAtomicProxyState(&proxyState{}), + state: atomicutil.NewValue(&proxyState{}), } err := proxy.jwtAssertion(w, req) if !assert.Error(t, err) { diff --git a/proxy/proxy.go b/proxy/proxy.go index becfb1e04..85f382be4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -8,11 +8,11 @@ import ( "context" "fmt" "net/http" - "sync/atomic" "github.com/gorilla/mux" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" @@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error { // Proxy stores all the information associated with proxying a request. type Proxy struct { - state *atomicProxyState - currentOptions *config.AtomicOptions - currentRouter atomic.Value + state *atomicutil.Value[*proxyState] + currentOptions *atomicutil.Value[*config.Options] + currentRouter *atomicutil.Value[*mux.Router] } // New takes a Proxy service from options and a validation function. @@ -65,10 +65,10 @@ func New(cfg *config.Config) (*Proxy, error) { } p := &Proxy{ - state: newAtomicProxyState(state), + state: atomicutil.NewValue(state), currentOptions: config.NewAtomicOptions(), + currentRouter: atomicutil.NewValue(httputil.NewRouter()), } - p.currentRouter.Store(httputil.NewRouter()) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { return int64(len(p.currentOptions.Load().GetAllPolicies())) @@ -128,5 +128,5 @@ func (p *Proxy) setHandlers(opts *config.Options) error { } func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - p.currentRouter.Load().(*mux.Router).ServeHTTP(w, r) + p.currentRouter.Load().ServeHTTP(w, r) } diff --git a/proxy/state.go b/proxy/state.go index 6fe95cf8c..f8bc827fd 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -3,7 +3,6 @@ package proxy import ( "crypto/cipher" "net/url" - "sync/atomic" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" @@ -94,21 +93,3 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { return state, nil } - -type atomicProxyState struct { - value atomic.Value -} - -func newAtomicProxyState(state *proxyState) *atomicProxyState { - aps := new(atomicProxyState) - aps.Store(state) - return aps -} - -func (aps *atomicProxyState) Load() *proxyState { - return aps.value.Load().(*proxyState) -} - -func (aps *atomicProxyState) Store(state *proxyState) { - aps.value.Store(state) -}