mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
atomicutil: use atomicutil.Value wherever possible (#3517)
* atomicutil: use atomicutil.Value wherever possible * fix test * fix mux router
This commit is contained in:
parent
5c14d2c994
commit
0ac7e45a21
23 changed files with 121 additions and 215 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
@ -39,8 +40,8 @@ func ValidateOptions(o *config.Options) error {
|
|||
// Authenticate contains data required to run the authenticate service.
|
||||
type Authenticate struct {
|
||||
cfg *authenticateConfig
|
||||
options *config.AtomicOptions
|
||||
state *atomicAuthenticateState
|
||||
options *atomicutil.Value[*config.Options]
|
||||
state *atomicutil.Value[*authenticateState]
|
||||
webauthn *webauthn.Handler
|
||||
}
|
||||
|
||||
|
@ -49,7 +50,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
|||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(options...),
|
||||
options: config.NewAtomicOptions(),
|
||||
state: newAtomicAuthenticateState(newAuthenticateState()),
|
||||
state: atomicutil.NewValue(newAuthenticateState()),
|
||||
}
|
||||
a.webauthn = webauthn.New(a.getWebauthnState)
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
|
@ -44,7 +45,7 @@ import (
|
|||
func testAuthenticate() *Authenticate {
|
||||
redirectURL, _ := url.Parse("https://auth.example.com/oauth/callback")
|
||||
var auth Authenticate
|
||||
auth.state = newAtomicAuthenticateState(&authenticateState{
|
||||
auth.state = atomicutil.NewValue(&authenticateState{
|
||||
redirectURL: redirectURL,
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
})
|
||||
|
@ -150,7 +151,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: newAtomicAuthenticateState(&authenticateState{
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
sharedCipher: sharedCipher,
|
||||
sessionStore: tt.session,
|
||||
redirectURL: uriParseHelper("https://some.example"),
|
||||
|
@ -306,7 +307,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: newAtomicAuthenticateState(&authenticateState{
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
sessionStore: tt.sessionStore,
|
||||
encryptedEncoder: mock.Encoder{},
|
||||
sharedEncoder: mock.Encoder{},
|
||||
|
@ -419,7 +420,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: newAtomicAuthenticateState(&authenticateState{
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
dataBrokerClient: mockDataBrokerServiceClient{
|
||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
|
@ -554,7 +555,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: newAtomicAuthenticateState(&authenticateState{
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
sessionStore: tt.session,
|
||||
|
@ -644,7 +645,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "https://authenticate.service.cluster.local/.pomerium/?pomerium_redirect_uri=https://www.example.com", nil)
|
||||
var a Authenticate
|
||||
a.state = newAtomicAuthenticateState(&authenticateState{
|
||||
a.state = atomicutil.NewValue(&authenticateState{
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
})
|
||||
a.options = config.NewAtomicOptions()
|
||||
|
@ -709,7 +710,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
})
|
||||
a := &Authenticate{
|
||||
options: o,
|
||||
state: newAtomicAuthenticateState(&authenticateState{
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
sessionStore: tt.sessionStore,
|
||||
encryptedEncoder: signer,
|
||||
sharedEncoder: signer,
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
|
||||
|
@ -172,21 +171,3 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
type atomicAuthenticateState struct {
|
||||
atomic.Value
|
||||
}
|
||||
|
||||
func newAtomicAuthenticateState(state *authenticateState) *atomicAuthenticateState {
|
||||
aas := new(atomicAuthenticateState)
|
||||
aas.Store(state)
|
||||
return aas
|
||||
}
|
||||
|
||||
func (aas *atomicAuthenticateState) Load() *authenticateState {
|
||||
return aas.Value.Load().(*authenticateState)
|
||||
}
|
||||
|
||||
func (aas *atomicAuthenticateState) Store(state *authenticateState) {
|
||||
aas.Value.Store(state)
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
|
@ -24,9 +25,9 @@ import (
|
|||
|
||||
// Authorize struct holds
|
||||
type Authorize struct {
|
||||
state *atomicAuthorizeState
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *config.AtomicOptions
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
|
||||
|
@ -49,7 +50,7 @@ func New(cfg *config.Config) (*Authorize, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.state = newAtomicAuthorizeState(state)
|
||||
a.state = atomicutil.NewValue(state)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
@ -34,7 +35,7 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}},
|
||||
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
||||
}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(opt)
|
||||
|
@ -90,7 +91,7 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
@ -46,7 +47,7 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
|
@ -247,7 +248,7 @@ func Test_handleForwardAuth(t *testing.T) {
|
|||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL})
|
||||
|
||||
got := a.isForwardAuth(tc.checkReq)
|
||||
|
@ -260,7 +261,7 @@ func Test_handleForwardAuth(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
|
|
|
@ -3,7 +3,6 @@ package authorize
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
googlegrpc "google.golang.org/grpc"
|
||||
|
||||
|
@ -79,21 +78,3 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
|
|||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
type atomicAuthorizeState struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func newAtomicAuthorizeState(state *authorizeState) *atomicAuthorizeState {
|
||||
aas := new(atomicAuthorizeState)
|
||||
aas.Store(state)
|
||||
return aas
|
||||
}
|
||||
|
||||
func (aas *atomicAuthorizeState) Load() *authorizeState {
|
||||
return aas.value.Load().(*authorizeState)
|
||||
}
|
||||
|
||||
func (aas *atomicAuthorizeState) Store(state *authorizeState) {
|
||||
aas.value.Store(state)
|
||||
}
|
||||
|
|
|
@ -12,13 +12,13 @@ import (
|
|||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/volatiletech/null/v9"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||
"github.com/pomerium/pomerium/internal/directory/github"
|
||||
"github.com/pomerium/pomerium/internal/directory/gitlab"
|
||||
|
@ -1586,24 +1586,7 @@ func min(x, y int) int {
|
|||
return y
|
||||
}
|
||||
|
||||
// AtomicOptions are Options that can be access atomically.
|
||||
type AtomicOptions struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
// NewAtomicOptions creates a new AtomicOptions.
|
||||
func NewAtomicOptions() *AtomicOptions {
|
||||
ao := new(AtomicOptions)
|
||||
ao.Store(new(Options))
|
||||
return ao
|
||||
}
|
||||
|
||||
// Load loads the options.
|
||||
func (a *AtomicOptions) Load() *Options {
|
||||
return a.value.Load().(*Options)
|
||||
}
|
||||
|
||||
// Store stores the options.
|
||||
func (a *AtomicOptions) Store(options *Options) {
|
||||
a.value.Store(options)
|
||||
func NewAtomicOptions() *atomicutil.Value[*Options] {
|
||||
return atomicutil.NewValue(new(Options))
|
||||
}
|
||||
|
|
|
@ -3,11 +3,11 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/databroker"
|
||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||
|
@ -17,12 +17,14 @@ import (
|
|||
// A dataBrokerServer implements the data broker service interface.
|
||||
type dataBrokerServer struct {
|
||||
server *databroker.Server
|
||||
sharedKey atomic.Value
|
||||
sharedKey *atomicutil.Value[[]byte]
|
||||
}
|
||||
|
||||
// newDataBrokerServer creates a new databroker service server.
|
||||
func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
|
||||
srv := &dataBrokerServer{}
|
||||
srv := &dataBrokerServer{
|
||||
sharedKey: atomicutil.NewValue([]byte{}),
|
||||
}
|
||||
srv.server = databroker.New(srv.getOptions(cfg)...)
|
||||
srv.setKey(cfg)
|
||||
return srv
|
||||
|
@ -57,63 +59,63 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) {
|
|||
// Databroker functions
|
||||
|
||||
func (srv *dataBrokerServer) AcquireLease(ctx context.Context, req *databrokerpb.AcquireLeaseRequest) (*databrokerpb.AcquireLeaseResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.AcquireLease(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Get(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Query(ctx context.Context, req *databrokerpb.QueryRequest) (*databrokerpb.QueryResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Query(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutRequest) (*databrokerpb.PutResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Put(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) ReleaseLease(ctx context.Context, req *databrokerpb.ReleaseLeaseRequest) (*emptypb.Empty, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.ReleaseLease(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) RenewLease(ctx context.Context, req *databrokerpb.RenewLeaseRequest) (*emptypb.Empty, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.RenewLease(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) SetOptions(ctx context.Context, req *databrokerpb.SetOptionsRequest) (*databrokerpb.SetOptionsResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.SetOptions(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Sync(req *databrokerpb.SyncRequest, stream databrokerpb.DataBrokerService_SyncServer) error {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.server.Sync(req, stream)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, stream databrokerpb.DataBrokerService_SyncLatestServer) error {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.server.SyncLatest(req, stream)
|
||||
|
@ -122,21 +124,21 @@ func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, str
|
|||
// Registry functions
|
||||
|
||||
func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Report(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.List(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load()); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.server.Watch(req, stream)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -28,8 +29,7 @@ func init() {
|
|||
lis = bufconn.Listen(bufSize)
|
||||
s := grpc.NewServer()
|
||||
internalSrv := internal_databroker.New()
|
||||
srv := &dataBrokerServer{server: internalSrv}
|
||||
srv.sharedKey.Store([]byte{})
|
||||
srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})}
|
||||
databroker.RegisterDataBrokerServiceServer(s, srv)
|
||||
|
||||
go func() {
|
||||
|
|
|
@ -17,7 +17,16 @@ func NewValue[T any](init T) *Value[T] {
|
|||
|
||||
// Load loads the value atomically.
|
||||
func (v *Value[T]) Load() T {
|
||||
return v.value.Load().(T)
|
||||
var def T
|
||||
if v == nil {
|
||||
return def
|
||||
}
|
||||
|
||||
cur := v.value.Load()
|
||||
if cur == nil {
|
||||
return def
|
||||
}
|
||||
return cur.(T)
|
||||
}
|
||||
|
||||
// Store stores the value atomically.
|
||||
|
|
21
internal/atomicutil/value_test.go
Normal file
21
internal/atomicutil/value_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package atomicutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
v := NewValue(5)
|
||||
assert.Equal(t, 5, v.Load())
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
var v *Value[int]
|
||||
assert.Equal(t, 0, v.Load())
|
||||
})
|
||||
t.Run("default", func(t *testing.T) {
|
||||
var v Value[int]
|
||||
assert.Equal(t, 0, v.Load())
|
||||
})
|
||||
}
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
|
@ -18,6 +17,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
@ -46,7 +46,7 @@ type Manager struct {
|
|||
mu sync.RWMutex
|
||||
config *config.Config
|
||||
certmagic *certmagic.Config
|
||||
acmeMgr atomic.Value
|
||||
acmeMgr *atomicutil.Value[*certmagic.ACMEIssuer]
|
||||
srv *http.Server
|
||||
|
||||
*ocspCache
|
||||
|
@ -87,6 +87,7 @@ func newManager(ctx context.Context,
|
|||
mgr := &Manager{
|
||||
src: src,
|
||||
acmeTemplate: acmeTemplate,
|
||||
acmeMgr: atomicutil.NewValue(new(certmagic.ACMEIssuer)),
|
||||
certmagic: certmagicConfig,
|
||||
ocspCache: ocspRespCache,
|
||||
}
|
||||
|
@ -324,12 +325,7 @@ func (mgr *Manager) updateServer(ctx context.Context, cfg *config.Config) {
|
|||
}
|
||||
|
||||
func (mgr *Manager) handleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool {
|
||||
obj := mgr.acmeMgr.Load()
|
||||
if obj == nil {
|
||||
return false
|
||||
}
|
||||
acmeMgr := obj.(*certmagic.ACMEIssuer)
|
||||
return acmeMgr.HandleHTTPChallenge(w, r)
|
||||
return mgr.acmeMgr.Load().HandleHTTPChallenge(w, r)
|
||||
}
|
||||
|
||||
// GetConfig gets the config.
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/events"
|
||||
|
@ -73,17 +74,17 @@ func TestEvents(t *testing.T) {
|
|||
|
||||
srv := &Server{
|
||||
haveSetCapacity: make(map[string]bool),
|
||||
}
|
||||
srv.currentConfig.Store(versionedConfig{
|
||||
Config: &config.Config{
|
||||
OutboundPort: outboundPort,
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
DataBrokerURLString: "http://" + li.Addr().String(),
|
||||
GRPCInsecure: true,
|
||||
currentConfig: atomicutil.NewValue(versionedConfig{
|
||||
Config: &config.Config{
|
||||
OutboundPort: outboundPort,
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
DataBrokerURLString: "http://" + li.Addr().String(),
|
||||
GRPCInsecure: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}),
|
||||
}
|
||||
err := srv.storeEvent(ctx, new(events.EnvoyConfigurationEvent))
|
||||
assert.NoError(t, err)
|
||||
return err
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
|
@ -20,6 +19,7 @@ import (
|
|||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/controlplane/xdsmgr"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/httputil/reproxy"
|
||||
|
@ -38,18 +38,6 @@ type versionedConfig struct {
|
|||
version int64
|
||||
}
|
||||
|
||||
type atomicVersionedConfig struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func (avo *atomicVersionedConfig) Load() versionedConfig {
|
||||
return avo.value.Load().(versionedConfig)
|
||||
}
|
||||
|
||||
func (avo *atomicVersionedConfig) Store(cfg versionedConfig) {
|
||||
avo.value.Store(cfg)
|
||||
}
|
||||
|
||||
// A Service can be mounted on the control plane.
|
||||
type Service interface {
|
||||
Mount(r *mux.Router)
|
||||
|
@ -67,14 +55,14 @@ type Server struct {
|
|||
Builder *envoyconfig.Builder
|
||||
EventsMgr *events.Manager
|
||||
|
||||
currentConfig atomicVersionedConfig
|
||||
currentConfig *atomicutil.Value[versionedConfig]
|
||||
name string
|
||||
xdsmgr *xdsmgr.Manager
|
||||
filemgr *filemgr.Manager
|
||||
metricsMgr *config.MetricsManager
|
||||
reproxy *reproxy.Handler
|
||||
|
||||
httpRouter atomic.Value
|
||||
httpRouter *atomicutil.Value[*mux.Router]
|
||||
authenticateSvc Service
|
||||
proxySvc Service
|
||||
|
||||
|
@ -88,10 +76,11 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
|||
EventsMgr: eventsMgr,
|
||||
reproxy: reproxy.New(),
|
||||
haveSetCapacity: map[string]bool{},
|
||||
currentConfig: atomicutil.NewValue(versionedConfig{
|
||||
Config: cfg,
|
||||
}),
|
||||
httpRouter: atomicutil.NewValue(mux.NewRouter()),
|
||||
}
|
||||
srv.currentConfig.Store(versionedConfig{
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
var err error
|
||||
|
||||
|
@ -227,7 +216,7 @@ func (srv *Server) Run(ctx context.Context) error {
|
|||
Handler http.Handler
|
||||
}{
|
||||
{"http", srv.HTTPListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
srv.httpRouter.Load().(http.Handler).ServeHTTP(w, r)
|
||||
srv.httpRouter.Load().ServeHTTP(w, r)
|
||||
})},
|
||||
{"debug", srv.DebugListener, srv.DebugRouter},
|
||||
{"metrics", srv.MetricsListener, srv.MetricsRouter},
|
||||
|
@ -307,6 +296,6 @@ func (srv *Server) updateRouter(cfg *config.Config) error {
|
|||
if srv.proxySvc != nil {
|
||||
srv.proxySvc.Mount(httpRouter)
|
||||
}
|
||||
srv.httpRouter.Store(http.Handler(httpRouter))
|
||||
srv.httpRouter.Store(httpRouter)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
|
@ -106,21 +105,3 @@ func WithEventManager(mgr *events.Manager) Option {
|
|||
c.eventMgr = mgr
|
||||
}
|
||||
}
|
||||
|
||||
type atomicConfig struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func newAtomicConfig(cfg *config) *atomicConfig {
|
||||
ac := new(atomicConfig)
|
||||
ac.Store(cfg)
|
||||
return ac
|
||||
}
|
||||
|
||||
func (ac *atomicConfig) Load() *config {
|
||||
return ac.value.Load().(*config)
|
||||
}
|
||||
|
||||
func (ac *atomicConfig) Store(cfg *config) {
|
||||
ac.value.Store(cfg)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/identity/identity"
|
||||
|
@ -43,7 +44,7 @@ type (
|
|||
|
||||
// A Manager refreshes identity information using session and user data.
|
||||
type Manager struct {
|
||||
cfg *atomicConfig
|
||||
cfg *atomicutil.Value[*config]
|
||||
|
||||
sessionScheduler *scheduler.Scheduler
|
||||
userScheduler *scheduler.Scheduler
|
||||
|
@ -62,7 +63,7 @@ func New(
|
|||
options ...Option,
|
||||
) *Manager {
|
||||
mgr := &Manager{
|
||||
cfg: newAtomicConfig(newConfig()),
|
||||
cfg: atomicutil.NewValue(newConfig()),
|
||||
|
||||
sessionScheduler: scheduler.New(),
|
||||
userScheduler: scheduler.New(),
|
||||
|
|
|
@ -3,11 +3,12 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type dataBrokerSyncer struct {
|
||||
cfg *atomicConfig
|
||||
cfg *atomicutil.Value[*config]
|
||||
|
||||
update chan<- updateRecordsMessage
|
||||
clear chan<- struct{}
|
||||
|
@ -17,7 +18,7 @@ type dataBrokerSyncer struct {
|
|||
|
||||
func newDataBrokerSyncer(
|
||||
ctx context.Context,
|
||||
cfg *atomicConfig,
|
||||
cfg *atomicutil.Value[*config],
|
||||
update chan<- updateRecordsMessage,
|
||||
clear chan<- struct{},
|
||||
) *dataBrokerSyncer {
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
|
@ -66,30 +65,3 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) {
|
|||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// wrap the Authenticator for the AtomicAuthenticator to support a nil default value.
|
||||
type authenticatorValue struct {
|
||||
Authenticator
|
||||
}
|
||||
|
||||
// An AtomicAuthenticator is a strongly-typed atomic.Value for storing an authenticator.
|
||||
type AtomicAuthenticator struct {
|
||||
current atomic.Value
|
||||
}
|
||||
|
||||
// NewAtomicAuthenticator creates a new AtomicAuthenticator.
|
||||
func NewAtomicAuthenticator() *AtomicAuthenticator {
|
||||
a := &AtomicAuthenticator{}
|
||||
a.current.Store(authenticatorValue{})
|
||||
return a
|
||||
}
|
||||
|
||||
// Load loads the current authenticator.
|
||||
func (a *AtomicAuthenticator) Load() Authenticator {
|
||||
return a.current.Load().(authenticatorValue)
|
||||
}
|
||||
|
||||
// Store stores the authenticator.
|
||||
func (a *AtomicAuthenticator) Store(value Authenticator) {
|
||||
a.current.Store(authenticatorValue{value})
|
||||
}
|
||||
|
|
|
@ -5,16 +5,17 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
)
|
||||
|
||||
var (
|
||||
logger atomic.Value
|
||||
zapLogger atomic.Value
|
||||
logger = atomicutil.NewValue(new(zerolog.Logger))
|
||||
zapLogger = atomicutil.NewValue(new(zap.Logger))
|
||||
zapLevel zap.AtomicLevel
|
||||
)
|
||||
|
||||
|
@ -55,12 +56,12 @@ func SetLogger(l *zerolog.Logger) {
|
|||
|
||||
// Logger returns the global logger.
|
||||
func Logger() *zerolog.Logger {
|
||||
return logger.Load().(*zerolog.Logger)
|
||||
return logger.Load()
|
||||
}
|
||||
|
||||
// ZapLogger returns the global zap logger.
|
||||
func ZapLogger() *zap.Logger {
|
||||
return zapLogger.Load().(*zap.Logger)
|
||||
return zapLogger.Load()
|
||||
}
|
||||
|
||||
// SetLevel sets the minimum global log level. Options are 'debug' 'info' 'warn' and 'error'.
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -538,7 +539,7 @@ func TestProxy_jwt(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
|
||||
proxy := &Proxy{
|
||||
state: newAtomicProxyState(&proxyState{}),
|
||||
state: atomicutil.NewValue(&proxyState{}),
|
||||
}
|
||||
err := proxy.jwtAssertion(w, req)
|
||||
if !assert.Error(t, err) {
|
||||
|
|
|
@ -8,11 +8,11 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error {
|
|||
|
||||
// Proxy stores all the information associated with proxying a request.
|
||||
type Proxy struct {
|
||||
state *atomicProxyState
|
||||
currentOptions *config.AtomicOptions
|
||||
currentRouter atomic.Value
|
||||
state *atomicutil.Value[*proxyState]
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentRouter *atomicutil.Value[*mux.Router]
|
||||
}
|
||||
|
||||
// New takes a Proxy service from options and a validation function.
|
||||
|
@ -65,10 +65,10 @@ func New(cfg *config.Config) (*Proxy, error) {
|
|||
}
|
||||
|
||||
p := &Proxy{
|
||||
state: newAtomicProxyState(state),
|
||||
state: atomicutil.NewValue(state),
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||
}
|
||||
p.currentRouter.Store(httputil.NewRouter())
|
||||
|
||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||
return int64(len(p.currentOptions.Load().GetAllPolicies()))
|
||||
|
@ -128,5 +128,5 @@ func (p *Proxy) setHandlers(opts *config.Options) error {
|
|||
}
|
||||
|
||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
p.currentRouter.Load().(*mux.Router).ServeHTTP(w, r)
|
||||
p.currentRouter.Load().ServeHTTP(w, r)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package proxy
|
|||
import (
|
||||
"crypto/cipher"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
|
@ -94,21 +93,3 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
|||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
type atomicProxyState struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func newAtomicProxyState(state *proxyState) *atomicProxyState {
|
||||
aps := new(atomicProxyState)
|
||||
aps.Store(state)
|
||||
return aps
|
||||
}
|
||||
|
||||
func (aps *atomicProxyState) Load() *proxyState {
|
||||
return aps.value.Load().(*proxyState)
|
||||
}
|
||||
|
||||
func (aps *atomicProxyState) Store(state *proxyState) {
|
||||
aps.value.Store(state)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue