authorize: bypass data in rego for databroker data (#2041)

This commit is contained in:
Caleb Doxsey 2021-03-30 14:14:32 -06:00 committed by GitHub
parent 76bc7a7e9a
commit 4218f49741
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 139 deletions

View file

@ -59,6 +59,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) {
rego.Module("pomerium.authz", string(authzPolicy)), rego.Module("pomerium.authz", string(authzPolicy)),
rego.Query("result = data.pomerium.authz"), rego.Query("result = data.pomerium.authz"),
getGoogleCloudServerlessHeadersRegoOption, getGoogleCloudServerlessHeadersRegoOption,
store.GetDataBrokerRecordOption(),
) )
e.query, err = e.rego.PrepareForEval(context.Background()) e.query, err = e.rego.PrepareForEval(context.Background())

View file

@ -12,6 +12,8 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -168,18 +170,18 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
lastSessionID := "" lastSessionID := ""
for i := 0; i < 100; i++ { for i := 0; i < 100000; i++ {
sessionID := uuid.New().String() sessionID := uuid.New().String()
lastSessionID = sessionID lastSessionID = sessionID
userID := uuid.New().String() userID := uuid.New().String()
data, _ := ptypes.MarshalAny(&session.Session{ data, _ := anypb.New(&session.Session{
Version: fmt.Sprint(i), Version: fmt.Sprint(i),
Id: sessionID, Id: sessionID,
UserId: userID, UserId: userID,
IdToken: &session.IDToken{ IdToken: &session.IDToken{
Issuer: "benchmark", Issuer: "benchmark",
Subject: userID, Subject: userID,
IssuedAt: ptypes.TimestampNow(), IssuedAt: timestamppb.Now(),
}, },
OauthToken: &session.OAuthToken{ OauthToken: &session.OAuthToken{
AccessToken: "ACCESS TOKEN", AccessToken: "ACCESS TOKEN",
@ -193,7 +195,7 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
Id: sessionID, Id: sessionID,
Data: data, Data: data,
}) })
data, _ = ptypes.MarshalAny(&user.User{ data, _ = anypb.New(&user.User{
Version: fmt.Sprint(i), Version: fmt.Sprint(i),
Id: userID, Id: userID,
}) })
@ -203,6 +205,29 @@ func BenchmarkEvaluator_Evaluate(b *testing.B) {
Id: userID, Id: userID,
Data: data, Data: data,
}) })
data, _ = anypb.New(&directory.User{
Version: fmt.Sprint(i),
Id: userID,
GroupIds: []string{"1", "2", "3", "4"},
})
store.UpdateRecord(&databroker.Record{
Version: uint64(i),
Type: data.TypeUrl,
Id: userID,
Data: data,
})
data, _ = anypb.New(&directory.Group{
Version: fmt.Sprint(i),
Id: fmt.Sprint(i),
})
store.UpdateRecord(&databroker.Record{
Version: uint64(i),
Type: data.TypeUrl,
Id: fmt.Sprint(i),
Data: data,
})
} }
b.ResetTimer() b.ResetTimer()

View file

@ -10,30 +10,30 @@ route_policy_idx := first_allowed_route_policy_idx(input.http.url)
route_policy := data.route_policies[route_policy_idx] route_policy := data.route_policies[route_policy_idx]
session = s { session = s {
s = object_get(data.databroker_data["type.googleapis.com"]["user.ServiceAccount"], input.session.id, null) s = get_databroker_record("type.googleapis.com/user.ServiceAccount", input.session.id)
s != null s != null
} else = s { } else = s {
s = object_get(data.databroker_data["type.googleapis.com"]["session.Session"], input.session.id, null) s = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
s != null s != null
} else = {} { } else = {} {
true true
} }
user = u { user = u {
u = object_get(data.databroker_data["type.googleapis.com"]["user.User"], session.impersonate_user_id, null) u = get_databroker_record("type.googleapis.com/user.User", session.impersonate_user_id)
u != null u != null
} else = u { } else = u {
u = object_get(data.databroker_data["type.googleapis.com"]["user.User"], session.user_id, null) u = get_databroker_record("type.googleapis.com/user.User", session.user_id)
u != null u != null
} else = {} { } else = {} {
true true
} }
directory_user = du { directory_user = du {
du = object_get(data.databroker_data["type.googleapis.com"]["directory.User"], session.impersonate_user_id, null) du = get_databroker_record("type.googleapis.com/directory.User", session.impersonate_user_id)
du != null du != null
} else = du { } else = du {
du = object_get(data.databroker_data["type.googleapis.com"]["directory.User"], session.user_id, null) du = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
du != null du != null
} else = {} { } else = {} {
true true
@ -212,7 +212,7 @@ jwt_payload_groups = v {
v = array.concat(group_ids, get_databroker_group_names(group_ids)) v = array.concat(group_ids, get_databroker_group_names(group_ids))
v != [] v != []
} else = v { } else = v {
v = session.claims["groups"] v = session.claims.groups
v != null v != null
} else = [] { } else = [] {
true true
@ -398,11 +398,11 @@ are_claims_allowed(a, b) {
} }
get_databroker_group_names(ids) = gs { get_databroker_group_names(ids) = gs {
gs := [name | id := ids[i]; group := data.databroker_data["type.googleapis.com"]["directory.Group"][id]; name := group.name] gs := [name | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); name := group.name]
} }
get_databroker_group_emails(ids) = gs { get_databroker_group_emails(ids) = gs {
gs := [email | id := ids[i]; group := data.databroker_data["type.googleapis.com"]["directory.Group"][id]; email := group.email] gs := [email | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); email := group.email]
} }
get_header_string_value(obj) = s { get_header_string_value(obj) = s {

View file

@ -1,9 +1,11 @@
package pomerium.authz package pomerium.authz
get_google_cloud_serverless_headers(serviceAccount, audience) = h { get_google_cloud_serverless_headers(serviceAccount, audience) = h {
h := { h := {"Authorization": "Bearer xxx"}
"Authorization": "Bearer xxx" }
}
get_databroker_record(typeURL, id) = v {
v := object_get(data.databroker_data, typeURL, null)[id]
} }
test_email_allowed { test_email_allowed {

View file

@ -34,7 +34,7 @@ func TestOPA(t *testing.T) {
publicJWK, err := cryptutil.PublicJWKFromBytes(encodedSigningKey, jose.ES256) publicJWK, err := cryptutil.PublicJWKFromBytes(encodedSigningKey, jose.ES256)
require.NoError(t, err) require.NoError(t, err)
eval := func(policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result { eval := func(t *testing.T, policies []config.Policy, data []proto.Message, req *Request, isValidClientCertificate bool) rego.Result {
authzPolicy, err := readPolicy() authzPolicy, err := readPolicy()
require.NoError(t, err) require.NoError(t, err)
store := NewStoreFromProtos(data...) store := NewStoreFromProtos(data...)
@ -47,6 +47,7 @@ func TestOPA(t *testing.T) {
rego.Module("pomerium.authz", string(authzPolicy)), rego.Module("pomerium.authz", string(authzPolicy)),
rego.Query("result = data.pomerium.authz"), rego.Query("result = data.pomerium.authz"),
getGoogleCloudServerlessHeadersRegoOption, getGoogleCloudServerlessHeadersRegoOption,
store.GetDataBrokerRecordOption(),
) )
q, err := r.PrepareForEval(context.Background()) q, err := r.PrepareForEval(context.Background())
require.NoError(t, err) require.NoError(t, err)
@ -59,14 +60,14 @@ func TestOPA(t *testing.T) {
} }
t.Run("client certificate", func(t *testing.T) { t.Run("client certificate", func(t *testing.T) {
res := eval(nil, nil, &Request{}, false) res := eval(t, nil, nil, &Request{}, false)
assert.Equal(t, assert.Equal(t,
A{A{json.Number("495"), "invalid client certificate"}}, A{A{json.Number("495"), "invalid client certificate"}},
res.Bindings["result"].(M)["deny"]) res.Bindings["result"].(M)["deny"])
}) })
t.Run("identity_headers", func(t *testing.T) { t.Run("identity_headers", func(t *testing.T) {
t.Run("kubernetes", func(t *testing.T) { t.Run("kubernetes", func(t *testing.T) {
res := eval([]config.Policy{{ res := eval(t, []config.Policy{{
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
{URL: *mustParseURL("https://to.example.com")}, {URL: *mustParseURL("https://to.example.com")},
@ -98,7 +99,7 @@ func TestOPA(t *testing.T) {
}) })
t.Run("google_cloud_serverless", func(t *testing.T) { t.Run("google_cloud_serverless", func(t *testing.T) {
withMockGCP(t, func() { withMockGCP(t, func() {
res := eval([]config.Policy{{ res := eval(t, []config.Policy{{
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
{URL: *mustParseURL("https://to.example.com")}, {URL: *mustParseURL("https://to.example.com")},
@ -130,7 +131,7 @@ func TestOPA(t *testing.T) {
}) })
t.Run("jwt", func(t *testing.T) { t.Run("jwt", func(t *testing.T) {
evalJWT := func(msgs ...proto.Message) M { evalJWT := func(msgs ...proto.Message) M {
res := eval([]config.Policy{{ res := eval(t, []config.Policy{{
Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")},
To: config.WeightedURLs{ To: config.WeightedURLs{
{URL: *mustParseURL("https://to.example.com")}, {URL: *mustParseURL("https://to.example.com")},
@ -226,7 +227,7 @@ func TestOPA(t *testing.T) {
}) })
t.Run("email", func(t *testing.T) { t.Run("email", func(t *testing.T) {
t.Run("allowed", func(t *testing.T) { t.Run("allowed", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com:8000")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -255,7 +256,7 @@ func TestOPA(t *testing.T) {
assert.True(t, res.Bindings["result"].(M)["allow"].(bool)) assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
}) })
t.Run("denied", func(t *testing.T) { t.Run("denied", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -286,7 +287,7 @@ func TestOPA(t *testing.T) {
}) })
t.Run("impersonate email", func(t *testing.T) { t.Run("impersonate email", func(t *testing.T) {
t.Run("allowed", func(t *testing.T) { t.Run("allowed", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -316,7 +317,7 @@ func TestOPA(t *testing.T) {
assert.True(t, res.Bindings["result"].(M)["allow"].(bool)) assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
}) })
t.Run("denied", func(t *testing.T) { t.Run("denied", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -347,7 +348,7 @@ func TestOPA(t *testing.T) {
}) })
}) })
t.Run("user_id", func(t *testing.T) { t.Run("user_id", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -377,7 +378,7 @@ func TestOPA(t *testing.T) {
}) })
t.Run("domain", func(t *testing.T) { t.Run("domain", func(t *testing.T) {
t.Run("allowed", func(t *testing.T) { t.Run("allowed", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -407,7 +408,7 @@ func TestOPA(t *testing.T) {
assert.True(t, res.Bindings["result"].(M)["allow"].(bool)) assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
}) })
t.Run("denied", func(t *testing.T) { t.Run("denied", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -438,7 +439,7 @@ func TestOPA(t *testing.T) {
}) })
t.Run("impersonate domain", func(t *testing.T) { t.Run("impersonate domain", func(t *testing.T) {
t.Run("allowed", func(t *testing.T) { t.Run("allowed", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -468,7 +469,7 @@ func TestOPA(t *testing.T) {
assert.True(t, res.Bindings["result"].(M)["allow"].(bool)) assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
}) })
t.Run("denied", func(t *testing.T) { t.Run("denied", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -502,7 +503,7 @@ func TestOPA(t *testing.T) {
t.Run("allowed", func(t *testing.T) { t.Run("allowed", func(t *testing.T) {
for _, nm := range []string{"group1", "group1name", "group1@example.com"} { for _, nm := range []string{"group1", "group1name", "group1@example.com"} {
t.Run(nm, func(t *testing.T) { t.Run(nm, func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -542,7 +543,7 @@ func TestOPA(t *testing.T) {
} }
}) })
t.Run("denied", func(t *testing.T) { t.Run("denied", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -581,7 +582,7 @@ func TestOPA(t *testing.T) {
}) })
}) })
t.Run("impersonate groups", func(t *testing.T) { t.Run("impersonate groups", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{
@ -617,7 +618,7 @@ func TestOPA(t *testing.T) {
assert.True(t, res.Bindings["result"].(M)["allow"].(bool)) assert.True(t, res.Bindings["result"].(M)["allow"].(bool))
}) })
t.Run("any authenticated user", func(t *testing.T) { t.Run("any authenticated user", func(t *testing.T) {
res := eval([]config.Policy{ res := eval(t, []config.Policy{
{ {
Source: &config.StringURL{URL: mustParseURL("https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL("https://from.example.com")},
To: config.WeightedURLs{ To: config.WeightedURLs{

View file

@ -4,10 +4,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/types"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
@ -22,12 +26,16 @@ import (
// A Store stores data for the OPA rego policy evaluation. // A Store stores data for the OPA rego policy evaluation.
type Store struct { type Store struct {
opaStore storage.Store opaStore storage.Store
mu sync.RWMutex
dataBrokerData map[string]map[string]proto.Message
} }
// NewStore creates a new Store. // NewStore creates a new Store.
func NewStore() *Store { func NewStore() *Store {
return &Store{ return &Store{
opaStore: inmem.New(), opaStore: inmem.New(),
dataBrokerData: make(map[string]map[string]proto.Message),
} }
} }
@ -57,37 +65,24 @@ func NewStoreFromProtos(msgs ...proto.Message) *Store {
// ClearRecords removes all the records from the store. // ClearRecords removes all the records from the store.
func (s *Store) ClearRecords() { func (s *Store) ClearRecords() {
rawPath := "/databroker_data" s.mu.Lock()
s.delete(rawPath) defer s.mu.Unlock()
s.dataBrokerData = make(map[string]map[string]proto.Message)
} }
// GetRecordData gets a record's data from the store. `nil` is returned // GetRecordData gets a record's data from the store. `nil` is returned
// if no record exists for the given type and id. // if no record exists for the given type and id.
func (s *Store) GetRecordData(typeURL, id string) proto.Message { func (s *Store) GetRecordData(typeURL, id string) proto.Message {
rawPath := fmt.Sprintf("/databroker_data/%s/%s", typeURL, id) s.mu.RLock()
data := s.get(rawPath) defer s.mu.RUnlock()
if data == nil {
m, ok := s.dataBrokerData[typeURL]
if !ok {
return nil return nil
} }
any := anypb.Any{ return m[id]
TypeUrl: typeURL,
}
msg, err := any.UnmarshalNew()
if err != nil {
return nil
}
bs, err := json.Marshal(data)
if err != nil {
return nil
}
err = json.Unmarshal(bs, &msg)
if err != nil {
return nil
}
return msg
} }
// UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction. // UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction.
@ -113,51 +108,18 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) {
// UpdateRecord updates a record in the store. // UpdateRecord updates a record in the store.
func (s *Store) UpdateRecord(record *databroker.Record) { func (s *Store) UpdateRecord(record *databroker.Record) {
rawPath := fmt.Sprintf("/databroker_data/%s/%s", record.GetType(), record.GetId()) s.mu.Lock()
defer s.mu.Unlock()
if record.GetDeletedAt() != nil { m, ok := s.dataBrokerData[record.GetType()]
s.delete(rawPath)
return
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
log.Error().Err(err).
Str("path", rawPath).
Msg("opa-store: error unmarshaling record data, ignoring")
return
}
s.write(rawPath, msg)
}
func (s *Store) delete(rawPath string) {
p, ok := storage.ParsePath(rawPath)
if !ok { if !ok {
log.Error(). m = make(map[string]proto.Message)
Str("path", rawPath). s.dataBrokerData[record.GetType()] = m
Msg("opa-store: invalid path, ignoring data")
return
} }
if record.GetDeletedAt() != nil {
err := storage.Txn(context.Background(), s.opaStore, storage.WriteParams, func(txn storage.Transaction) error { delete(m, record.GetId())
_, err := s.opaStore.Read(context.Background(), txn, p) } else {
if storage.IsNotFound(err) { m[record.GetId()], _ = record.GetData().UnmarshalNew()
return nil
} else if err != nil {
return err
}
err = s.opaStore.Write(context.Background(), txn, storage.RemoveOp, p, nil)
if err != nil {
return err
}
return nil
})
if err != nil {
log.Error().Err(err).Msg("opa-store: error deleting data")
return
} }
} }
@ -167,27 +129,6 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
s.write("/signing_key", signingKey) s.write("/signing_key", signingKey)
} }
func (s *Store) get(rawPath string) (value interface{}) {
p, ok := storage.ParsePath(rawPath)
if !ok {
log.Error().
Str("path", rawPath).
Msg("opa-store: invalid path, ignoring data")
return nil
}
var err error
value, err = storage.ReadOne(context.Background(), s.opaStore, p)
if storage.IsNotFound(err) {
return nil
} else if err != nil {
log.Error().Err(err).Msg("opa-store: error reading data")
return nil
}
return value
}
func (s *Store) write(rawPath string, value interface{}) { func (s *Store) write(rawPath string, value interface{}) {
p, ok := storage.ParsePath(rawPath) p, ok := storage.ParsePath(rawPath)
if !ok { if !ok {
@ -220,3 +161,44 @@ func (s *Store) write(rawPath string, value interface{}) {
return return
} }
} }
// GetDataBrokerRecordOption returns a function option that can retrieve databroker data.
func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
return rego.Function2(&rego.Function{
Name: "get_databroker_record",
Decl: types.NewFunction(
types.Args(types.S, types.S),
types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)),
),
}, func(bctx rego.BuiltinContext, op1 *ast.Term, op2 *ast.Term) (*ast.Term, error) {
recordType, ok := op1.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid record type: %T", op1)
}
recordID, ok := op2.Value.(ast.String)
if !ok {
return nil, fmt.Errorf("invalid record id: %T", op2)
}
msg := s.GetRecordData(string(recordType), string(recordID))
if msg == nil {
return ast.NullTerm(), nil
}
obj := toMap(msg)
value, err := ast.InterfaceToValue(obj)
if err != nil {
return nil, err
}
return ast.NewTerm(value), nil
})
}
func toMap(msg proto.Message) map[string]interface{} {
bs, _ := json.Marshal(msg)
var obj map[string]interface{}
_ = json.Unmarshal(bs, &obj)
return obj
}

View file

@ -1,22 +1,17 @@
package evaluator package evaluator
import ( import (
"context"
"testing" "testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/open-policy-agent/opa/storage"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
s := NewStore() s := NewStore()
t.Run("records", func(t *testing.T) { t.Run("records", func(t *testing.T) {
u := &user.User{ u := &user.User{
@ -25,7 +20,7 @@ func TestStore(t *testing.T) {
Name: "name", Name: "name",
Email: "name@example.com", Email: "name@example.com",
} }
any, _ := ptypes.MarshalAny(u) any, _ := anypb.New(u)
s.UpdateRecord(&databroker.Record{ s.UpdateRecord(&databroker.Record{
Version: 1, Version: 1,
Type: any.GetTypeUrl(), Type: any.GetTypeUrl(),
@ -33,25 +28,23 @@ func TestStore(t *testing.T) {
Data: any, Data: any,
}) })
v, err := storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1")) v := s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.NoError(t, err)
assert.Equal(t, map[string]interface{}{ assert.Equal(t, map[string]interface{}{
"version": "v1", "version": "v1",
"id": "u1", "id": "u1",
"name": "name", "name": "name",
"email": "name@example.com", "email": "name@example.com",
}, v) }, toMap(v))
s.UpdateRecord(&databroker.Record{ s.UpdateRecord(&databroker.Record{
Version: 2, Version: 2,
Type: any.GetTypeUrl(), Type: any.GetTypeUrl(),
Id: u.GetId(), Id: u.GetId(),
Data: any, Data: any,
DeletedAt: ptypes.TimestampNow(), DeletedAt: timestamppb.Now(),
}) })
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1")) v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Error(t, err)
assert.Nil(t, v) assert.Nil(t, v)
s.UpdateRecord(&databroker.Record{ s.UpdateRecord(&databroker.Record{
@ -61,13 +54,11 @@ func TestStore(t *testing.T) {
Data: any, Data: any,
}) })
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1")) v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.NoError(t, err)
assert.NotNil(t, v) assert.NotNil(t, v)
s.ClearRecords() s.ClearRecords()
v, err = storage.ReadOne(ctx, s.opaStore, storage.MustParsePath("/databroker_data/type.googleapis.com/user.User/u1")) v = s.GetRecordData(any.GetTypeUrl(), u.GetId())
assert.Error(t, err)
assert.Nil(t, v) assert.Nil(t, v)
}) })
} }