mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-05 20:32:57 +02:00
directory/azure: add paging support to user group members call (#2311)
This commit is contained in:
parent
fcb33966e2
commit
b1d7a126ab
4 changed files with 115 additions and 51 deletions
45
internal/directory/azure/api.go
Normal file
45
internal/directory/azure/api.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package azure
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type (
|
||||||
|
apiGetUserResponse struct {
|
||||||
|
apiUser
|
||||||
|
}
|
||||||
|
apiGetUserMembersResponse struct {
|
||||||
|
Context string `json:"@odata.context"`
|
||||||
|
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||||
|
Value []apiGroup `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
apiGroup struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
}
|
||||||
|
apiUser struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
Mail string `json:"mail"`
|
||||||
|
UserPrincipalName string `json:"userPrincipalName"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (obj apiUser) getEmail() string {
|
||||||
|
if obj.Mail != "" {
|
||||||
|
return obj.Mail
|
||||||
|
}
|
||||||
|
|
||||||
|
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
||||||
|
|
||||||
|
// UPN looks like:
|
||||||
|
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
||||||
|
email := obj.UserPrincipalName
|
||||||
|
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
||||||
|
email = email[:idx]
|
||||||
|
}
|
||||||
|
// find the last _ and replace it with @
|
||||||
|
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
||||||
|
email = email[:idx] + "@" + email[idx+1:]
|
||||||
|
}
|
||||||
|
return email
|
||||||
|
}
|
|
@ -115,30 +115,17 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
|
||||||
Path: fmt.Sprintf("/v1.0/users/%s", userID),
|
Path: fmt.Sprintf("/v1.0/users/%s", userID),
|
||||||
}).String()
|
}).String()
|
||||||
|
|
||||||
var u usersDeltaResponseUser
|
var u apiGetUserResponse
|
||||||
err := p.api(ctx, userURL, &u)
|
err := p.api(ctx, userURL, &u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
du.DisplayName = u.DisplayName
|
du.DisplayName = u.DisplayName
|
||||||
du.Email = u.getEmail()
|
du.Email = u.getEmail()
|
||||||
|
du.GroupIds, err = p.transitiveMemberOf(ctx, userID)
|
||||||
groupURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
|
||||||
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
|
|
||||||
}).String()
|
|
||||||
|
|
||||||
var res struct {
|
|
||||||
Value []usersDeltaResponseUser `json:"value"`
|
|
||||||
}
|
|
||||||
err = p.api(ctx, groupURL, &res)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, g := range res.Value {
|
|
||||||
du.GroupIds = append(du.GroupIds, g.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(du.GroupIds)
|
|
||||||
|
|
||||||
return du, nil
|
return du, nil
|
||||||
}
|
}
|
||||||
|
@ -246,6 +233,28 @@ func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
||||||
return p.token, nil
|
return p.token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) transitiveMemberOf(ctx context.Context, userID string) (groupIDs []string, err error) {
|
||||||
|
apiURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
|
||||||
|
}).String()
|
||||||
|
for {
|
||||||
|
var res apiGetUserMembersResponse
|
||||||
|
err := p.api(ctx, apiURL, &res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, g := range res.Value {
|
||||||
|
groupIDs = append(groupIDs, g.ID)
|
||||||
|
}
|
||||||
|
if res.NextLink == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
apiURL = res.NextLink
|
||||||
|
}
|
||||||
|
sort.Strings(groupIDs)
|
||||||
|
return groupIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// A ServiceAccount is used by the Azure provider to query the Microsoft Graph API.
|
// A ServiceAccount is used by the Azure provider to query the Microsoft Graph API.
|
||||||
type ServiceAccount struct {
|
type ServiceAccount struct {
|
||||||
ClientID string `json:"client_id"`
|
ClientID string `json:"client_id"`
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
@ -86,11 +87,28 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch chi.URLParam(r, "user_id") {
|
switch chi.URLParam(r, "user_id") {
|
||||||
case "user-1":
|
case "user-1":
|
||||||
_ = json.NewEncoder(w).Encode(M{
|
switch r.URL.Query().Get("page") {
|
||||||
"value": []M{
|
case "":
|
||||||
{"id": "admin"},
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
},
|
"value": []M{
|
||||||
})
|
{"id": "admin"},
|
||||||
|
},
|
||||||
|
"@odata.nextLink": getPageURL(r, 1),
|
||||||
|
})
|
||||||
|
case "1":
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"value": []M{
|
||||||
|
{"id": "group1"},
|
||||||
|
},
|
||||||
|
"@odata.nextLink": getPageURL(r, 2),
|
||||||
|
})
|
||||||
|
case "2":
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"value": []M{
|
||||||
|
{"id": "group2"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
http.Error(w, "not found", http.StatusNotFound)
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
@ -126,7 +144,7 @@ func TestProvider_User(t *testing.T) {
|
||||||
"id": "user-1",
|
"id": "user-1",
|
||||||
"displayName": "User 1",
|
"displayName": "User 1",
|
||||||
"email": "user1@example.com",
|
"email": "user1@example.com",
|
||||||
"groupIds": ["admin"]
|
"groupIds": ["admin", "group1", "group2"]
|
||||||
}`, du)
|
}`, du)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,3 +237,20 @@ func mustParseURL(rawurl string) *url.URL {
|
||||||
}
|
}
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getPageURL(r *http.Request, page int) string {
|
||||||
|
var u url.URL
|
||||||
|
u = *r.URL
|
||||||
|
if r.TLS == nil {
|
||||||
|
u.Scheme = "http"
|
||||||
|
} else {
|
||||||
|
u.Scheme = "https"
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
u.Host = r.Host
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("page", strconv.Itoa(page))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
@ -229,10 +228,9 @@ type (
|
||||||
Value []groupsDeltaResponseGroup `json:"value"`
|
Value []groupsDeltaResponseGroup `json:"value"`
|
||||||
}
|
}
|
||||||
groupsDeltaResponseGroup struct {
|
groupsDeltaResponseGroup struct {
|
||||||
ID string `json:"id"`
|
apiGroup
|
||||||
DisplayName string `json:"displayName"`
|
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
|
||||||
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
|
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
|
||||||
}
|
}
|
||||||
groupsDeltaResponseGroupMember struct {
|
groupsDeltaResponseGroupMember struct {
|
||||||
Type string `json:"@odata.type"`
|
Type string `json:"@odata.type"`
|
||||||
|
@ -247,30 +245,7 @@ type (
|
||||||
Value []usersDeltaResponseUser `json:"value"`
|
Value []usersDeltaResponseUser `json:"value"`
|
||||||
}
|
}
|
||||||
usersDeltaResponseUser struct {
|
usersDeltaResponseUser struct {
|
||||||
ID string `json:"id"`
|
apiUser
|
||||||
DisplayName string `json:"displayName"`
|
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||||
Mail string `json:"mail"`
|
|
||||||
UserPrincipalName string `json:"userPrincipalName"`
|
|
||||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (obj usersDeltaResponseUser) getEmail() string {
|
|
||||||
if obj.Mail != "" {
|
|
||||||
return obj.Mail
|
|
||||||
}
|
|
||||||
|
|
||||||
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
|
||||||
|
|
||||||
// UPN looks like:
|
|
||||||
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
|
||||||
email := obj.UserPrincipalName
|
|
||||||
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
|
||||||
email = email[:idx]
|
|
||||||
}
|
|
||||||
// find the last _ and replace it with @
|
|
||||||
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
|
||||||
email = email[:idx] + "@" + email[idx+1:]
|
|
||||||
}
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue