pomerium/internal/directory/ping/provider_test.go
dependabot[bot] 58ca681f40
chore(deps): bump github.com/go-chi/chi from 1.5.4 to 4.1.2+incompatible (#2910)
* chore(deps): bump github.com/go-chi/chi from 1.5.4 to 4.1.2+incompatible

Bumps [github.com/go-chi/chi](https://github.com/go-chi/chi) from 1.5.4 to 4.1.2+incompatible.
- [Release notes](https://github.com/go-chi/chi/releases)
- [Changelog](https://github.com/go-chi/chi/blob/master/CHANGELOG.md)
- [Commits](https://github.com/go-chi/chi/compare/v1.5.4...v4.1.2)

---
updated-dependencies:
- dependency-name: github.com/go-chi/chi
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* upgrade chi

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
2022-01-10 10:50:11 -07:00

230 lines
5.6 KiB
Go

package ping
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
)
type M = map[string]interface{}
func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler {
lookup := map[string]struct{}{}
for _, groups := range userIDToGroupIDs {
for _, group := range groups {
lookup[group] = struct{}{}
}
}
var allGroups []string
for groupID := range lookup {
allGroups = append(allGroups, groupID)
}
sort.Strings(allGroups)
var allUserIDs []string
for userID := range userIDToGroupIDs {
allUserIDs = append(allUserIDs, userID)
}
sort.Strings(allUserIDs)
filterToUserIDs := map[string][]string{}
for userID, groupIDs := range userIDToGroupIDs {
for _, groupID := range groupIDs {
filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)
filterToUserIDs[filter] = append(filterToUserIDs[filter], userID)
}
}
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) {
u, p, _ := r.BasicAuth()
if u != "CLIENTID" || p != "CLIENTSECRET" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
grantType := r.FormValue("grant_type")
if grantType != "client_credentials" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"created_at": time.Now().Format(time.RFC3339),
"expires_in": 360000,
"refresh_token": "REFRESHTOKEN",
"token_type": "bearer",
})
})
r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) {
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
var apiGroups []apiGroup
for _, id := range allGroups {
apiGroups = append(apiGroups, apiGroup{
ID: id,
Name: "Group " + id,
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"_embedded": M{
"groups": apiGroups,
},
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "user_id")
groupIDs, ok := userIDToGroupIDs[userID]
if !ok {
http.NotFound(w, r)
return
}
au := apiUser{
ID: userID,
Email: userID + "@example.com",
Name: apiUserName{
Given: "Given-" + userID,
Middle: "Middle-" + userID,
Family: "Family-" + userID,
},
}
if r.URL.Query().Get("include") == "memberOfGroupIDs" {
au.MemberOfGroupIDs = groupIDs
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(au)
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
filter := r.URL.Query().Get("filter")
userIDs, ok := filterToUserIDs[filter]
if !ok {
http.Error(w, "expected filter", http.StatusBadRequest)
return
}
var apiUsers []apiUser
for _, id := range userIDs {
apiUsers = append(apiUsers, apiUser{
ID: id,
Email: id + "@example.com",
Name: apiUserName{
Given: "Given-" + id,
Middle: "Middle-" + id,
Family: "Family-" + id,
},
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"_embedded": M{
"users": apiUsers,
},
})
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
srv := httptest.NewServer(newMockAPI(map[string][]string{
"user1": {"group1", "group2"},
"user2": {"group1", "group3"},
"user3": {"group3"},
}))
defer srv.Close()
u, err := url.Parse(srv.URL)
require.NoError(t, err)
p := New(
WithAPIURL(u),
WithAuthURL(u),
WithEnvironmentID("ENVIRONMENTID"),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}))
du, err := p.User(ctx, "user1", "")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{
"id": "user1",
"email": "user1@example.com",
"displayName": "Given-user1 Middle-user1 Family-user1",
"groupIds": ["group1", "group2"]
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
srv := httptest.NewServer(newMockAPI(map[string][]string{
"user1": {"group1", "group2"},
"user2": {"group1", "group3"},
"user3": {"group3"},
}))
defer srv.Close()
u, err := url.Parse(srv.URL)
require.NoError(t, err)
p := New(
WithAPIURL(u),
WithAuthURL(u),
WithEnvironmentID("ENVIRONMENTID"),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}))
dgs, dus, err := p.UserGroups(ctx)
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "group1", "name": "Group group1" },
{ "id": "group2", "name": "Group group2" },
{ "id": "group3", "name": "Group group3" }
]`, dgs)
testutil.AssertProtoJSONEqual(t, `[
{
"id": "user1",
"email": "user1@example.com",
"displayName": "Given-user1 Middle-user1 Family-user1",
"groupIds": ["group1", "group2"]
},
{
"id": "user2",
"email": "user2@example.com",
"displayName": "Given-user2 Middle-user2 Family-user2",
"groupIds": ["group1", "group3"]
},
{
"id": "user3",
"email": "user3@example.com",
"displayName": "Given-user3 Middle-user3 Family-user3",
"groupIds": ["group3"]
}
]`, dus)
}