atomicutil: use atomicutil.Value wherever possible (#3517)

* atomicutil: use atomicutil.Value wherever possible

* fix test

* fix mux router
This commit is contained in:
Caleb Doxsey 2022-07-28 15:38:38 -06:00 committed by GitHub
parent 5c14d2c994
commit 0ac7e45a21
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 121 additions and 215 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/pomerium/pomerium/authenticate/handlers/webauthn" "github.com/pomerium/pomerium/authenticate/handlers/webauthn"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
) )
@ -39,8 +40,8 @@ func ValidateOptions(o *config.Options) error {
// Authenticate contains data required to run the authenticate service. // Authenticate contains data required to run the authenticate service.
type Authenticate struct { type Authenticate struct {
cfg *authenticateConfig cfg *authenticateConfig
options *config.AtomicOptions options *atomicutil.Value[*config.Options]
state *atomicAuthenticateState state *atomicutil.Value[*authenticateState]
webauthn *webauthn.Handler webauthn *webauthn.Handler
} }
@ -49,7 +50,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(options...), cfg: getAuthenticateConfig(options...),
options: config.NewAtomicOptions(), options: config.NewAtomicOptions(),
state: newAtomicAuthenticateState(newAuthenticateState()), state: atomicutil.NewValue(newAuthenticateState()),
} }
a.webauthn = webauthn.New(a.getWebauthnState) a.webauthn = webauthn.New(a.getWebauthnState)

View file

@ -26,6 +26,7 @@ import (
"github.com/pomerium/pomerium/authenticate/handlers/webauthn" "github.com/pomerium/pomerium/authenticate/handlers/webauthn"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
@ -44,7 +45,7 @@ import (
func testAuthenticate() *Authenticate { func testAuthenticate() *Authenticate {
redirectURL, _ := url.Parse("https://auth.example.com/oauth/callback") redirectURL, _ := url.Parse("https://auth.example.com/oauth/callback")
var auth Authenticate var auth Authenticate
auth.state = newAtomicAuthenticateState(&authenticateState{ auth.state = atomicutil.NewValue(&authenticateState{
redirectURL: redirectURL, redirectURL: redirectURL,
cookieSecret: cryptutil.NewKey(), cookieSecret: cryptutil.NewKey(),
}) })
@ -150,7 +151,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: newAtomicAuthenticateState(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
sharedCipher: sharedCipher, sharedCipher: sharedCipher,
sessionStore: tt.session, sessionStore: tt.session,
redirectURL: uriParseHelper("https://some.example"), redirectURL: uriParseHelper("https://some.example"),
@ -306,7 +307,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: newAtomicAuthenticateState(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
encryptedEncoder: mock.Encoder{}, encryptedEncoder: mock.Encoder{},
sharedEncoder: mock.Encoder{}, sharedEncoder: mock.Encoder{},
@ -419,7 +420,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: newAtomicAuthenticateState(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
dataBrokerClient: mockDataBrokerServiceClient{ dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return nil, fmt.Errorf("not implemented") return nil, fmt.Errorf("not implemented")
@ -554,7 +555,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: newAtomicAuthenticateState(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(), cookieSecret: cryptutil.NewKey(),
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session, sessionStore: tt.session,
@ -644,7 +645,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "https://authenticate.service.cluster.local/.pomerium/?pomerium_redirect_uri=https://www.example.com", nil) r := httptest.NewRequest("GET", "https://authenticate.service.cluster.local/.pomerium/?pomerium_redirect_uri=https://www.example.com", nil)
var a Authenticate var a Authenticate
a.state = newAtomicAuthenticateState(&authenticateState{ a.state = atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(), cookieSecret: cryptutil.NewKey(),
}) })
a.options = config.NewAtomicOptions() a.options = config.NewAtomicOptions()
@ -709,7 +710,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
}) })
a := &Authenticate{ a := &Authenticate{
options: o, options: o,
state: newAtomicAuthenticateState(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
encryptedEncoder: signer, encryptedEncoder: signer,
sharedEncoder: signer, sharedEncoder: signer,

View file

@ -6,7 +6,6 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"sync/atomic"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
@ -172,21 +171,3 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
return state, nil return state, nil
} }
type atomicAuthenticateState struct {
atomic.Value
}
func newAtomicAuthenticateState(state *authenticateState) *atomicAuthenticateState {
aas := new(atomicAuthenticateState)
aas.Store(state)
return aas
}
func (aas *atomicAuthenticateState) Load() *authenticateState {
return aas.Value.Load().(*authenticateState)
}
func (aas *atomicAuthenticateState) Store(state *authenticateState) {
aas.Value.Store(state)
}

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
@ -24,9 +25,9 @@ import (
// Authorize struct holds // Authorize struct holds
type Authorize struct { type Authorize struct {
state *atomicAuthorizeState state *atomicutil.Value[*authorizeState]
store *store.Store store *store.Store
currentOptions *config.AtomicOptions currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker accessTracker *AccessTracker
globalCache storage.Cache globalCache storage.Cache
@ -49,7 +50,7 @@ func New(cfg *config.Config) (*Authorize, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
a.state = newAtomicAuthorizeState(state) a.state = atomicutil.NewValue(state)
return a, nil return a, nil
} }

View file

@ -18,6 +18,7 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
) )
@ -34,7 +35,7 @@ func TestAuthorize_okResponse(t *testing.T) {
}}, }},
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"), JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
} }
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(opt) a.currentOptions.Store(opt)
@ -90,7 +91,7 @@ func TestAuthorize_okResponse(t *testing.T) {
} }
func TestAuthorize_deniedResponse(t *testing.T) { func TestAuthorize_deniedResponse(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{ a.currentOptions.Store(&config.Options{

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
@ -46,7 +47,7 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) { func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{ a.currentOptions.Store(&config.Options{
@ -247,7 +248,7 @@ func Test_handleForwardAuth(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL}) a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL})
got := a.isForwardAuth(tc.checkReq) got := a.isForwardAuth(tc.checkReq)
@ -260,7 +261,7 @@ func Test_handleForwardAuth(t *testing.T) {
} }
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{ a.currentOptions.Store(&config.Options{

View file

@ -3,7 +3,6 @@ package authorize
import ( import (
"context" "context"
"fmt" "fmt"
"sync/atomic"
googlegrpc "google.golang.org/grpc" googlegrpc "google.golang.org/grpc"
@ -79,21 +78,3 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
return state, nil return state, nil
} }
type atomicAuthorizeState struct {
value atomic.Value
}
func newAtomicAuthorizeState(state *authorizeState) *atomicAuthorizeState {
aas := new(atomicAuthorizeState)
aas.Store(state)
return aas
}
func (aas *atomicAuthorizeState) Load() *authorizeState {
return aas.value.Load().(*authorizeState)
}
func (aas *atomicAuthorizeState) Store(state *authorizeState) {
aas.value.Store(state)
}

View file

@ -12,13 +12,13 @@ import (
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/volatiletech/null/v9" "github.com/volatiletech/null/v9"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/directory/azure" "github.com/pomerium/pomerium/internal/directory/azure"
"github.com/pomerium/pomerium/internal/directory/github" "github.com/pomerium/pomerium/internal/directory/github"
"github.com/pomerium/pomerium/internal/directory/gitlab" "github.com/pomerium/pomerium/internal/directory/gitlab"
@ -1586,24 +1586,7 @@ func min(x, y int) int {
return y return y
} }
// AtomicOptions are Options that can be access atomically.
type AtomicOptions struct {
value atomic.Value
}
// NewAtomicOptions creates a new AtomicOptions. // NewAtomicOptions creates a new AtomicOptions.
func NewAtomicOptions() *AtomicOptions { func NewAtomicOptions() *atomicutil.Value[*Options] {
ao := new(AtomicOptions) return atomicutil.NewValue(new(Options))
ao.Store(new(Options))
return ao
}
// Load loads the options.
func (a *AtomicOptions) Load() *Options {
return a.value.Load().(*Options)
}
// Store stores the options.
func (a *AtomicOptions) Store(options *Options) {
a.value.Store(options)
} }

View file

@ -3,11 +3,11 @@ package databroker
import ( import (
"context" "context"
"sync/atomic"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/internal/databroker"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
@ -17,12 +17,14 @@ import (
// A dataBrokerServer implements the data broker service interface. // A dataBrokerServer implements the data broker service interface.
type dataBrokerServer struct { type dataBrokerServer struct {
server *databroker.Server server *databroker.Server
sharedKey atomic.Value sharedKey *atomicutil.Value[[]byte]
} }
// newDataBrokerServer creates a new databroker service server. // newDataBrokerServer creates a new databroker service server.
func newDataBrokerServer(cfg *config.Config) *dataBrokerServer { func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
srv := &dataBrokerServer{} srv := &dataBrokerServer{
sharedKey: atomicutil.NewValue([]byte{}),
}
srv.server = databroker.New(srv.getOptions(cfg)...) srv.server = databroker.New(srv.getOptions(cfg)...)
srv.setKey(cfg) srv.setKey(cfg)
return srv return srv
@ -57,63 +59,63 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) {
// Databroker functions // Databroker functions
func (srv *dataBrokerServer) AcquireLease(ctx context.Context, req *databrokerpb.AcquireLeaseRequest) (*databrokerpb.AcquireLeaseResponse, error) { func (srv *dataBrokerServer) AcquireLease(ctx context.Context, req *databrokerpb.AcquireLeaseRequest) (*databrokerpb.AcquireLeaseResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.AcquireLease(ctx, req) return srv.server.AcquireLease(ctx, req)
} }
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) { func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.Get(ctx, req) return srv.server.Get(ctx, req)
} }
func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) { func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.Query(ctx, req) return srv.server.Query(ctx, req)
} }
func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutRequest) (*databrokerpb.PutResponse, error) { func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutRequest) (*databrokerpb.PutResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.Put(ctx, req) return srv.server.Put(ctx, req)
} }
func (srv *dataBrokerServer) ReleaseLease(ctx context.Context, req *databrokerpb.ReleaseLeaseRequest) (*emptypb.Empty, error) { func (srv *dataBrokerServer) ReleaseLease(ctx context.Context, req *databrokerpb.ReleaseLeaseRequest) (*emptypb.Empty, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.ReleaseLease(ctx, req) return srv.server.ReleaseLease(ctx, req)
} }
func (srv *dataBrokerServer) RenewLease(ctx context.Context, req *databrokerpb.RenewLeaseRequest) (*emptypb.Empty, error) { func (srv *dataBrokerServer) RenewLease(ctx context.Context, req *databrokerpb.RenewLeaseRequest) (*emptypb.Empty, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.RenewLease(ctx, req) return srv.server.RenewLease(ctx, req)
} }
func (srv *dataBrokerServer) SetOptions(ctx context.Context, req *databrokerpb.SetOptionsRequest) (*databrokerpb.SetOptionsResponse, error) { func (srv *dataBrokerServer) SetOptions(ctx context.Context, req *databrokerpb.SetOptionsRequest) (*databrokerpb.SetOptionsResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.SetOptions(ctx, req) return srv.server.SetOptions(ctx, req)
} }
func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error { func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil {
return err return err
} }
return srv.server.Sync(req, stream) return srv.server.Sync(req, stream)
} }
func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, stream databrokerpb.DataBrokerService_SyncLatestServer) error { func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, stream databrokerpb.DataBrokerService_SyncLatestServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil {
return err return err
} }
return srv.server.SyncLatest(req, stream) return srv.server.SyncLatest(req, stream)
@ -122,21 +124,21 @@ func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, str
// Registry functions // Registry functions
func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.Report(ctx, req) return srv.server.Report(ctx, req)
} }
func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
return nil, err return nil, err
} }
return srv.server.List(ctx, req) return srv.server.List(ctx, req)
} }
func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil {
return err return err
} }
return srv.server.Watch(req, stream) return srv.server.Watch(req, stream)

