authorize: increase test coverage

- Add test cases for sync functions
 - Add test for valid JWT
 - Add session state to Test_getEvaluatorRequest
This commit is contained in:
Cuong Manh Le 2020-08-06 15:06:50 +07:00
parent 0624658e4b
commit 5d3b551524
3 changed files with 239 additions and 28 deletions

View file

@ -12,6 +12,7 @@ import (
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
envoy_type "github.com/envoyproxy/go-control-plane/envoy/type" envoy_type "github.com/envoyproxy/go-control-plane/envoy/type"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -19,20 +20,49 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
) )
func TestAuthorize_okResponse(t *testing.T) { func TestAuthorize_okResponse(t *testing.T) {
a := new(Authorize) opt := &config.Options{
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "") AuthenticateURL: mustParseURL("https://authenticate.example.com"),
a.currentEncoder.Store(encoder)
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{ Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{ SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"}, Rego: []string{"allow = true"},
}}, }},
}}, }},
}) JWTClaimsHeaders: []string{"email"},
}
a := new(Authorize)
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "")
a.currentEncoder.Store(encoder)
a.currentOptions.Store(opt)
a.store = evaluator.NewStore()
pe, err := newPolicyEvaluator(opt, a.store)
require.NoError(t, err)
a.pe = pe
validJWT, _ := a.pe.SignedJWT(a.pe.JWTPayload(&evaluator.Request{
DataBrokerData: evaluator.DataBrokerData{
"type.googleapis.com/session.Session": map[string]interface{}{
"SESSION_ID": &session.Session{
UserId: "USER_ID",
},
},
"type.googleapis.com/user.User": map[string]interface{}{
"USER_ID": &user.User{
Id: "USER_ID",
Name: "foo",
Email: "foo@example.com",
},
},
},
HTTP: evaluator.RequestHTTP{URL: "https://example.com"},
Session: evaluator.RequestSession{
ID: "SESSION_ID",
},
}))
originalGCPIdentityDocURL := gcpIdentityDocURL originalGCPIdentityDocURL := gcpIdentityDocURL
defer func() { defer func() {
@ -142,6 +172,25 @@ func TestAuthorize_okResponse(t *testing.T) {
}, },
}, },
}, },
{
"ok reply with jwt claims header",
&evaluator.Result{
Status: 0,
Message: "ok",
SignedJWT: validJWT,
},
&envoy_service_auth_v2.CheckResponse{
Status: &status.Status{Code: 0, Message: "ok"},
HttpResponse: &envoy_service_auth_v2.CheckResponse_OkResponse{
OkResponse: &envoy_service_auth_v2.OkHttpResponse{
Headers: []*envoy_api_v2_core.HeaderValueOption{
mkHeader("x-pomerium-claim-email", "foo@example.com", false),
mkHeader("x-pomerium-jwt-assertion", validJWT, false),
},
},
},
},
},
} }
for _, tc := range tests { for _, tc := range tests {

View file

@ -126,10 +126,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
defer span.End() defer span.End()
a.dataBrokerDataLock.RLock() a.dataBrokerDataLock.RLock()
s, ok := a.dataBrokerData.Get(userTypeURL, userID).(*user.User) u, ok := a.dataBrokerData.Get(userTypeURL, userID).(*user.User)
a.dataBrokerDataLock.RUnlock() a.dataBrokerDataLock.RUnlock()
if ok { if ok {
return s return u
} }
res, err := a.dataBrokerClient.Get(ctx, &databroker.GetRequest{ res, err := a.dataBrokerClient.Get(ctx, &databroker.GetRequest{
@ -145,10 +145,10 @@ func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User
if current := a.dataBrokerData.Get(userTypeURL, userID); current == nil { if current := a.dataBrokerData.Get(userTypeURL, userID); current == nil {
a.dataBrokerData.Update(res.GetRecord()) a.dataBrokerData.Update(res.GetRecord())
} }
s, _ = a.dataBrokerData.Get(userTypeURL, userID).(*user.User) u, _ = a.dataBrokerData.Get(userTypeURL, userID).(*user.User)
a.dataBrokerDataLock.Unlock() a.dataBrokerDataLock.Unlock()
return s return u
} }
func (a *Authorize) getEnvoyRequestHeaders(signedJWT string) ([]*envoy_api_v2_core.HeaderValueOption, error) { func (a *Authorize) getEnvoyRequestHeaders(signedJWT string) ([]*envoy_api_v2_core.HeaderValueOption, error) {

View file

@ -1,16 +1,25 @@
package authorize package authorize
import ( import (
"context"
"errors"
"net/url" "net/url"
"testing" "testing"
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
) )
const certPEM = ` const certPEM = `
@ -50,29 +59,40 @@ func Test_getEvaluatorRequest(t *testing.T) {
}}, }},
}) })
actual := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v2.CheckRequest{ actual := a.getEvaluatorRequestFromCheckRequest(
Attributes: &envoy_service_auth_v2.AttributeContext{ &envoy_service_auth_v2.CheckRequest{
Source: &envoy_service_auth_v2.AttributeContext_Peer{ Attributes: &envoy_service_auth_v2.AttributeContext{
Certificate: url.QueryEscape(certPEM), Source: &envoy_service_auth_v2.AttributeContext_Peer{
}, Certificate: url.QueryEscape(certPEM),
Request: &envoy_service_auth_v2.AttributeContext_Request{ },
Http: &envoy_service_auth_v2.AttributeContext_HttpRequest{ Request: &envoy_service_auth_v2.AttributeContext_Request{
Id: "id-1234", Http: &envoy_service_auth_v2.AttributeContext_HttpRequest{
Method: "GET", Id: "id-1234",
Headers: map[string]string{ Method: "GET",
"accept": "text/html", Headers: map[string]string{
"x-forwarded-proto": "https", "accept": "text/html",
"x-forwarded-proto": "https",
},
Path: "/some/path?qs=1",
Host: "example.com",
Scheme: "http",
Body: "BODY",
}, },
Path: "/some/path?qs=1",
Host: "example.com",
Scheme: "http",
Body: "BODY",
}, },
}, },
}, },
}, nil) &sessions.State{
ID: "SESSION_ID",
ImpersonateEmail: "foo@example.com",
ImpersonateGroups: []string{"admin", "test"},
},
)
expect := &evaluator.Request{ expect := &evaluator.Request{
Session: evaluator.RequestSession{}, Session: evaluator.RequestSession{
ID: "SESSION_ID",
ImpersonateEmail: "foo@example.com",
ImpersonateGroups: []string{"admin", "test"},
},
HTTP: evaluator.RequestHTTP{ HTTP: evaluator.RequestHTTP{
Method: "GET", Method: "GET",
URL: "https://example.com/some/path?qs=1", URL: "https://example.com/some/path?qs=1",
@ -254,7 +274,7 @@ func Test_handleForwardAuth(t *testing.T) {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
a := new(Authorize) a := new(Authorize)
fau := new(url.URL) var fau *url.URL
if tc.forwardAuthURL != "" { if tc.forwardAuthURL != "" {
fau = mustParseURL(tc.forwardAuthURL) fau = mustParseURL(tc.forwardAuthURL)
} }
@ -317,6 +337,138 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
} }
func TestSync(t *testing.T) {
mockSession := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, _ := ptypes.MarshalAny(&session.Session{
Id: in.GetId(),
UserId: "user1",
})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
}
mockUser := func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, _ := ptypes.MarshalAny(&user.User{Id: in.GetId()})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
}
mockGetByType := map[string]func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error){
"type.googleapis.com/session.Session": mockSession,
"type.googleapis.com/user.User": mockUser,
}
dbdClient := mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
if in.GetId() == "not-existed-id" {
return nil, errors.New("not found")
}
f, ok := mockGetByType[in.GetType()]
if !ok {
return nil, errors.New("not found")
}
return f(ctx, in, opts...)
},
}
o := &config.Options{
AuthenticateURL: mustParseURL("https://authN.example.com"),
DataBrokerURL: mustParseURL("https://cache.example.com"),
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
Policies: testPolicies(t),
}
ctx := context.Background()
tests := []struct {
name string
sessionState *sessions.State
databrokerClient mockDataBrokerServiceClient
wantErr bool
}{
{
"good with data in databroker data",
&sessions.State{ID: "dbd_session_id"},
mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
data, _ := ptypes.MarshalAny(&session.Session{
Id: in.GetId(),
UserId: "dbd_user1",
})
if in.GetType() == "type.googleapis.com/user.User" {
data, _ = ptypes.MarshalAny(&user.User{
Id: "dbd_user1",
})
}
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
},
},
false,
},
{"good", &sessions.State{ID: "SESSION_ID"}, dbdClient, false},
{"nil session state", nil, dbdClient, false},
{"not found session state", &sessions.State{ID: "not-existed-id"}, dbdClient, true},
{
"user not found",
&sessions.State{ID: "session_with_not_found_user"},
mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
if in.GetType() == "type.googleapis.com/user.User" {
return nil, errors.New("user not found")
}
data, _ := ptypes.MarshalAny(&session.Session{
Id: in.GetId(),
UserId: "user1",
})
return &databroker.GetResponse{
Record: &databroker.Record{
Version: "0001",
Type: data.GetTypeUrl(),
Id: in.GetId(),
Data: data,
},
}, nil
},
},
false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
a, err := New(o)
require.NoError(t, err)
a.dataBrokerData = evaluator.DataBrokerData{
"type.googleapis.com/session.Session": map[string]interface{}{
"dbd_session_id": &session.Session{UserId: "dbd_user1"},
},
"type.googleapis.com/user.User": map[string]interface{}{
"dbd_user1": &user.User{Id: "dbd_user1"},
},
}
a.dataBrokerClient = tc.databrokerClient
assert.True(t, (a.forceSync(ctx, tc.sessionState) != nil) == tc.wantErr)
})
}
}
func mustParseURL(str string) *url.URL { func mustParseURL(str string) *url.URL {
u, err := url.Parse(str) u, err := url.Parse(str)
if err != nil { if err != nil {
@ -324,3 +476,13 @@ func mustParseURL(str string) *url.URL {
} }
return u return u
} }
type mockDataBrokerServiceClient struct {
databroker.DataBrokerServiceClient
get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
}
func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return m.get(ctx, in, opts...)
}