mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
authorize: use atomic state for properties (#1290)
This commit is contained in:
parent
c0e230acbb
commit
6dee647a16
10 changed files with 131 additions and 90 deletions
|
@ -7,87 +7,43 @@ import (
|
|||
"fmt"
|
||||
"html/template"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type atomicMarshalUnmarshaler struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func (a *atomicMarshalUnmarshaler) Load() encoding.MarshalUnmarshaler {
|
||||
return a.value.Load().(encoding.MarshalUnmarshaler)
|
||||
}
|
||||
|
||||
func (a *atomicMarshalUnmarshaler) Store(encoder encoding.MarshalUnmarshaler) {
|
||||
a.value.Store(encoder)
|
||||
}
|
||||
|
||||
// Authorize struct holds
|
||||
type Authorize struct {
|
||||
pe *evaluator.Evaluator
|
||||
store *evaluator.Store
|
||||
|
||||
state *atomicAuthorizeState
|
||||
store *evaluator.Store
|
||||
currentOptions *config.AtomicOptions
|
||||
currentEncoder atomicMarshalUnmarshaler
|
||||
templates *template.Template
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
|
||||
dataBrokerDataLock sync.RWMutex
|
||||
dataBrokerData evaluator.DataBrokerData
|
||||
}
|
||||
|
||||
// New validates and creates a new Authorize service from a set of config options.
|
||||
func New(opts *config.Options) (*Authorize, error) {
|
||||
if err := validateOptions(opts); err != nil {
|
||||
return nil, fmt.Errorf("authorize: bad options: %w", err)
|
||||
}
|
||||
|
||||
dataBrokerConn, err := grpc.NewGRPCClientConn(
|
||||
&grpc.Options{
|
||||
Addr: opts.DataBrokerURL,
|
||||
OverrideCertificateName: opts.OverrideCertificateName,
|
||||
CA: opts.CA,
|
||||
CAFile: opts.CAFile,
|
||||
RequestTimeout: opts.GRPCClientTimeout,
|
||||
ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin,
|
||||
WithInsecure: opts.GRPCInsecure,
|
||||
ServiceName: opts.Services,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error creating cache connection: %w", err)
|
||||
}
|
||||
|
||||
func New(cfg *config.Config) (*Authorize, error) {
|
||||
a := Authorize{
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
store: evaluator.NewStore(),
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
dataBrokerClient: databroker.NewDataBrokerServiceClient(dataBrokerConn),
|
||||
dataBrokerData: make(evaluator.DataBrokerData),
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
store: evaluator.NewStore(),
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
dataBrokerData: make(evaluator.DataBrokerData),
|
||||
}
|
||||
|
||||
var host string
|
||||
if opts.AuthenticateURL != nil {
|
||||
host = opts.AuthenticateURL.Host
|
||||
}
|
||||
encoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), host)
|
||||
state, err := newAuthorizeStateFromConfig(cfg, a.store)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.currentEncoder.Store(encoder)
|
||||
a.state = newAtomicAuthorizeState(state)
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
|
@ -119,10 +75,9 @@ func newPolicyEvaluator(opts *config.Options, store *evaluator.Store) (*evaluato
|
|||
func (a *Authorize) OnConfigChange(cfg *config.Config) {
|
||||
log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authorize: updating options")
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
pe, err := newPolicyEvaluator(cfg.Options, a.store)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authorize: failed to update policy with options")
|
||||
return
|
||||
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
|
||||
log.Error().Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
}
|
||||
a.pe = pe
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ func TestNew(t *testing.T) {
|
|||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := New(&tt.config)
|
||||
_, err := New(&config.Config{Options: &tt.config})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -87,11 +87,11 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
|||
SharedKey: tc.SharedKey,
|
||||
Policies: tc.Policies,
|
||||
}
|
||||
a, err := New(o)
|
||||
a, err := New(&config.Config{Options: o})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, a)
|
||||
|
||||
oldPe := a.pe
|
||||
oldPe := a.state.Load().evaluator
|
||||
cfg := &config.Config{Options: o}
|
||||
assertFunc := assert.True
|
||||
o.SigningKey = "bad-share-key"
|
||||
|
@ -100,7 +100,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
|||
assertFunc = assert.False
|
||||
}
|
||||
a.OnConfigChange(cfg)
|
||||
assertFunc(t, oldPe == a.pe)
|
||||
assertFunc(t, oldPe == a.state.Load().evaluator)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,15 +35,15 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}},
|
||||
JWTClaimsHeaders: []string{"email"},
|
||||
}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions()}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "")
|
||||
a.currentEncoder.Store(encoder)
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(opt)
|
||||
a.store = evaluator.NewStore()
|
||||
pe, err := newPolicyEvaluator(opt, a.store)
|
||||
require.NoError(t, err)
|
||||
a.pe = pe
|
||||
validJWT, _ := a.pe.SignedJWT(a.pe.JWTPayload(&evaluator.Request{
|
||||
a.state.Load().evaluator = pe
|
||||
validJWT, _ := pe.SignedJWT(pe.JWTPayload(&evaluator.Request{
|
||||
DataBrokerData: evaluator.DataBrokerData{
|
||||
"type.googleapis.com/session.Session": map[string]interface{}{
|
||||
"SESSION_ID": &session.Session{
|
||||
|
@ -204,9 +204,9 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions()}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "")
|
||||
a.currentEncoder.Store(encoder)
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||
|
|
|
@ -43,11 +43,13 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
|
|||
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
|
||||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
|
||||
// maybe rewrite http request for forward auth
|
||||
isForwardAuth := a.handleForwardAuth(in)
|
||||
hreq := getHTTPRequestFromCheckRequest(in)
|
||||
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), a.currentEncoder.Load())
|
||||
sessionState, _ := loadSession(a.currentEncoder.Load(), rawJWT)
|
||||
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
|
||||
sessionState, _ := loadSession(state.encoder, rawJWT)
|
||||
|
||||
if err := a.forceSync(ctx, sessionState); err != nil {
|
||||
log.Warn().Err(err).Msg("clearing session due to force sync failed")
|
||||
|
@ -58,7 +60,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
|
|||
defer a.dataBrokerDataLock.RUnlock()
|
||||
|
||||
req := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
|
||||
reply, err := a.pe.Evaluate(ctx, req)
|
||||
reply, err := state.evaluator.Evaluate(ctx, req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error during OPA evaluation")
|
||||
return nil, err
|
||||
|
@ -95,6 +97,8 @@ func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) *ses
|
|||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
|
||||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
|
||||
a.dataBrokerDataLock.RLock()
|
||||
s, ok := a.dataBrokerData.Get(sessionTypeURL, sessionID).(*session.Session)
|
||||
a.dataBrokerDataLock.RUnlock()
|
||||
|
@ -102,7 +106,7 @@ func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) *ses
|
|||
return s
|
||||
}
|
||||
|
||||
res, err := a.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: sessionTypeURL,
|
||||
Id: sessionID,
|
||||
})
|
||||
|
@ -125,6 +129,8 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
|
|||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
|
||||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
|
||||
a.dataBrokerDataLock.RLock()
|
||||
u, ok := a.dataBrokerData.Get(userTypeURL, userID).(*user.User)
|
||||
a.dataBrokerDataLock.RUnlock()
|
||||
|
@ -132,7 +138,7 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
|
|||
return u
|
||||
}
|
||||
|
||||
res, err := a.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: userTypeURL,
|
||||
Id: userID,
|
||||
})
|
||||
|
|
|
@ -47,9 +47,9 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions()}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "")
|
||||
a.currentEncoder.Store(encoder)
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||
|
@ -273,7 +273,7 @@ func Test_handleForwardAuth(t *testing.T) {
|
|||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions()}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
var fau *url.URL
|
||||
if tc.forwardAuthURL != "" {
|
||||
fau = mustParseURL(tc.forwardAuthURL)
|
||||
|
@ -288,9 +288,9 @@ func Test_handleForwardAuth(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions()}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "")
|
||||
a.currentEncoder.Store(encoder)
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||
|
@ -454,7 +454,7 @@ func TestSync(t *testing.T) {
|
|||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a, err := New(o)
|
||||
a, err := New(&config.Config{Options: o})
|
||||
require.NoError(t, err)
|
||||
a.dataBrokerData = evaluator.DataBrokerData{
|
||||
"type.googleapis.com/session.Session": map[string]interface{}{
|
||||
|
@ -464,7 +464,7 @@ func TestSync(t *testing.T) {
|
|||
"dbd_user1": &user.User{Id: "dbd_user1"},
|
||||
},
|
||||
}
|
||||
a.dataBrokerClient = tc.databrokerClient
|
||||
a.state.Load().dataBrokerClient = tc.databrokerClient
|
||||
assert.True(t, (a.forceSync(ctx, tc.sessionState) != nil) == tc.wantErr)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func (a *Authorize) runTypesSyncer(ctx context.Context, updateTypes chan<- []str
|
|||
return tryForever(ctx, func(backoff interface{ Reset() }) error {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync")
|
||||
defer span.End()
|
||||
stream, err := a.dataBrokerClient.SyncTypes(ctx, new(emptypb.Empty))
|
||||
stream, err := a.state.Load().dataBrokerClient.SyncTypes(ctx, new(emptypb.Empty))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ func (a *Authorize) runDataTypeSyncer(ctx context.Context, typeURL string) error
|
|||
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.GetAll")
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
for {
|
||||
res, err := a.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
res, err := a.state.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
Type: typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -119,7 +119,7 @@ func (a *Authorize) runDataTypeSyncer(ctx context.Context, typeURL string) error
|
|||
return tryForever(ctx, func(backoff interface{ Reset() }) error {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync")
|
||||
defer span.End()
|
||||
stream, err := a.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
stream, err := a.state.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: recordVersion,
|
||||
Type: typeURL,
|
||||
|
|
|
@ -91,8 +91,10 @@ func (a *Authorize) getJWTClaimHeaders(options *config.Options, signedJWT string
|
|||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
state := a.state.Load()
|
||||
|
||||
var claims map[string]interface{}
|
||||
payload, err := a.pe.ParseSignedJWT(signedJWT)
|
||||
payload, err := state.evaluator.ParseSignedJWT(signedJWT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -116,15 +116,15 @@ func TestAuthorize_getJWTClaimHeaders(t *testing.T) {
|
|||
}},
|
||||
}},
|
||||
}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions()}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: newAtomicAuthorizeState(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "")
|
||||
a.currentEncoder.Store(encoder)
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(opt)
|
||||
a.store = evaluator.NewStore()
|
||||
pe, err := newPolicyEvaluator(opt, a.store)
|
||||
require.NoError(t, err)
|
||||
a.pe = pe
|
||||
signedJWT, _ := a.pe.SignedJWT(a.pe.JWTPayload(&evaluator.Request{
|
||||
a.state.Load().evaluator = pe
|
||||
signedJWT, _ := pe.SignedJWT(pe.JWTPayload(&evaluator.Request{
|
||||
DataBrokerData: evaluator.DataBrokerData{
|
||||
"type.googleapis.com/session.Session": map[string]interface{}{
|
||||
"SESSION_ID": &session.Session{
|
||||
|
|
78
authorize/state.go
Normal file
78
authorize/state.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type authorizeState struct {
|
||||
evaluator *evaluator.Evaluator
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
}
|
||||
|
||||
func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*authorizeState, error) {
|
||||
if err := validateOptions(cfg.Options); err != nil {
|
||||
return nil, fmt.Errorf("authorize: bad options: %w", err)
|
||||
}
|
||||
|
||||
state := new(authorizeState)
|
||||
|
||||
var err error
|
||||
|
||||
state.evaluator, err = newPolicyEvaluator(cfg.Options, store)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||
}
|
||||
|
||||
var host string
|
||||
if cfg.Options.AuthenticateURL != nil {
|
||||
host = cfg.Options.AuthenticateURL.Host
|
||||
}
|
||||
state.encoder, err = jws.NewHS256Signer([]byte(cfg.Options.SharedKey), host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc, err := grpc.GetGRPCClientConn("databroker", &grpc.Options{
|
||||
Addr: cfg.Options.DataBrokerURL,
|
||||
OverrideCertificateName: cfg.Options.OverrideCertificateName,
|
||||
CA: cfg.Options.CA,
|
||||
CAFile: cfg.Options.CAFile,
|
||||
RequestTimeout: cfg.Options.GRPCClientTimeout,
|
||||
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
|
||||
WithInsecure: cfg.Options.GRPCInsecure,
|
||||
ServiceName: cfg.Options.Services,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error creating databroker connection: %w", err)
|
||||
}
|
||||
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(cc)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
type atomicAuthorizeState struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func newAtomicAuthorizeState(state *authorizeState) *atomicAuthorizeState {
|
||||
aas := new(atomicAuthorizeState)
|
||||
aas.Store(state)
|
||||
return aas
|
||||
}
|
||||
|
||||
func (aas *atomicAuthorizeState) Load() *authorizeState {
|
||||
return aas.value.Load().(*authorizeState)
|
||||
}
|
||||
|
||||
func (aas *atomicAuthorizeState) Store(state *authorizeState) {
|
||||
aas.value.Store(state)
|
||||
}
|
|
@ -150,7 +150,7 @@ func setupAuthenticate(src config.Source, cfg *config.Config, controlPlane *cont
|
|||
}
|
||||
|
||||
func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
||||
svc, err := authorize.New(cfg.Options)
|
||||
svc, err := authorize.New(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue