mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-30 17:37:25 +02:00
authorize: bypass data in rego for databroker data (#2041)
This commit is contained in:
parent
76bc7a7e9a
commit
4218f49741
7 changed files with 141 additions and 139 deletions
|
@ -59,6 +59,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) {
|
||||||
rego.Module("pomerium.authz", string(authzPolicy)),
|
rego.Module("pomerium.authz", string(authzPolicy)),
|
||||||
rego.Query("result = data.pomerium.authz"),
|
rego.Query("result = data.pomerium.authz"),
|
||||||
getGoogleCloudServerlessHeadersRegoOption,
|
getGoogleCloudServerlessHeadersRegoOption,
|
||||||
|
store.GetDataBrokerRecordOption(),
|
||||||
)
|
)
|
||||||
|
|
||||||
e.query, err = e.rego.PrepareForEval(context.Background())
|
e.query, err = e.rego.PrepareForEval(context.Background())
|
||||||
|
|
|
@ -12,6 +12,8 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -168,18 +170,18 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
|
|
||||||
lastSessionID := ""
|
lastSessionID := ""
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100000; i++ {
|
||||||
sessionID := uuid.New().String()
|
sessionID := uuid.New().String()
|
||||||
lastSessionID = sessionID
|
lastSessionID = sessionID
|
||||||
userID := uuid.New().String()
|
userID := uuid.New().String()
|
||||||
data, _ := ptypes.MarshalAny(&session.Session{
|
data, _ := anypb.New(&session.Session{
|
||||||
Version: fmt.Sprint(i),
|
Version: fmt.Sprint(i),
|
||||||
Id: sessionID,
|
Id: sessionID,
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
IdToken: &session.IDToken{
|
IdToken: &session.IDToken{
|
||||||
Issuer: "benchmark",
|
Issuer: "benchmark",
|
||||||
Subject: userID,
|
Subject: userID,
|
||||||
IssuedAt: ptypes.TimestampNow(),
|
IssuedAt: timestamppb.Now(),
|
||||||
},
|
},
|
||||||
OauthToken: &session.OAuthToken{
|
OauthToken: &session.OAuthToken{
|
||||||
AccessToken: "ACCESS TOKEN",
|
AccessToken: "ACCESS TOKEN",
|
||||||
|
@ -193,7 +195,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
Id: sessionID,
|
Id: sessionID,
|
||||||
Data: data,
|
Data: data,
|
||||||
})
|
})
|
||||||
data, _ = ptypes.MarshalAny(&user.User{
|
data, _ = anypb.New(&user.User{
|
||||||
Version: fmt.Sprint(i),
|
Version: fmt.Sprint(i),
|
||||||
Id: userID,
|
Id: userID,
|
||||||
})
|
})
|
||||||
|
@ -203,6 +205,29 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
|
||||||
Id: userID,
|
Id: userID,
|
||||||
Data: data,
|
Data: data,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
data, _ = anypb.New(&directory.User{
|
||||||
|
Version: fmt.Sprint(i),
|
||||||
|
Id: userID,
|
||||||
|
GroupIds: []string{"1", "2", "3", "4"},
|
||||||
|
})
|
||||||
|
store.UpdateRecord(&databroker.Record{
|
||||||
|
Version: uint64(i),
|
||||||
|
Type: data.TypeUrl,
|
||||||
|
Id: userID,
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
|
||||||
|
data, _ = anypb.New(&directory.Group{
|
||||||
|
Version: fmt.Sprint(i),
|
||||||
|
Id: fmt.Sprint(i),
|
||||||
|
})
|
||||||
|
store.UpdateRecord(&databroker.Record{
|
||||||
|
Version: uint64(i),
|
||||||
|
Type: data.TypeUrl,
|
||||||
|
Id: fmt.Sprint(i),
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
|
@ -10,30 +10,30 @@ route_policy_idx := first_allowed_route_policy_idx(input.http.url)
|
||||||
route_policy := data.route_policies[route_policy_idx]
|
route_policy := data.route_policies[route_policy_idx]
|
||||||
|
|
||||||
session = s {
|
session = s {
|
||||||
s = object_get(data.databroker_data["type.googleapis.com"]["user.ServiceAccount"], input.session.id, null)
|
s = get_databroker_record("type.googleapis.com/user.ServiceAccount", input.session.id)
|
||||||
s != null
|
s != null
|
||||||
} else = s {
|
} else = s {
|
||||||
s = object_get(data.databroker_data["type.googleapis.com"]["session.Session"], input.session.id, null)
|
s = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
|
||||||
s != null
|
s != null
|
||||||
} else = {} {
|
} else = {} {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
user = u {
|
user = u {
|
||||||
u = object_get(data.databroker_data["type.googleapis.com"]["user.User"], session.impersonate_user_id, null)
|
u = get_databroker_record("type.googleapis.com/user.User", session.impersonate_user_id)
|
||||||
u != null
|
u != null
|
||||||
} else = u {
|
} else = u {
|
||||||
u = object_get(data.databroker_data["type.googleapis.com"]["user.User"], session.user_id, null)
|
u = get_databroker_record("type.googleapis.com/user.User", session.user_id)
|
||||||
u != null
|
u != null
|
||||||
} else = {} {
|
} else = {} {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
directory_user = du {
|
directory_user = du {
|
||||||
du = object_get(data.databroker_data["type.googleapis.com"]["directory.User"], session.impersonate_user_id, null)
|
du = get_databroker_record("type.googleapis.com/directory.User", session.impersonate_user_id)
|
||||||
du != null
|
du != null
|
||||||
} else = du {
|
} else = du {
|
||||||
du = object_get(data.databroker_data["type.googleapis.com"]["directory.User"], session.user_id, null)
|
du = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
|
||||||
du != null
|
du != null
|
||||||
} else = {} {
|
} else = {} {
|
||||||
true
|
true
|
||||||
|
@ -212,7 +212,7 @@ jwt_payload_groups = v {
|
||||||
v = array.concat(group_ids, get_databroker_group_names(group_ids))
|
v = array.concat(group_ids, get_databroker_group_names(group_ids))
|
||||||
v != []
|
v != []
|
||||||
} else = v {
|
} else = v {
|
||||||
v = session.claims["groups"]
|
v = session.claims.groups
|
||||||
v != null
|
v != null
|
||||||
} else = [] {
|
} else = [] {
|
||||||
true
|
true
|
||||||
|
@ -398,11 +398,11 @@ are_claims_allowed(a, b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
get_databroker_group_names(ids) = gs {
|
get_databroker_group_names(ids) = gs {
|
||||||
gs := [name | id := ids[i]; group := data.databroker_data["type.googleapis.com"]["directory.Group"][id]; name := group.name]
|
gs := [name | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); name := group.name]
|
||||||
}
|
}
|
||||||
|
|
||||||
get_databroker_group_emails(ids) = gs {
|
get_databroker_group_emails(ids) = gs {
|
||||||
gs := [email | id := ids[i]; group := data.databroker_data["type.googleapis.com"]["directory.Group"][id]; email := group.email]
|
gs := [email | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); email := group.email]
|
||||||
}
|
}
|
||||||
|
|
||||||
get_header_string_value(obj) = s {
|
get_header_string_value(obj) = s {
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package pomerium.authz
|
package pomerium.authz
|
||||||
|
|
||||||
get_google_cloud_serverless_headers(serviceAccount, audience) = h {
|
get_google_cloud_serverless_headers(serviceAccount, audience) = h {
|
||||||
h := {
|
h := {"Authorization": "Bearer xxx"}
|
||||||
"Authorization": "Bearer xxx"
|
}
|
||||||
}
|
|
||||||
|
get_databroker_record(typeURL, id) = v {
|
||||||
|
v := object_get(data.databroker_data, typeURL, null)[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
test_email_allowed {
|
test_email_allowed {
|
||||||
|
|
|
@ -34,7 +34,7 @@ func TestOPA(t *testing.T) {
|
||||||
publicJWK, err := cryptutil.PublicJWKFromBytes(encodedSigningKey, jose.ES256)
|
publicJWK, err := cryptutil.PublicJWKFromBytes(encodedSigningKey, jose.ES256)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
eval := func(policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result {
|
eval := func(t *testing.T, policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result {
|
||||||
authzPolicy, err := readPolicy()
|
authzPolicy, err := readPolicy()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
store := NewStoreFromProtos(data...)
|
store := NewStoreFromProtos(data...)
|
||||||
|
@ -47,6 +47,7 @@ func TestOPA(t *testing.T) {
|
||||||
rego.Module("pomerium.authz", string(authzPolicy)),
|
rego.Module("pomerium.authz", string(authzPolicy)),
|
||||||
rego.Query("result = data.pomerium.authz"),
|
rego.Query("result = data.pomerium.authz"),
|
||||||
getGoogleCloudServerlessHeadersRegoOption,
|
getGoogleCloudServerlessHeadersRegoOption,
|
||||||
|
store.GetDataBrokerRecordOption(),
|
||||||
)
|
)
|
||||||
q, err := r.PrepareForEval(context.Background())
|
q, err := r.PrepareForEval(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -59,14 +60,14 @@ func TestOPA(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("client certificate", func(t *testing.T) {
|
t.Run("client certificate", func(t *testing.T) {
|
||||||
res := eval(nil, nil, &Request{}, false)
|
res := eval(t, nil, nil, &Request{}, false)
|
||||||
assert.Equal(t,
|
assert.Equal(t,
|
||||||
A{A{json.Number("495"), "invalid client certificate"}},
|
A{A{json.Number("495"), "invalid client certificate"}},
|
||||||
res.Bindings["result"].(M)["deny"])
|
res.Bindings["result"].(M)["deny"])
|
||||||
})
|
})
|
||||||
t.Run("identity_headers", func(t *testing.T) {
|
t.Run("identity_headers", func(t *testing.T) {
|
||||||
t.Run("kubernetes", func(t *testing.T) {
|
t.Run("kubernetes", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{{
|
res := eval(t, []config.Policy{{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
{URL: *mustParseURL("https://to.example.com")},
|
{URL: *mustParseURL("https://to.example.com")},
|
||||||
|
@ -98,7 +99,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("google_cloud_serverless", func(t *testing.T) {
|
t.Run("google_cloud_serverless", func(t *testing.T) {
|
||||||
withMockGCP(t, func() {
|
withMockGCP(t, func() {
|
||||||
res := eval([]config.Policy{{
|
res := eval(t, []config.Policy{{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
{URL: *mustParseURL("https://to.example.com")},
|
{URL: *mustParseURL("https://to.example.com")},
|
||||||
|
@ -130,7 +131,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("jwt", func(t *testing.T) {
|
t.Run("jwt", func(t *testing.T) {
|
||||||
evalJWT := func(msgs ...proto.Message) M {
|
evalJWT := func(msgs ...proto.Message) M {
|
||||||
res := eval([]config.Policy{{
|
res := eval(t, []config.Policy{{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
{URL: *mustParseURL("https://to.example.com")},
|
{URL: *mustParseURL("https://to.example.com")},
|
||||||
|
@ -226,7 +227,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("email", func(t *testing.T) {
|
t.Run("email", func(t *testing.T) {
|
||||||
t.Run("allowed", func(t *testing.T) {
|
t.Run("allowed", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -255,7 +256,7 @@ func TestOPA(t *testing.T) {
|
||||||
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
||||||
})
|
})
|
||||||
t.Run("denied", func(t *testing.T) {
|
t.Run("denied", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -286,7 +287,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("impersonate email", func(t *testing.T) {
|
t.Run("impersonate email", func(t *testing.T) {
|
||||||
t.Run("allowed", func(t *testing.T) {
|
t.Run("allowed", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -316,7 +317,7 @@ func TestOPA(t *testing.T) {
|
||||||
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
||||||
})
|
})
|
||||||
t.Run("denied", func(t *testing.T) {
|
t.Run("denied", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -347,7 +348,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("user_id", func(t *testing.T) {
|
t.Run("user_id", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -377,7 +378,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("domain", func(t *testing.T) {
|
t.Run("domain", func(t *testing.T) {
|
||||||
t.Run("allowed", func(t *testing.T) {
|
t.Run("allowed", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -407,7 +408,7 @@ func TestOPA(t *testing.T) {
|
||||||
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
||||||
})
|
})
|
||||||
t.Run("denied", func(t *testing.T) {
|
t.Run("denied", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -438,7 +439,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("impersonate domain", func(t *testing.T) {
|
t.Run("impersonate domain", func(t *testing.T) {
|
||||||
t.Run("allowed", func(t *testing.T) {
|
t.Run("allowed", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -468,7 +469,7 @@ func TestOPA(t *testing.T) {
|
||||||
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
||||||
})
|
})
|
||||||
t.Run("denied", func(t *testing.T) {
|
t.Run("denied", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -502,7 +503,7 @@ func TestOPA(t *testing.T) {
|
||||||
t.Run("allowed", func(t *testing.T) {
|
t.Run("allowed", func(t *testing.T) {
|
||||||
for _, nm := range []string{"group1", "group1name", "group1@example.com"} {
|
for _, nm := range []string{"group1", "group1name", "group1@example.com"} {
|
||||||
t.Run(nm, func(t *testing.T) {
|
t.Run(nm, func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -542,7 +543,7 @@ func TestOPA(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("denied", func(t *testing.T) {
|
t.Run("denied", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -581,7 +582,7 @@ func TestOPA(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("impersonate groups", func(t *testing.T) {
|
t.Run("impersonate groups", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
@ -617,7 +618,7 @@ func TestOPA(t *testing.T) {
|
||||||
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
|
||||||
})
|
})
|
||||||
t.Run("any authenticated user", func(t *testing.T) {
|
t.Run("any authenticated user", func(t *testing.T) {
|
||||||
res := eval([]config.Policy{
|
res := eval(t, []config.Policy{
|
||||||
{
|
{
|
||||||
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
|
||||||
To: config.WeightedURLs{
|
To: config.WeightedURLs{
|
||||||
|
|
|
@ -4,10 +4,14 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/open-policy-agent/opa/ast"
|
||||||
|
"github.com/open-policy-agent/opa/rego"
|
||||||
"github.com/open-policy-agent/opa/storage"
|
"github.com/open-policy-agent/opa/storage"
|
||||||
"github.com/open-policy-agent/opa/storage/inmem"
|
"github.com/open-policy-agent/opa/storage/inmem"
|
||||||
|
"github.com/open-policy-agent/opa/types"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
@ -22,12 +26,16 @@ import (
|
||||||
// A Store stores data for the OPA rego policy evaluation.
|
// A Store stores data for the OPA rego policy evaluation.
|
||||||
type Store struct {
|
type Store struct {
|
||||||
opaStore storage.Store
|
opaStore storage.Store
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
dataBrokerData map[string]map[string]proto.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStore creates a new Store.
|
// NewStore creates a new Store.
|
||||||
func NewStore() *Store {
|
func NewStore() *Store {
|
||||||
return &Store{
|
return &Store{
|
||||||
opaStore: inmem.New(),
|
opaStore: inmem.New(),
|
||||||
|
dataBrokerData: make(map[string]map[string]proto.Message),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,37 +65,24 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
|
||||||
|
|
||||||
// ClearRecords removes all the records from the store.
|
// ClearRecords removes all the records from the store.
|
||||||
func (s *Store) ClearRecords() {
|
func (s *Store) ClearRecords() {
|
||||||
rawPath := "/databroker_data"
|
s.mu.Lock()
|
||||||
s.delete(rawPath)
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.dataBrokerData = make(map[string]map[string]proto.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRecordData gets a record's data from the store. `nil` is returned
|
// GetRecordData gets a record's data from the store. `nil` is returned
|
||||||
// if no record exists for the given type and id.
|
// if no record exists for the given type and id.
|
||||||
func (s *Store) GetRecordData(typeURL, id string) proto.Message {
|
func (s *Store) GetRecordData(typeURL, id string) proto.Message {
|
||||||
rawPath := fmt.Sprintf("/databroker_data/%s/%s", typeURL, id)
|
s.mu.RLock()
|
||||||
data := s.get(rawPath)
|
defer s.mu.RUnlock()
|
||||||
if data == nil {
|
|
||||||
|
m, ok := s.dataBrokerData[typeURL]
|
||||||
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
any := anypb.Any{
|
return m[id]
|
||||||
TypeUrl: typeURL,
|
|
||||||
}
|
|
||||||
msg, err := any.UnmarshalNew()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
bs, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(bs, &msg)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
|
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
|
||||||
|
@ -113,51 +108,18 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
|
||||||
|
|
||||||
// UpdateRecord updates a record in the store.
|
// UpdateRecord updates a record in the store.
|
||||||
func (s *Store) UpdateRecord(record *databroker.Record) {
|
func (s *Store) UpdateRecord(record *databroker.Record) {
|
||||||
rawPath := fmt.Sprintf("/databroker_data/%s/%s", record.GetType(), record.GetId())
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
if record.GetDeletedAt() != nil {
|
m, ok := s.dataBrokerData[record.GetType()]
|
||||||
s.delete(rawPath)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := record.GetData().UnmarshalNew()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).
|
|
||||||
Str("path", rawPath).
|
|
||||||
Msg("opa-store: error unmarshaling record data, ignoring")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.write(rawPath, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) delete(rawPath string) {
|
|
||||||
p, ok := storage.ParsePath(rawPath)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().
|
m = make(map[string]proto.Message)
|
||||||
Str("path", rawPath).
|
s.dataBrokerData[record.GetType()] = m
|
||||||
Msg("opa-store: invalid path, ignoring data")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
if record.GetDeletedAt() != nil {
|
||||||
err := storage.Txn(context.Background(), s.opaStore, storage.WriteParams, func(txn storage.Transaction) error {
|
delete(m, record.GetId())
|
||||||
_, err := s.opaStore.Read(context.Background(), txn, p)
|
} else {
|
||||||
if storage.IsNotFound(err) {
|
m[record.GetId()], _ = record.GetData().UnmarshalNew()
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.opaStore.Write(context.Background(), txn, storage.RemoveOp, p, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("opa-store: error deleting data")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,27 +129,6 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||||
s.write("/signing_key", signingKey)
|
s.write("/signing_key", signingKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) get(rawPath string) (value interface{}) {
|
|
||||||
p, ok := storage.ParsePath(rawPath)
|
|
||||||
if !ok {
|
|
||||||
log.Error().
|
|
||||||
Str("path", rawPath).
|
|
||||||
Msg("opa-store: invalid path, ignoring data")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
value, err = storage.ReadOne(context.Background(), s.opaStore, p)
|
|
||||||
if storage.IsNotFound(err) {
|
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
log.Error().Err(err).Msg("opa-store: error reading data")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) write(rawPath string, value interface{}) {
|
func (s *Store) write(rawPath string, value interface{}) {
|
||||||
p, ok := storage.ParsePath(rawPath)
|
p, ok := storage.ParsePath(rawPath)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -220,3 +161,44 @@ func (s *Store) write(rawPath string, value interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerRecordOption returns a function option that can retrieve databroker data.
|
||||||
|
func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
|
||||||
|
return rego.Function2(®o.Function{
|
||||||
|
Name: "get_databroker_record",
|
||||||
|
Decl: types.NewFunction(
|
||||||
|
types.Args(types.S, types.S),
|
||||||
|
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)),
|
||||||
|
),
|
||||||
|
}, func(bctx rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
|
||||||
|
recordType, ok := op1.Value.(ast.String)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid record type: %T", op1)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordID, ok := op2.Value.(ast.String)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid record id: %T", op2)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := s.GetRecordData(string(recordType), string(recordID))
|
||||||
|
if msg == nil {
|
||||||
|
return ast.NullTerm(), nil
|
||||||
|
}
|
||||||
|
obj := toMap(msg)
|
||||||
|
|
||||||
|
value, err := ast.InterfaceToValue(obj)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ast.NewTerm(value), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func toMap(msg proto.Message) map[string]interface{} {
|
||||||
|
bs, _ := json.Marshal(msg)
|
||||||
|
var obj map[string]interface{}
|
||||||
|
_ = json.Unmarshal(bs, &obj)
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
|
|
@ -1,22 +1,17 @@
|
||||||
package evaluator
|
package evaluator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes"
|
|
||||||
"github.com/open-policy-agent/opa/storage"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
s := NewStore()
|
s := NewStore()
|
||||||
t.Run("records", func(t *testing.T) {
|
t.Run("records", func(t *testing.T) {
|
||||||
u := &user.User{
|
u := &user.User{
|
||||||
|
@ -25,7 +20,7 @@ func TestStore(t *testing.T) {
|
||||||
Name: "name",
|
Name: "name",
|
||||||
Email: "name@example.com",
|
Email: "name@example.com",
|
||||||
}
|
}
|
||||||
any, _ := ptypes.MarshalAny(u)
|
any, _ := anypb.New(u)
|
||||||
s.UpdateRecord(&databroker.Record{
|
s.UpdateRecord(&databroker.Record{
|
||||||
Version: 1,
|
Version: 1,
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
|
@ -33,25 +28,23 @@ func TestStore(t *testing.T) {
|
||||||
Data: any,
|
Data: any,
|
||||||
})
|
})
|
||||||
|
|
||||||
v, err := storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1"))
|
v := s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, map[string]interface{}{
|
assert.Equal(t, map[string]interface{}{
|
||||||
"version": "v1",
|
"version": "v1",
|
||||||
"id": "u1",
|
"id": "u1",
|
||||||
"name": "name",
|
"name": "name",
|
||||||
"email": "name@example.com",
|
"email": "name@example.com",
|
||||||
}, v)
|
}, toMap(v))
|
||||||
|
|
||||||
s.UpdateRecord(&databroker.Record{
|
s.UpdateRecord(&databroker.Record{
|
||||||
Version: 2,
|
Version: 2,
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: u.GetId(),
|
Id: u.GetId(),
|
||||||
Data: any,
|
Data: any,
|
||||||
DeletedAt: ptypes.TimestampNow(),
|
DeletedAt: timestamppb.Now(),
|
||||||
})
|
})
|
||||||
|
|
||||||
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1"))
|
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, v)
|
assert.Nil(t, v)
|
||||||
|
|
||||||
s.UpdateRecord(&databroker.Record{
|
s.UpdateRecord(&databroker.Record{
|
||||||
|
@ -61,13 +54,11 @@ func TestStore(t *testing.T) {
|
||||||
Data: any,
|
Data: any,
|
||||||
})
|
})
|
||||||
|
|
||||||
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1"))
|
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, v)
|
assert.NotNil(t, v)
|
||||||
|
|
||||||
s.ClearRecords()
|
s.ClearRecords()
|
||||||
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1"))
|
v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, v)
|
assert.Nil(t, v)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue