mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 05:57:19 +02:00
authorize: add databroker server and record version to result, force sync via polling (#2024)
* authorize: add databroker server and record version to result, force sync via polling * wrap inmem store to take read lock when grabbing databroker versions * address code review comments * reset max to 0
This commit is contained in:
parent
8f97b0d6ee
commit
d7ab817de7
17 changed files with 467 additions and 362 deletions
|
@ -39,7 +39,7 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||||
a.state.Load().encoder = encoder
|
a.state.Load().encoder = encoder
|
||||||
a.currentOptions.Store(opt)
|
a.currentOptions.Store(opt)
|
||||||
a.store = evaluator.NewStoreFromProtos(
|
a.store = evaluator.NewStoreFromProtos(0,
|
||||||
&session.Session{
|
&session.Session{
|
||||||
Id: "SESSION_ID",
|
Id: "SESSION_ID",
|
||||||
UserId: "USER_ID",
|
UserId: "USER_ID",
|
||||||
|
|
|
@ -14,7 +14,7 @@ func TestCustomEvaluator(t *testing.T) {
|
||||||
|
|
||||||
store := NewStore()
|
store := NewStore()
|
||||||
t.Run("bool deny", func(t *testing.T) {
|
t.Run("bool deny", func(t *testing.T) {
|
||||||
ce := NewCustomEvaluator(store.opaStore)
|
ce := NewCustomEvaluator(store)
|
||||||
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
||||||
RegoPolicy: `
|
RegoPolicy: `
|
||||||
package pomerium.custom_policy
|
package pomerium.custom_policy
|
||||||
|
@ -29,7 +29,7 @@ func TestCustomEvaluator(t *testing.T) {
|
||||||
assert.Empty(t, res.Reason)
|
assert.Empty(t, res.Reason)
|
||||||
})
|
})
|
||||||
t.Run("set deny", func(t *testing.T) {
|
t.Run("set deny", func(t *testing.T) {
|
||||||
ce := NewCustomEvaluator(store.opaStore)
|
ce := NewCustomEvaluator(store)
|
||||||
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
||||||
RegoPolicy: `
|
RegoPolicy: `
|
||||||
package pomerium.custom_policy
|
package pomerium.custom_policy
|
||||||
|
@ -44,7 +44,7 @@ func TestCustomEvaluator(t *testing.T) {
|
||||||
assert.Equal(t, "test", res.Reason)
|
assert.Equal(t, "test", res.Reason)
|
||||||
})
|
})
|
||||||
t.Run("missing package", func(t *testing.T) {
|
t.Run("missing package", func(t *testing.T) {
|
||||||
ce := NewCustomEvaluator(store.opaStore)
|
ce := NewCustomEvaluator(store)
|
||||||
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
||||||
RegoPolicy: `allow = true`,
|
RegoPolicy: `allow = true`,
|
||||||
})
|
})
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/open-policy-agent/opa/rego"
|
"github.com/open-policy-agent/opa/rego"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
@ -29,7 +28,7 @@ type Evaluator struct {
|
||||||
// New creates a new Evaluator.
|
// New creates a new Evaluator.
|
||||||
func New(options *config.Options, store *Store) (*Evaluator, error) {
|
func New(options *config.Options, store *Store) (*Evaluator, error) {
|
||||||
e := &Evaluator{
|
e := &Evaluator{
|
||||||
custom: NewCustomEvaluator(store.opaStore),
|
custom: NewCustomEvaluator(store),
|
||||||
policies: options.GetAllPolicies(),
|
policies: options.GetAllPolicies(),
|
||||||
store: store,
|
store: store,
|
||||||
}
|
}
|
||||||
|
@ -55,7 +54,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) {
|
||||||
store.UpdateSigningKey(jwk)
|
store.UpdateSigningKey(jwk)
|
||||||
|
|
||||||
e.rego = rego.New(
|
e.rego = rego.New(
|
||||||
rego.Store(store.opaStore),
|
rego.Store(store),
|
||||||
rego.Module("pomerium.authz", string(authzPolicy)),
|
rego.Module("pomerium.authz", string(authzPolicy)),
|
||||||
rego.Query("result = data.pomerium.authz"),
|
rego.Query("result = data.pomerium.authz"),
|
||||||
getGoogleCloudServerlessHeadersRegoOption,
|
getGoogleCloudServerlessHeadersRegoOption,
|
||||||
|
@ -91,6 +90,9 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
|
||||||
MatchingPolicy: getMatchingPolicy(res[0].Bindings.WithoutWildcards(), e.policies),
|
MatchingPolicy: getMatchingPolicy(res[0].Bindings.WithoutWildcards(), e.policies),
|
||||||
Headers: getHeadersVar(res[0].Bindings.WithoutWildcards()),
|
Headers: getHeadersVar(res[0].Bindings.WithoutWildcards()),
|
||||||
}
|
}
|
||||||
|
evalResult.DataBrokerServerVersion, evalResult.DataBrokerRecordVersion = getDataBrokerVersions(
|
||||||
|
res[0].Bindings,
|
||||||
|
)
|
||||||
|
|
||||||
allow := getAllowVar(res[0].Bindings.WithoutWildcards())
|
allow := getAllowVar(res[0].Bindings.WithoutWildcards())
|
||||||
// evaluate any custom policies
|
// evaluate any custom policies
|
||||||
|
@ -181,95 +183,3 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input
|
||||||
i.IsValidClientCertificate = isValidClientCertificate
|
i.IsValidClientCertificate = isValidClientCertificate
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
// Result is the result of evaluation.
|
|
||||||
type Result struct {
|
|
||||||
Status int
|
|
||||||
Message string
|
|
||||||
Headers map[string]string
|
|
||||||
MatchingPolicy *config.Policy
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMatchingPolicy(vars rego.Vars, policies []config.Policy) *config.Policy {
|
|
||||||
result, ok := vars["result"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
idx, err := strconv.Atoi(fmt.Sprint(result["route_policy_idx"]))
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if idx >= len(policies) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &policies[idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAllowVar(vars rego.Vars) bool {
|
|
||||||
result, ok := vars["result"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
allow, ok := result["allow"].(bool)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return allow
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDenyVar(vars rego.Vars) []Result {
|
|
||||||
result, ok := vars["result"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
denials, ok := result["deny"].([]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
results := make([]Result, 0, len(denials))
|
|
||||||
for _, denial := range denials {
|
|
||||||
denial, ok := denial.([]interface{})
|
|
||||||
if !ok || len(denial) != 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := strconv.Atoi(fmt.Sprint(denial[0]))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("invalid type in deny")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msg := fmt.Sprint(denial[1])
|
|
||||||
|
|
||||||
results = append(results, Result{
|
|
||||||
Status: status,
|
|
||||||
Message: msg,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHeadersVar(vars rego.Vars) map[string]string {
|
|
||||||
headers := make(map[string]string)
|
|
||||||
|
|
||||||
result, ok := vars["result"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return headers
|
|
||||||
}
|
|
||||||
|
|
||||||
m, ok := result["identity_headers"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return headers
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range m {
|
|
||||||
headers[k] = fmt.Sprint(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers
|
|
||||||
}
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import (
|
||||||
func TestJSONMarshal(t *testing.T) {
|
func TestJSONMarshal(t *testing.T) {
|
||||||
opt := config.NewDefaultOptions()
|
opt := config.NewDefaultOptions()
|
||||||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||||
e, err := New(opt, NewStoreFromProtos(
|
e, err := New(opt, NewStoreFromProtos(0,
|
||||||
&session.Session{
|
&session.Session{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
},
|
},
|
||||||
|
@ -100,7 +100,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
store := NewStoreFromProtos()
|
store := NewStoreFromProtos(0)
|
||||||
data, _ := ptypes.MarshalAny(&session.Session{
|
data, _ := ptypes.MarshalAny(&session.Session{
|
||||||
Version: "1",
|
Version: "1",
|
||||||
Id: sessionID,
|
Id: sessionID,
|
||||||
|
@ -116,7 +116,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
|
||||||
RefreshToken: "REFRESH TOKEN",
|
RefreshToken: "REFRESH TOKEN",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
store.UpdateRecord(&databroker.Record{
|
store.UpdateRecord(0, &databroker.Record{
|
||||||
Version: 1,
|
Version: 1,
|
||||||
Type: "type.googleapis.com/session.Session",
|
Type: "type.googleapis.com/session.Session",
|
||||||
Id: sessionID,
|
Id: sessionID,
|
||||||
|
@ -127,7 +127,7 @@ func TestEvaluator_Evaluate(t *testing.T) {
|
||||||
Id: userID,
|
Id: userID,
|
||||||
Email: "foo@example.com",
|
Email: "foo@example.com",
|
||||||
})
|
})
|
||||||
store.UpdateRecord(&databroker.Record{
|
store.UpdateRecord(0, &databroker.Record{
|
||||||
Version: 1,
|
Version: 1,
|
||||||
Type: "type.googleapis.com/user.User",
|
Type: "type.googleapis.com/user.User",
|
||||||
Id: userID,
|
Id: userID,
|
||||||
|
@ -189,7 +189,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
RefreshToken: "REFRESH TOKEN",
|
RefreshToken: "REFRESH TOKEN",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
store.UpdateRecord(&databroker.Record{
|
store.UpdateRecord(0, &databroker.Record{
|
||||||
Version: uint64(i),
|
Version: uint64(i),
|
||||||
Type: "type.googleapis.com/session.Session",
|
Type: "type.googleapis.com/session.Session",
|
||||||
Id: sessionID,
|
Id: sessionID,
|
||||||
|
@ -199,7 +199,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
Version: fmt.Sprint(i),
|
Version: fmt.Sprint(i),
|
||||||
Id: userID,
|
Id: userID,
|
||||||
})
|
})
|
||||||
store.UpdateRecord(&databroker.Record{
|
store.UpdateRecord(0, &databroker.Record{
|
||||||
Version: uint64(i),
|
Version: uint64(i),
|
||||||
Type: "type.googleapis.com/user.User",
|
Type: "type.googleapis.com/user.User",
|
||||||
Id: userID,
|
Id: userID,
|
||||||
|
@ -211,7 +211,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
Id: userID,
|
Id: userID,
|
||||||
GroupIds: []string{"1", "2", "3", "4"},
|
GroupIds: []string{"1", "2", "3", "4"},
|
||||||
})
|
})
|
||||||
store.UpdateRecord(&databroker.Record{
|
store.UpdateRecord(0, &databroker.Record{
|
||||||
Version: uint64(i),
|
Version: uint64(i),
|
||||||
Type: data.TypeUrl,
|
Type: data.TypeUrl,
|
||||||
Id: userID,
|
Id: userID,
|
||||||
|
@ -222,7 +222,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
Version: fmt.Sprint(i),
|
Version: fmt.Sprint(i),
|
||||||
Id: fmt.Sprint(i),
|
Id: fmt.Sprint(i),
|
||||||
})
|
})
|
||||||
store.UpdateRecord(&databroker.Record{
|
store.UpdateRecord(0, &databroker.Record{
|
||||||
Version: uint64(i),
|
Version: uint64(i),
|
||||||
Type: data.TypeUrl,
|
Type: data.TypeUrl,
|
||||||
Id: fmt.Sprint(i),
|
Id: fmt.Sprint(i),
|
||||||
|
|
|
@ -5,6 +5,10 @@ default allow = false
|
||||||
# 5 minutes from now in seconds
|
# 5 minutes from now in seconds
|
||||||
five_minutes := (time.now_ns() / 1e9) + (60 * 5)
|
five_minutes := (time.now_ns() / 1e9) + (60 * 5)
|
||||||
|
|
||||||
|
# databroker versions to know which version of the data was evaluated
|
||||||
|
databroker_server_version := data.databroker_server_version
|
||||||
|
databroker_record_version := data.databroker_record_version
|
||||||
|
|
||||||
route_policy_idx := first_allowed_route_policy_idx(input.http.url)
|
route_policy_idx := first_allowed_route_policy_idx(input.http.url)
|
||||||
|
|
||||||
route_policy := data.route_policies[route_policy_idx]
|
route_policy := data.route_policies[route_policy_idx]
|
||||||
|
|
|
@ -3,6 +3,7 @@ package evaluator
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -11,6 +12,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
|
||||||
|
@ -37,13 +39,13 @@ func TestOPA(t *testing.T) {
|
||||||
eval := func(t *testing.T, policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result {
|
eval := func(t *testing.T, policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result {
|
||||||
authzPolicy, err := readPolicy()
|
authzPolicy, err := readPolicy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
store := NewStoreFromProtos(data...)
|
store := NewStoreFromProtos(math.MaxUint64, data...)
|
||||||
store.UpdateIssuer("authenticate.example.com")
|
store.UpdateIssuer("authenticate.example.com")
|
||||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user"))
|
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user"))
|
||||||
store.UpdateRoutePolicies(policies)
|
store.UpdateRoutePolicies(policies)
|
||||||
store.UpdateSigningKey(privateJWK)
|
store.UpdateSigningKey(privateJWK)
|
||||||
r := rego.New(
|
r := rego.New(
|
||||||
rego.Store(store.opaStore),
|
rego.Store(store),
|
||||||
rego.Module("pomerium.authz", string(authzPolicy)),
|
rego.Module("pomerium.authz", string(authzPolicy)),
|
||||||
rego.Query("result = data.pomerium.authz"),
|
rego.Query("result = data.pomerium.authz"),
|
||||||
getGoogleCloudServerlessHeadersRegoOption,
|
getGoogleCloudServerlessHeadersRegoOption,
|
||||||
|
@ -646,4 +648,12 @@ func TestOPA(t *testing.T) {
|
||||||
}, true)
|
}, true)
|
||||||
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
||||||
})
|
})
|
||||||
|
t.Run("databroker versions", func(t *testing.T) {
|
||||||
|
res := eval(t, nil, []proto.Message{
|
||||||
|
wrapperspb.String("test"),
|
||||||
|
}, &Request{}, false)
|
||||||
|
serverVersion, recordVersion := getDataBrokerVersions(res.Bindings)
|
||||||
|
assert.Equal(t, uint64(math.MaxUint64), serverVersion)
|
||||||
|
assert.NotEqual(t, uint64(0), recordVersion) // random
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
115
authorize/evaluator/result.go
Normal file
115
authorize/evaluator/result.go
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
package evaluator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/open-policy-agent/opa/rego"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Result is the result of evaluation.
|
||||||
|
type Result struct {
|
||||||
|
Status int
|
||||||
|
Message string
|
||||||
|
Headers map[string]string
|
||||||
|
MatchingPolicy *config.Policy
|
||||||
|
|
||||||
|
DataBrokerServerVersion, DataBrokerRecordVersion uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMatchingPolicy(vars rego.Vars, policies []config.Policy) *config.Policy {
|
||||||
|
result, ok := vars["result"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx, err := strconv.Atoi(fmt.Sprint(result["route_policy_idx"]))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx >= len(policies) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &policies[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAllowVar(vars rego.Vars) bool {
|
||||||
|
result, ok := vars["result"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
allow, ok := result["allow"].(bool)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return allow
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDenyVar(vars rego.Vars) []Result {
|
||||||
|
result, ok := vars["result"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
denials, ok := result["deny"].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]Result, 0, len(denials))
|
||||||
|
for _, denial := range denials {
|
||||||
|
denial, ok := denial.([]interface{})
|
||||||
|
if !ok || len(denial) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := strconv.Atoi(fmt.Sprint(denial[0]))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("invalid type in deny")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg := fmt.Sprint(denial[1])
|
||||||
|
|
||||||
|
results = append(results, Result{
|
||||||
|
Status: status,
|
||||||
|
Message: msg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHeadersVar(vars rego.Vars) map[string]string {
|
||||||
|
headers := make(map[string]string)
|
||||||
|
|
||||||
|
result, ok := vars["result"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := result["identity_headers"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range m {
|
||||||
|
headers[k] = fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDataBrokerVersions(vars rego.Vars) (serverVersion, recordVersion uint64) {
|
||||||
|
result, ok := vars["result"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
serverVersion, _ = strconv.ParseUint(fmt.Sprint(result["databroker_server_version"]), 10, 64)
|
||||||
|
recordVersion, _ = strconv.ParseUint(fmt.Sprint(result["databroker_record_version"]), 10, 64)
|
||||||
|
return serverVersion, recordVersion
|
||||||
|
}
|
|
@ -25,7 +25,7 @@ import (
|
||||||
|
|
||||||
// A Store stores data for the OPA rego policy evaluation.
|
// A Store stores data for the OPA rego policy evaluation.
|
||||||
type Store struct {
|
type Store struct {
|
||||||
opaStore storage.Store
|
storage.Store
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
dataBrokerData map[string]map[string]proto.Message
|
dataBrokerData map[string]map[string]proto.Message
|
||||||
|
@ -34,13 +34,13 @@ type Store struct {
|
||||||
// NewStore creates a new Store.
|
// NewStore creates a new Store.
|
||||||
func NewStore() *Store {
|
func NewStore() *Store {
|
||||||
return &Store{
|
return &Store{
|
||||||
opaStore: inmem.New(),
|
Store: inmem.New(),
|
||||||
dataBrokerData: make(map[string]map[string]proto.Message),
|
dataBrokerData: make(map[string]map[string]proto.Message),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStoreFromProtos creates a new Store from an existing set of protobuf messages.
|
// NewStoreFromProtos creates a new Store from an existing set of protobuf messages.
|
||||||
func NewStoreFromProtos(msgs ...proto.Message) *Store {
|
func NewStoreFromProtos(serverVersion uint64, msgs ...proto.Message) *Store {
|
||||||
s := NewStore()
|
s := NewStore()
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
any, err := anypb.New(msg)
|
any, err := anypb.New(msg)
|
||||||
|
@ -58,11 +58,34 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
|
||||||
record.Id = hasID.GetId()
|
record.Id = hasID.GetId()
|
||||||
}
|
}
|
||||||
|
|
||||||
s.UpdateRecord(record)
|
s.UpdateRecord(serverVersion, record)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewTransaction calls the underlying store NewTransaction and takes the transaction lock.
|
||||||
|
func (s *Store) NewTransaction(ctx context.Context, params ...storage.TransactionParams) (storage.Transaction, error) {
|
||||||
|
txn, err := s.Store.NewTransaction(ctx, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
return txn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit calls the underlying store Commit and releases the transaction lock.
|
||||||
|
func (s *Store) Commit(ctx context.Context, txn storage.Transaction) error {
|
||||||
|
err := s.Store.Commit(ctx, txn)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort calls the underlying store Abort and releases the transaction lock.
|
||||||
|
func (s *Store) Abort(ctx context.Context, txn storage.Transaction) {
|
||||||
|
s.Store.Abort(ctx, txn)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
// ClearRecords removes all the records from the store.
|
// ClearRecords removes all the records from the store.
|
||||||
func (s *Store) ClearRecords() {
|
func (s *Store) ClearRecords() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
@ -107,10 +130,13 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRecord updates a record in the store.
|
// UpdateRecord updates a record in the store.
|
||||||
func (s *Store) UpdateRecord(record *databroker.Record) {
|
func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.write("/databroker_server_version", fmt.Sprint(serverVersion))
|
||||||
|
s.write("/databroker_record_version", fmt.Sprint(record.GetVersion()))
|
||||||
|
|
||||||
m, ok := s.dataBrokerData[record.GetType()]
|
m, ok := s.dataBrokerData[record.GetType()]
|
||||||
if !ok {
|
if !ok {
|
||||||
m = make(map[string]proto.Message)
|
m = make(map[string]proto.Message)
|
||||||
|
@ -130,31 +156,8 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) write(rawPath string, value interface{}) {
|
func (s *Store) write(rawPath string, value interface{}) {
|
||||||
p, ok := storage.ParsePath(rawPath)
|
err := storage.Txn(context.Background(), s.Store, storage.WriteParams, func(txn storage.Transaction) error {
|
||||||
if !ok {
|
return s.writeTxn(txn, rawPath, value)
|
||||||
log.Error().
|
|
||||||
Str("path", rawPath).
|
|
||||||
Msg("opa-store: invalid path, ignoring data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := storage.Txn(context.Background(), s.opaStore, storage.WriteParams, func(txn storage.Transaction) error {
|
|
||||||
if len(p) > 1 {
|
|
||||||
err := storage.MakeDir(context.Background(), s.opaStore, txn, p[:len(p)-1])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var op storage.PatchOp = storage.ReplaceOp
|
|
||||||
_, err := s.opaStore.Read(context.Background(), txn, p)
|
|
||||||
if storage.IsNotFound(err) {
|
|
||||||
op = storage.AddOp
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.opaStore.Write(context.Background(), txn, op, p, value)
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("opa-store: error writing data")
|
log.Error().Err(err).Msg("opa-store: error writing data")
|
||||||
|
@ -162,6 +165,30 @@ func (s *Store) write(rawPath string, value interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error {
|
||||||
|
p, ok := storage.ParsePath(rawPath)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid path")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p) > 1 {
|
||||||
|
err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var op storage.PatchOp = storage.ReplaceOp
|
||||||
|
_, err := s.Read(context.Background(), txn, p)
|
||||||
|
if storage.IsNotFound(err) {
|
||||||
|
op = storage.AddOp
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.Write(context.Background(), txn, op, p, value)
|
||||||
|
}
|
||||||
|
|
||||||
// GetDataBrokerRecordOption returns a function option that can retrieve databroker data.
|
// GetDataBrokerRecordOption returns a function option that can retrieve databroker data.
|
||||||
func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
||||||
return rego.Function2(®o.Function{
|
return rego.Function2(®o.Function{
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestStore(t *testing.T) {
|
||||||
Email: "name@example.com",
|
Email: "name@example.com",
|
||||||
}
|
}
|
||||||
any, _ := anypb.New(u)
|
any, _ := anypb.New(u)
|
||||||
s.UpdateRecord(&databroker.Record{
|
s.UpdateRecord(0, &databroker.Record{
|
||||||
Version: 1,
|
Version: 1,
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: u.GetId(),
|
Id: u.GetId(),
|
||||||
|
@ -36,7 +36,7 @@ func TestStore(t *testing.T) {
|
||||||
"email": "name@example.com",
|
"email": "name@example.com",
|
||||||
}, toMap(v))
|
}, toMap(v))
|
||||||
|
|
||||||
s.UpdateRecord(&databroker.Record{
|
s.UpdateRecord(0, &databroker.Record{
|
||||||
Version: 2,
|
Version: 2,
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: u.GetId(),
|
Id: u.GetId(),
|
||||||
|
@ -47,7 +47,7 @@ func TestStore(t *testing.T) {
|
||||||
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||||
assert.Nil(t, v)
|
assert.Nil(t, v)
|
||||||
|
|
||||||
s.UpdateRecord(&databroker.Record{
|
s.UpdateRecord(0, &databroker.Record{
|
||||||
Version: 3,
|
Version: 3,
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: u.GetId(),
|
Id: u.GetId(),
|
||||||
|
|
|
@ -3,7 +3,6 @@ package authorize
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -19,10 +18,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
||||||
|
|
||||||
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
||||||
)
|
)
|
||||||
|
@ -83,81 +79,6 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
return a.deniedResponse(in, int32(reply.Status), reply.Message, nil)
|
return a.deniedResponse(in, int32(reply.Status), reply.Message, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (*user.User, error) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSync")
|
|
||||||
defer span.End()
|
|
||||||
if ss == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
s := a.forceSyncSession(ctx, ss.ID)
|
|
||||||
if s == nil {
|
|
||||||
return nil, errors.New("session not found")
|
|
||||||
}
|
|
||||||
u := a.forceSyncUser(ctx, s.GetUserId())
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) interface{ GetUserId() string } {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
state := a.state.Load()
|
|
||||||
|
|
||||||
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
|
|
||||||
if ok {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
|
|
||||||
if ok {
|
|
||||||
return sa
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
|
||||||
Type: grpcutil.GetTypeURL(new(session.Session)),
|
|
||||||
Id: sessionID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Err(err).Msg("failed to get session from databroker")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID); current == nil {
|
|
||||||
a.store.UpdateRecord(res.GetRecord())
|
|
||||||
}
|
|
||||||
s, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
state := a.state.Load()
|
|
||||||
|
|
||||||
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
|
|
||||||
if ok {
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
|
||||||
Type: grpcutil.GetTypeURL(new(user.User)),
|
|
||||||
Id: userID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Err(err).Msg("failed to get user from databroker")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if current := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID); current == nil {
|
|
||||||
a.store.UpdateRecord(res.GetRecord())
|
|
||||||
}
|
|
||||||
u, _ = a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
|
|
||||||
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func getForwardAuthURL(r *http.Request) *url.URL {
|
func getForwardAuthURL(r *http.Request) *url.URL {
|
||||||
urqQuery := r.URL.Query().Get("uri")
|
urqQuery := r.URL.Query().Get("uri")
|
||||||
u, _ := urlutil.ParseAndValidateURL(urqQuery)
|
u, _ := urlutil.ParseAndValidateURL(urqQuery)
|
||||||
|
@ -329,6 +250,8 @@ func logAuthorizeCheck(
|
||||||
evt = evt.Str("message", reply.Message)
|
evt = evt.Str("message", reply.Message)
|
||||||
evt = evt.Str("user", u.GetId())
|
evt = evt.Str("user", u.GetId())
|
||||||
evt = evt.Str("email", u.GetEmail())
|
evt = evt.Str("email", u.GetEmail())
|
||||||
|
evt = evt.Uint64("databroker_server_version", reply.DataBrokerServerVersion)
|
||||||
|
evt = evt.Uint64("databroker_record_version", reply.DataBrokerRecordVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// potentially sensitive, only log if debug mode
|
// potentially sensitive, only log if debug mode
|
||||||
|
|
|
@ -2,12 +2,10 @@ package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
||||||
"github.com/golang/protobuf/ptypes"
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -21,8 +19,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const certPEM = `
|
const certPEM = `
|
||||||
|
@ -313,132 +309,6 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSync(t *testing.T) {
|
|
||||||
mockSession := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
data, _ := ptypes.MarshalAny(&session.Session{
|
|
||||||
Id: in.GetId(),
|
|
||||||
UserId: "user1",
|
|
||||||
})
|
|
||||||
return &databroker.GetResponse{
|
|
||||||
Record: &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: data.GetTypeUrl(),
|
|
||||||
Id: in.GetId(),
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
mockUser := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
data, _ := ptypes.MarshalAny(&user.User{Id: in.GetId()})
|
|
||||||
return &databroker.GetResponse{
|
|
||||||
Record: &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: data.GetTypeUrl(),
|
|
||||||
Id: in.GetId(),
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mockGetByType := map[string]func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error){
|
|
||||||
"type.googleapis.com/session.Session": mockSession,
|
|
||||||
"type.googleapis.com/user.User": mockUser,
|
|
||||||
}
|
|
||||||
dbdClient := mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
if in.GetId() == "not-existed-id" {
|
|
||||||
return nil, errors.New("not found")
|
|
||||||
}
|
|
||||||
f, ok := mockGetByType[in.GetType()]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("not found")
|
|
||||||
}
|
|
||||||
return f(ctx, in, opts...)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
o := &config.Options{
|
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
|
||||||
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
|
|
||||||
Policies: testPolicies(t),
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
sessionState *sessions.State
|
|
||||||
databrokerClient mockDataBrokerServiceClient
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"good with data in databroker data",
|
|
||||||
&sessions.State{ID: "dbd_session_id"},
|
|
||||||
mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
data, _ := ptypes.MarshalAny(&session.Session{
|
|
||||||
Id: in.GetId(),
|
|
||||||
UserId: "dbd_user1",
|
|
||||||
})
|
|
||||||
if in.GetType() == "type.googleapis.com/user.User" {
|
|
||||||
data, _ = ptypes.MarshalAny(&user.User{
|
|
||||||
Id: "dbd_user1",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &databroker.GetResponse{
|
|
||||||
Record: &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: data.GetTypeUrl(),
|
|
||||||
Id: in.GetId(),
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{"good", &sessions.State{ID: "SESSION_ID"}, dbdClient, false},
|
|
||||||
{"nil session state", nil, dbdClient, false},
|
|
||||||
{"not found session state", &sessions.State{ID: "not-existed-id"}, dbdClient, true},
|
|
||||||
{
|
|
||||||
"user not found",
|
|
||||||
&sessions.State{ID: "session_with_not_found_user"},
|
|
||||||
mockDataBrokerServiceClient{
|
|
||||||
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
|
||||||
if in.GetType() == "type.googleapis.com/user.User" {
|
|
||||||
return nil, errors.New("user not found")
|
|
||||||
}
|
|
||||||
data, _ := ptypes.MarshalAny(&session.Session{
|
|
||||||
Id: in.GetId(),
|
|
||||||
UserId: "user1",
|
|
||||||
})
|
|
||||||
return &databroker.GetResponse{
|
|
||||||
Record: &databroker.Record{
|
|
||||||
Version: 1,
|
|
||||||
Type: data.GetTypeUrl(),
|
|
||||||
Id: in.GetId(),
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
a, err := New(&config.Config{Options: o})
|
|
||||||
require.NoError(t, err)
|
|
||||||
a.state.Load().dataBrokerClient = dbdClient
|
|
||||||
_, err = a.forceSync(ctx, tc.sessionState)
|
|
||||||
assert.True(t, (err != nil) == tc.wantErr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockDataBrokerServiceClient struct {
|
type mockDataBrokerServiceClient struct {
|
||||||
databroker.DataBrokerServiceClient
|
databroker.DataBrokerServiceClient
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,32 @@ package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
forceSyncRecordMaxWait = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionOrServiceAccount interface {
|
||||||
|
GetUserId() string
|
||||||
|
}
|
||||||
|
|
||||||
type dataBrokerSyncer struct {
|
type dataBrokerSyncer struct {
|
||||||
*databroker.Syncer
|
*databroker.Syncer
|
||||||
authorize *Authorize
|
authorize *Authorize
|
||||||
|
@ -29,9 +50,9 @@ func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
|
||||||
syncer.authorize.store.ClearRecords()
|
syncer.authorize.store.ClearRecords()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
|
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
syncer.authorize.store.UpdateRecord(record)
|
syncer.authorize.store.UpdateRecord(serverVersion, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
// the first time we update records we signal the initial sync
|
// the first time we update records we signal the initial sync
|
||||||
|
@ -39,3 +60,112 @@ func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*da
|
||||||
close(syncer.authorize.dataBrokerInitialSync)
|
close(syncer.authorize.dataBrokerInitialSync)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (*user.User, error) {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "authorize.forceSync")
|
||||||
|
defer span.End()
|
||||||
|
if ss == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
s := a.forceSyncSession(ctx, ss.ID)
|
||||||
|
if s == nil {
|
||||||
|
return nil, errors.New("session not found")
|
||||||
|
}
|
||||||
|
u := a.forceSyncUser(ctx, s.GetUserId())
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session)
|
||||||
|
if ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount)
|
||||||
|
if ok {
|
||||||
|
return sa
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the session to show up
|
||||||
|
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s, ok = record.(*session.Session)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User)
|
||||||
|
if ok {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the user to show up
|
||||||
|
record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
u, ok = record.(*user.User)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForRecordSync waits for the first sync of a record to complete
|
||||||
|
func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, recordID string) (proto.Message, error) {
|
||||||
|
bo := backoff.NewExponentialBackOff()
|
||||||
|
bo.InitialInterval = time.Millisecond
|
||||||
|
bo.MaxElapsedTime = 0
|
||||||
|
bo.Reset()
|
||||||
|
|
||||||
|
for {
|
||||||
|
current := a.store.GetRecordData(recordTypeURL, recordID)
|
||||||
|
if current != nil {
|
||||||
|
// record found, so it's already synced
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := a.state.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||||
|
Type: recordTypeURL,
|
||||||
|
Id: recordID,
|
||||||
|
})
|
||||||
|
if status.Code(err) == codes.NotFound {
|
||||||
|
// record not found, so no need to wait
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Str("type", recordTypeURL).
|
||||||
|
Str("id", recordID).
|
||||||
|
Msg("authorize: error retrieving record")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Warn().
|
||||||
|
Str("type", recordTypeURL).
|
||||||
|
Str("id", recordID).
|
||||||
|
Msg("authorize: first sync of record did not complete")
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(bo.NextBackOff()):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
116
authorize/sync_test.go
Normal file
116
authorize/sync_test.go
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
package authorize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorize_waitForRecordSync(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
o := &config.Options{
|
||||||
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
|
||||||
|
Policies: testPolicies(t),
|
||||||
|
}
|
||||||
|
t.Run("skip if exists", func(t *testing.T) {
|
||||||
|
a, err := New(&config.Config{Options: o})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
a.store.UpdateRecord(0, newRecord(&session.Session{
|
||||||
|
Id: "SESSION_ID",
|
||||||
|
}))
|
||||||
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
||||||
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||||
|
panic("should never be called")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
||||||
|
})
|
||||||
|
t.Run("skip if not found", func(t *testing.T) {
|
||||||
|
a, err := New(&config.Config{Options: o})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
||||||
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||||
|
callCount++
|
||||||
|
return nil, status.Error(codes.NotFound, "not found")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
||||||
|
assert.Equal(t, 1, callCount, "should be called once")
|
||||||
|
})
|
||||||
|
t.Run("poll", func(t *testing.T) {
|
||||||
|
a, err := New(&config.Config{Options: o})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
||||||
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||||
|
callCount++
|
||||||
|
switch callCount {
|
||||||
|
case 1:
|
||||||
|
s := &session.Session{Id: "SESSION_ID"}
|
||||||
|
a.store.UpdateRecord(0, newRecord(s))
|
||||||
|
return &databroker.GetResponse{Record: newRecord(s)}, nil
|
||||||
|
default:
|
||||||
|
panic("should never be called")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
||||||
|
})
|
||||||
|
t.Run("timeout", func(t *testing.T) {
|
||||||
|
a, err := New(&config.Config{Options: o})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{
|
||||||
|
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
|
||||||
|
callCount++
|
||||||
|
s := &session.Session{Id: "SESSION_ID"}
|
||||||
|
return &databroker.GetResponse{Record: newRecord(s)}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID")
|
||||||
|
assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type storableMessage interface {
|
||||||
|
proto.Message
|
||||||
|
GetId() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRecord(msg storableMessage) *databroker.Record {
|
||||||
|
any, err := anypb.New(msg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return &databroker.Record{
|
||||||
|
Version: 1,
|
||||||
|
Type: any.GetTypeUrl(),
|
||||||
|
Id: msg.GetId(),
|
||||||
|
Data: any,
|
||||||
|
}
|
||||||
|
}
|
|
@ -206,7 +206,7 @@ func (s *syncerHandler) ClearRecords(ctx context.Context) {
|
||||||
s.src.mu.Unlock()
|
s.src.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *syncerHandler) UpdateRecords(ctx context.Context, records []*databroker.Record) {
|
func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
|
||||||
if len(records) == 0 {
|
if len(records) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,7 @@ func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrok
|
||||||
return syncer.cfg.Load().dataBrokerClient
|
return syncer.cfg.Load().dataBrokerClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, records []*databroker.Record) {
|
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case syncer.update <- updateRecordsMessage{records: records}:
|
case syncer.update <- updateRecordsMessage{records: records}:
|
||||||
|
|
|
@ -39,7 +39,7 @@ func WithTypeURL(typeURL string) SyncerOption {
|
||||||
type SyncerHandler interface {
|
type SyncerHandler interface {
|
||||||
GetDataBrokerServiceClient() DataBrokerServiceClient
|
GetDataBrokerServiceClient() DataBrokerServiceClient
|
||||||
ClearRecords(ctx context.Context)
|
ClearRecords(ctx context.Context)
|
||||||
UpdateRecords(ctx context.Context, records []*Record)
|
UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Syncer is a helper type for working with Sync and SyncLatest. It will make a call to
|
// A Syncer is a helper type for working with Sync and SyncLatest. It will make a call to
|
||||||
|
@ -122,7 +122,7 @@ func (syncer *Syncer) init(ctx context.Context) error {
|
||||||
|
|
||||||
syncer.recordVersion = recordVersion
|
syncer.recordVersion = recordVersion
|
||||||
syncer.serverVersion = serverVersion
|
syncer.serverVersion = serverVersion
|
||||||
syncer.handler.UpdateRecords(ctx, records)
|
syncer.handler.UpdateRecords(ctx, serverVersion, records)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,7 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
syncer.recordVersion = res.GetRecord().GetVersion()
|
syncer.recordVersion = res.GetRecord().GetVersion()
|
||||||
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
|
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
|
||||||
syncer.handler.UpdateRecords(ctx, []*Record{res.GetRecord()})
|
syncer.handler.UpdateRecords(ctx, syncer.serverVersion, []*Record{res.GetRecord()})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
type testSyncerHandler struct {
|
type testSyncerHandler struct {
|
||||||
getDataBrokerServiceClient func() DataBrokerServiceClient
|
getDataBrokerServiceClient func() DataBrokerServiceClient
|
||||||
clearRecords func(ctx context.Context)
|
clearRecords func(ctx context.Context)
|
||||||
updateRecords func(ctx context.Context, records []*Record)
|
updateRecords func(ctx context.Context, serverVersion uint64, records []*Record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t testSyncerHandler) GetDataBrokerServiceClient() DataBrokerServiceClient {
|
func (t testSyncerHandler) GetDataBrokerServiceClient() DataBrokerServiceClient {
|
||||||
|
@ -30,8 +30,8 @@ func (t testSyncerHandler) ClearRecords(ctx context.Context) {
|
||||||
t.clearRecords(ctx)
|
t.clearRecords(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t testSyncerHandler) UpdateRecords(ctx context.Context, records []*Record) {
|
func (t testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) {
|
||||||
t.updateRecords(ctx, records)
|
t.updateRecords(ctx, serverVersion, records)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testServer struct {
|
type testServer struct {
|
||||||
|
@ -166,7 +166,7 @@ func TestSyncer(t *testing.T) {
|
||||||
clearRecords: func(ctx context.Context) {
|
clearRecords: func(ctx context.Context) {
|
||||||
clearCh <- struct{}{}
|
clearCh <- struct{}{}
|
||||||
},
|
},
|
||||||
updateRecords: func(ctx context.Context, records []*Record) {
|
updateRecords: func(ctx context.Context, serverVersion uint64, records []*Record) {
|
||||||
updateCh <- records
|
updateCh <- records
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue