package ping import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "sort" "testing" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/testutil" ) type M = map[string]interface{} func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler { lookup := map[string]struct{}{} for _, groups := range userIDToGroupIDs { for _, group := range groups { lookup[group] = struct{}{} } } var allGroups []string for groupID := range lookup { allGroups = append(allGroups, groupID) } sort.Strings(allGroups) var allUserIDs []string for userID := range userIDToGroupIDs { allUserIDs = append(allUserIDs, userID) } sort.Strings(allUserIDs) filterToUserIDs := map[string][]string{} for userID, groupIDs := range userIDToGroupIDs { for _, groupID := range groupIDs { filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID) filterToUserIDs[filter] = append(filterToUserIDs[filter], userID) } } r := chi.NewRouter() r.Use(middleware.Logger) r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) { u, p, _ := r.BasicAuth() if u != "CLIENTID" || p != "CLIENTSECRET" { http.Error(w, "forbidden", http.StatusForbidden) return } grantType := r.FormValue("grant_type") if grantType != "client_credentials" { http.Error(w, "invalid grant_type", http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(M{ "access_token": "ACCESSTOKEN", "created_at": time.Now().Format(time.RFC3339), "expires_in": 360000, "refresh_token": "REFRESHTOKEN", "token_type": "bearer", }) }) r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) { r.Get("/groups", func(w http.ResponseWriter, r *http.Request) { var apiGroups []apiGroup for _, id := range allGroups { apiGroups = append(apiGroups, apiGroup{ ID: id, Name: "Group " + id, }) } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(M{ "_embedded": M{ "groups": apiGroups, }, }) }) r.Route("/users", func(r chi.Router) { r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) { userID := chi.URLParam(r, "user_id") groupIDs, ok := userIDToGroupIDs[userID] if !ok { http.NotFound(w, r) return } au := apiUser{ ID: userID, Email: userID + "@example.com", Name: apiUserName{ Given: "Given-" + userID, Middle: "Middle-" + userID, Family: "Family-" + userID, }, } if r.URL.Query().Get("include") == "memberOfGroupIDs" { au.MemberOfGroupIDs = groupIDs } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(au) }) r.Get("/", func(w http.ResponseWriter, r *http.Request) { filter := r.URL.Query().Get("filter") userIDs, ok := filterToUserIDs[filter] if !ok { http.Error(w, "expected filter", http.StatusBadRequest) return } var apiUsers []apiUser for _, id := range userIDs { apiUsers = append(apiUsers, apiUser{ ID: id, Email: id + "@example.com", Name: apiUserName{ Given: "Given-" + id, Middle: "Middle-" + id, Family: "Family-" + id, }, }) } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(M{ "_embedded": M{ "users": apiUsers, }, }) }) }) }) return r } func TestProvider_User(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() srv := httptest.NewServer(newMockAPI(map[string][]string{ "user1": {"group1", "group2"}, "user2": {"group1", "group3"}, "user3": {"group3"}, })) defer srv.Close() u, err := url.Parse(srv.URL) require.NoError(t, err) p := New( WithAPIURL(u), WithAuthURL(u), WithEnvironmentID("ENVIRONMENTID"), WithServiceAccount(&ServiceAccount{ ClientID: "CLIENTID", ClientSecret: "CLIENTSECRET", })) du, err := p.User(ctx, "user1", "") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `{ "id": "user1", "email": "user1@example.com", "displayName": "Given-user1 Middle-user1 Family-user1", "groupIds": ["group1", "group2"] }`, du) } func TestProvider_UserGroups(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() srv := httptest.NewServer(newMockAPI(map[string][]string{ "user1": {"group1", "group2"}, "user2": {"group1", "group3"}, "user3": {"group3"}, })) defer srv.Close() u, err := url.Parse(srv.URL) require.NoError(t, err) p := New( WithAPIURL(u), WithAuthURL(u), WithEnvironmentID("ENVIRONMENTID"), WithServiceAccount(&ServiceAccount{ ClientID: "CLIENTID", ClientSecret: "CLIENTSECRET", })) dgs, dus, err := p.UserGroups(ctx) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ { "id": "group1", "name": "Group group1" }, { "id": "group2", "name": "Group group2" }, { "id": "group3", "name": "Group group3" } ]`, dgs) testutil.AssertProtoJSONEqual(t, `[ { "id": "user1", "email": "user1@example.com", "displayName": "Given-user1 Middle-user1 Family-user1", "groupIds": ["group1", "group2"] }, { "id": "user2", "email": "user2@example.com", "displayName": "Given-user2 Middle-user2 Family-user2", "groupIds": ["group1", "group3"] }, { "id": "user3", "email": "user3@example.com", "displayName": "Given-user3 Middle-user3 Family-user3", "groupIds": ["group3"] } ]`, dus) }