View file

@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
"github.com/pomerium/pomerium/internal/atomicutil"
internal_databroker "github.com/pomerium/pomerium/internal/databroker" internal_databroker "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -28,8 +29,7 @@ func init() {
lis = bufconn.Listen(bufSize) lis = bufconn.Listen(bufSize)
s := grpc.NewServer() s := grpc.NewServer()
internalSrv := internal_databroker.New() internalSrv := internal_databroker.New()
srv := &dataBrokerServer{server: internalSrv} srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})}
srv.sharedKey.Store([]byte{})
databroker.RegisterDataBrokerServiceServer(s, srv) databroker.RegisterDataBrokerServiceServer(s, srv)
go func() { go func() {

View file

@ -17,7 +17,16 @@ func NewValue[T any](init T) *Value[T] {
// Load loads the value atomically. // Load loads the value atomically.
func (v *Value[T]) Load() T { func (v *Value[T]) Load() T {
return v.value.Load().(T) var def T
if v == nil {
return def
}
cur := v.value.Load()
if cur == nil {
return def
}
return cur.(T)
} }
// Store stores the value atomically. // Store stores the value atomically.

View file

@ -0,0 +1,21 @@
package atomicutil
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValue(t *testing.T) {
v := NewValue(5)
assert.Equal(t, 5, v.Load())
t.Run("nil", func(t *testing.T) {
var v *Value[int]
assert.Equal(t, 0, v.Load())
})
t.Run("default", func(t *testing.T) {
var v Value[int]
assert.Equal(t, 0, v.Load())
})
}

View file

@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"sort" "sort"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
@ -18,6 +17,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
@ -46,7 +46,7 @@ type Manager struct {
mu sync.RWMutex mu sync.RWMutex
config *config.Config config *config.Config
certmagic *certmagic.Config certmagic *certmagic.Config
acmeMgr atomic.Value acmeMgr *atomicutil.Value[*certmagic.ACMEIssuer]
srv *http.Server srv *http.Server
*ocspCache *ocspCache
@ -87,6 +87,7 @@ func newManager(ctx context.Context,
mgr := &Manager{ mgr := &Manager{
src: src, src: src,
acmeTemplate: acmeTemplate, acmeTemplate: acmeTemplate,
acmeMgr: atomicutil.NewValue(new(certmagic.ACMEIssuer)),
certmagic: certmagicConfig, certmagic: certmagicConfig,
ocspCache: ocspRespCache, ocspCache: ocspRespCache,
} }
@ -324,12 +325,7 @@ func (mgr *Manager) updateServer(ctx context.Context, cfg *config.Config) {
} }
func (mgr *Manager) handleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool { func (mgr *Manager) handleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool {
obj := mgr.acmeMgr.Load() return mgr.acmeMgr.Load().HandleHTTPChallenge(w, r)
if obj == nil {
return false
}
acmeMgr := obj.(*certmagic.ACMEIssuer)
return acmeMgr.HandleHTTPChallenge(w, r)
} }
// GetConfig gets the config. // GetConfig gets the config.

View file

@ -12,6 +12,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/events" "github.com/pomerium/pomerium/pkg/grpc/events"
@ -73,8 +74,7 @@ func TestEvents(t *testing.T) {
srv := &Server{ srv := &Server{
haveSetCapacity: make(map[string]bool), haveSetCapacity: make(map[string]bool),
} currentConfig: atomicutil.NewValue(versionedConfig{
srv.currentConfig.Store(versionedConfig{
Config: &config.Config{ Config: &config.Config{
OutboundPort: outboundPort, OutboundPort: outboundPort,
Options: &config.Options{ Options: &config.Options{
@ -83,7 +83,8 @@ func TestEvents(t *testing.T) {
GRPCInsecure: true, GRPCInsecure: true,
}, },
}, },
}) }),
}
err := srv.storeEvent(ctx, new(events.EnvoyConfigurationEvent)) err := srv.storeEvent(ctx, new(events.EnvoyConfigurationEvent))
assert.NoError(t, err) assert.NoError(t, err)
return err return err

View file

@ -5,7 +5,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"sync/atomic"
"time" "time"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
@ -20,6 +19,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig" "github.com/pomerium/pomerium/config/envoyconfig"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/controlplane/xdsmgr" "github.com/pomerium/pomerium/internal/controlplane/xdsmgr"
"github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/httputil/reproxy" "github.com/pomerium/pomerium/internal/httputil/reproxy"
@ -38,18 +38,6 @@ type versionedConfig struct {
version int64 version int64
} }
type atomicVersionedConfig struct {
value atomic.Value
}
func (avo *atomicVersionedConfig) Load() versionedConfig {
return avo.value.Load().(versionedConfig)
}
func (avo *atomicVersionedConfig) Store(cfg versionedConfig) {
avo.value.Store(cfg)
}
// A Service can be mounted on the control plane. // A Service can be mounted on the control plane.
type Service interface { type Service interface {
Mount(r *mux.Router) Mount(r *mux.Router)
@ -67,14 +55,14 @@ type Server struct {
Builder *envoyconfig.Builder Builder *envoyconfig.Builder
EventsMgr *events.Manager EventsMgr *events.Manager
currentConfig atomicVersionedConfig currentConfig *atomicutil.Value[versionedConfig]
name string name string
xdsmgr *xdsmgr.Manager xdsmgr *xdsmgr.Manager
filemgr *filemgr.Manager filemgr *filemgr.Manager
metricsMgr *config.MetricsManager metricsMgr *config.MetricsManager
reproxy *reproxy.Handler reproxy *reproxy.Handler
httpRouter atomic.Value httpRouter *atomicutil.Value[*mux.Router]
authenticateSvc Service authenticateSvc Service
proxySvc Service proxySvc Service
@ -88,10 +76,11 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
EventsMgr: eventsMgr, EventsMgr: eventsMgr,
reproxy: reproxy.New(), reproxy: reproxy.New(),
haveSetCapacity: map[string]bool{}, haveSetCapacity: map[string]bool{},
} currentConfig: atomicutil.NewValue(versionedConfig{
srv.currentConfig.Store(versionedConfig{
Config: cfg, Config: cfg,
}) }),
httpRouter: atomicutil.NewValue(mux.NewRouter()),
}
var err error var err error
@ -227,7 +216,7 @@ func (srv *Server) Run(ctx context.Context) error {
Handler http.Handler Handler http.Handler
}{ }{
{"http", srv.HTTPListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { {"http", srv.HTTPListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
srv.httpRouter.Load().(http.Handler).ServeHTTP(w, r) srv.httpRouter.Load().ServeHTTP(w, r)
})}, })},
{"debug", srv.DebugListener, srv.DebugRouter}, {"debug", srv.DebugListener, srv.DebugRouter},
{"metrics", srv.MetricsListener, srv.MetricsRouter}, {"metrics", srv.MetricsListener, srv.MetricsRouter},
@ -307,6 +296,6 @@ func (srv *Server) updateRouter(cfg *config.Config) error {
if srv.proxySvc != nil { if srv.proxySvc != nil {
srv.proxySvc.Mount(httpRouter) srv.proxySvc.Mount(httpRouter)
} }
srv.httpRouter.Store(http.Handler(httpRouter)) srv.httpRouter.Store(httpRouter)
return nil return nil
} }

View file

@ -1,7 +1,6 @@
package manager package manager
import ( import (
"sync/atomic"
"time" "time"
"github.com/pomerium/pomerium/internal/directory" "github.com/pomerium/pomerium/internal/directory"
@ -106,21 +105,3 @@ func WithEventManager(mgr *events.Manager) Option {
c.eventMgr = mgr c.eventMgr = mgr
} }
} }
type atomicConfig struct {
value atomic.Value
}
func newAtomicConfig(cfg *config) *atomicConfig {
ac := new(atomicConfig)
ac.Store(cfg)
return ac
}
func (ac *atomicConfig) Load() *config {
return ac.value.Load().(*config)
}
func (ac *atomicConfig) Store(cfg *config) {
ac.value.Store(cfg)
}

View file

@ -14,6 +14,7 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/directory" "github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity" "github.com/pomerium/pomerium/internal/identity/identity"
@ -43,7 +44,7 @@ type (
// A Manager refreshes identity information using session and user data. // A Manager refreshes identity information using session and user data.
type Manager struct { type Manager struct {
cfg *atomicConfig cfg *atomicutil.Value[*config]
sessionScheduler *scheduler.Scheduler sessionScheduler *scheduler.Scheduler
userScheduler *scheduler.Scheduler userScheduler *scheduler.Scheduler
@ -62,7 +63,7 @@ func New(
options ...Option, options ...Option,
) *Manager { ) *Manager {
mgr := &Manager{ mgr := &Manager{
cfg: newAtomicConfig(newConfig()), cfg: atomicutil.NewValue(newConfig()),
sessionScheduler: scheduler.New(), sessionScheduler: scheduler.New(),
userScheduler: scheduler.New(), userScheduler: scheduler.New(),

View file

@ -3,11 +3,12 @@ package manager
import ( import (
"context" "context"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
type dataBrokerSyncer struct { type dataBrokerSyncer struct {
cfg *atomicConfig cfg *atomicutil.Value[*config]
update chan<- updateRecordsMessage update chan<- updateRecordsMessage
clear chan<- struct{} clear chan<- struct{}
@ -17,7 +18,7 @@ type dataBrokerSyncer struct {
func newDataBrokerSyncer( func newDataBrokerSyncer(
ctx context.Context, ctx context.Context,
cfg *atomicConfig, cfg *atomicutil.Value[*config],
update chan<- updateRecordsMessage, update chan<- updateRecordsMessage,
clear chan<- struct{}, clear chan<- struct{},
) *dataBrokerSyncer { ) *dataBrokerSyncer {

View file

@ -6,7 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/url" "net/url"
"sync/atomic"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -66,30 +65,3 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) {
} }
return a, nil return a, nil
} }
// wrap the Authenticator for the AtomicAuthenticator to support a nil default value.
type authenticatorValue struct {
Authenticator
}
// An AtomicAuthenticator is a strongly-typed atomic.Value for storing an authenticator.
type AtomicAuthenticator struct {
current atomic.Value
}
// NewAtomicAuthenticator creates a new AtomicAuthenticator.
func NewAtomicAuthenticator() *AtomicAuthenticator {
a := &AtomicAuthenticator{}
a.current.Store(authenticatorValue{})
return a
}
// Load loads the current authenticator.
func (a *AtomicAuthenticator) Load() Authenticator {
return a.current.Load().(authenticatorValue)
}
// Store stores the authenticator.
func (a *AtomicAuthenticator) Store(value Authenticator) {
a.current.Store(authenticatorValue{value})
}

View file

@ -5,16 +5,17 @@ import (
"context" "context"
"net/http" "net/http"
"os" "os"
"sync/atomic"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
"github.com/pomerium/pomerium/internal/atomicutil"
) )
var ( var (
logger atomic.Value logger = atomicutil.NewValue(new(zerolog.Logger))
zapLogger atomic.Value zapLogger = atomicutil.NewValue(new(zap.Logger))
zapLevel zap.AtomicLevel zapLevel zap.AtomicLevel
) )
@ -55,12 +56,12 @@ func SetLogger(l *zerolog.Logger) {
// Logger returns the global logger. // Logger returns the global logger.
func Logger() *zerolog.Logger { func Logger() *zerolog.Logger {
return logger.Load().(*zerolog.Logger) return logger.Load()
} }
// ZapLogger returns the global zap logger. // ZapLogger returns the global zap logger.
func ZapLogger() *zap.Logger { func ZapLogger() *zap.Logger {
return zapLogger.Load().(*zap.Logger) return zapLogger.Load()
} }
// SetLevel sets the minimum global log level. Options are 'debug' 'info' 'warn' and 'error'. // SetLevel sets the minimum global log level. Options are 'debug' 'info' 'warn' and 'error'.

View file

@ -12,6 +12,7 @@ import (
"testing" "testing"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
@ -538,7 +539,7 @@ func TestProxy_jwt(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy := &Proxy{ proxy := &Proxy{
state: newAtomicProxyState(&proxyState{}), state: atomicutil.NewValue(&proxyState{}),
} }
err := proxy.jwtAssertion(w, req) err := proxy.jwtAssertion(w, req)
if !assert.Error(t, err) { if !assert.Error(t, err) {

View file

@ -8,11 +8,11 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"sync/atomic"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error {
// Proxy stores all the information associated with proxying a request. // Proxy stores all the information associated with proxying a request.
type Proxy struct { type Proxy struct {
state *atomicProxyState state *atomicutil.Value[*proxyState]
currentOptions *config.AtomicOptions currentOptions *atomicutil.Value[*config.Options]
currentRouter atomic.Value currentRouter *atomicutil.Value[*mux.Router]
} }
// New takes a Proxy service from options and a validation function. // New takes a Proxy service from options and a validation function.
@ -65,10 +65,10 @@ func New(cfg *config.Config) (*Proxy, error) {
} }
p := &Proxy{ p := &Proxy{
state: newAtomicProxyState(state), state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
} }
p.currentRouter.Store(httputil.NewRouter())
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
return int64(len(p.currentOptions.Load().GetAllPolicies())) return int64(len(p.currentOptions.Load().GetAllPolicies()))
@ -128,5 +128,5 @@ func (p *Proxy) setHandlers(opts *config.Options) error {
} }
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.currentRouter.Load().(*mux.Router).ServeHTTP(w, r) p.currentRouter.Load().ServeHTTP(w, r)
} }

View file

@ -3,7 +3,6 @@ package proxy
import ( import (
"crypto/cipher" "crypto/cipher"
"net/url" "net/url"
"sync/atomic"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
@ -94,21 +93,3 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
return state, nil return state, nil
} }
type atomicProxyState struct {
value atomic.Value
}
func newAtomicProxyState(state *proxyState) *atomicProxyState {
aps := new(atomicProxyState)
aps.Store(state)
return aps
}
func (aps *atomicProxyState) Load() *proxyState {
return aps.value.Load().(*proxyState)
}
func (aps *atomicProxyState) Store(state *proxyState) {
aps.value.Store(state)
}