mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-02 02:42:57 +02:00
directory/azure: add paging support to user group members call (#2311)
This commit is contained in:
parent
fcb33966e2
commit
b1d7a126ab
4 changed files with 115 additions and 51 deletions
45
internal/directory/azure/api.go
Normal file
45
internal/directory/azure/api.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package azure
|
||||
|
||||
import "strings"
|
||||
|
||||
type (
|
||||
apiGetUserResponse struct {
|
||||
apiUser
|
||||
}
|
||||
apiGetUserMembersResponse struct {
|
||||
Context string `json:"@odata.context"`
|
||||
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||
Value []apiGroup `json:"value"`
|
||||
}
|
||||
|
||||
apiGroup struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
apiUser struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Mail string `json:"mail"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
}
|
||||
)
|
||||
|
||||
func (obj apiUser) getEmail() string {
|
||||
if obj.Mail != "" {
|
||||
return obj.Mail
|
||||
}
|
||||
|
||||
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
||||
|
||||
// UPN looks like:
|
||||
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
||||
email := obj.UserPrincipalName
|
||||
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
||||
email = email[:idx]
|
||||
}
|
||||
// find the last _ and replace it with @
|
||||
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
||||
email = email[:idx] + "@" + email[idx+1:]
|
||||
}
|
||||
return email
|
||||
}
|
|
@ -115,30 +115,17 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
|
|||
Path: fmt.Sprintf("/v1.0/users/%s", userID),
|
||||
}).String()
|
||||
|
||||
var u usersDeltaResponseUser
|
||||
var u apiGetUserResponse
|
||||
err := p.api(ctx, userURL, &u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = u.DisplayName
|
||||
du.Email = u.getEmail()
|
||||
|
||||
groupURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
|
||||
}).String()
|
||||
|
||||
var res struct {
|
||||
Value []usersDeltaResponseUser `json:"value"`
|
||||
}
|
||||
err = p.api(ctx, groupURL, &res)
|
||||
du.GroupIds, err = p.transitiveMemberOf(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range res.Value {
|
||||
du.GroupIds = append(du.GroupIds, g.ID)
|
||||
}
|
||||
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
@ -246,6 +233,28 @@ func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
|||
return p.token, nil
|
||||
}
|
||||
|
||||
func (p *Provider) transitiveMemberOf(ctx context.Context, userID string) (groupIDs []string, err error) {
|
||||
apiURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
|
||||
}).String()
|
||||
for {
|
||||
var res apiGetUserMembersResponse
|
||||
err := p.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range res.Value {
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
if res.NextLink == "" {
|
||||
break
|
||||
}
|
||||
apiURL = res.NextLink
|
||||
}
|
||||
sort.Strings(groupIDs)
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Azure provider to query the Microsoft Graph API.
|
||||
type ServiceAccount struct {
|
||||
ClientID string `json:"client_id"`
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
@ -86,11 +87,28 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user-1":
|
||||
switch r.URL.Query().Get("page") {
|
||||
case "":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "admin"},
|
||||
},
|
||||
"@odata.nextLink": getPageURL(r, 1),
|
||||
})
|
||||
case "1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "group1"},
|
||||
},
|
||||
"@odata.nextLink": getPageURL(r, 2),
|
||||
})
|
||||
case "2":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "group2"},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
|
@ -126,7 +144,7 @@ func TestProvider_User(t *testing.T) {
|
|||
"id": "user-1",
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com",
|
||||
"groupIds": ["admin"]
|
||||
"groupIds": ["admin", "group1", "group2"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
|
@ -219,3 +237,20 @@ func mustParseURL(rawurl string) *url.URL {
|
|||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func getPageURL(r *http.Request, page int) string {
|
||||
var u url.URL
|
||||
u = *r.URL
|
||||
if r.TLS == nil {
|
||||
u.Scheme = "http"
|
||||
} else {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = r.Host
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
@ -229,8 +228,7 @@ type (
|
|||
Value []groupsDeltaResponseGroup `json:"value"`
|
||||
}
|
||||
groupsDeltaResponseGroup struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
apiGroup
|
||||
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
|
@ -247,30 +245,7 @@ type (
|
|||
Value []usersDeltaResponseUser `json:"value"`
|
||||
}
|
||||
usersDeltaResponseUser struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Mail string `json:"mail"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
apiUser
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
func (obj usersDeltaResponseUser) getEmail() string {
|
||||
if obj.Mail != "" {
|
||||
return obj.Mail
|
||||
}
|
||||
|
||||
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
||||
|
||||
// UPN looks like:
|
||||
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
||||
email := obj.UserPrincipalName
|
||||
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
||||
email = email[:idx]
|
||||
}
|
||||
// find the last _ and replace it with @
|
||||
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
||||
email = email[:idx] + "@" + email[idx+1:]
|
||||
}
|
||||
return email
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue