mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
google: support groups for users outside of the organization (#2950)
* google: support groups for users outside of the organization * wrap error
This commit is contained in:
parent
9f4fc986ee
commit
ed6c3e5087
5 changed files with 89 additions and 38 deletions
|
@ -3,9 +3,13 @@ package databroker
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
|
@ -22,7 +26,20 @@ func (c *DataBroker) RefreshUser(ctx context.Context, req *directory.RefreshUser
|
|||
}
|
||||
|
||||
u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken())
|
||||
if err != nil {
|
||||
// if the returned error signals we should prefer existing information
|
||||
if errors.Is(err, directoryerrors.ErrPreferExistingInformation) {
|
||||
_, err = c.dataBrokerServer.Get(ctx, &databroker.GetRequest{
|
||||
Type: protoutil.GetTypeURL(new(directory.User)),
|
||||
Id: req.GetUserId(),
|
||||
})
|
||||
switch status.Code(err) {
|
||||
case codes.OK:
|
||||
return new(emptypb.Empty), nil
|
||||
case codes.NotFound: // go ahead and save the user that was returned
|
||||
default:
|
||||
return nil, fmt.Errorf("databroker: error retrieving existing user record for refresh: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
8
internal/directory/directoryerrors/errors.go
Normal file
8
internal/directory/directoryerrors/errors.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
// Package directoryerrors contains errors used by directory providers.
|
||||
package directoryerrors
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrPreferExistingInformation indicates that the information returned by the provider should
|
||||
// only be used if a record is brand new, otherwise the existing information should be kept as is.
|
||||
var ErrPreferExistingInformation = errors.New("user ignored")
|
|
@ -17,6 +17,7 @@ import (
|
|||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
@ -100,7 +101,7 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
|
|||
Do()
|
||||
if isAccessDenied(err) {
|
||||
// ignore forbidden errors as a user may login from another gsuite domain
|
||||
return du, nil
|
||||
return du, directoryerrors.ErrPreferExistingInformation
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting user: %w", err)
|
||||
} else {
|
||||
|
@ -138,6 +139,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||
}
|
||||
|
||||
// query all the groups
|
||||
var groups []*directory.Group
|
||||
err = apiClient.Groups.List().
|
||||
Context(ctx).
|
||||
|
@ -160,7 +162,37 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, fmt.Errorf("google: error getting groups: %w", err)
|
||||
}
|
||||
|
||||
// query all the user members for each group
|
||||
// - create a lookup table for the user (storing id and name)
|
||||
// (this includes users who aren't necessarily members of the same organization)
|
||||
// - create a lookup table for the user's groups
|
||||
userLookup := map[string]apiUserObject{}
|
||||
userIDToGroups := map[string][]string{}
|
||||
for _, group := range groups {
|
||||
group := group
|
||||
err = apiClient.Members.List(group.Id).
|
||||
Context(ctx).
|
||||
Pages(ctx, func(res *admin.Members) error {
|
||||
for _, member := range res.Members {
|
||||
// only include user objects
|
||||
if member.Type != "USER" {
|
||||
continue
|
||||
}
|
||||
|
||||
userLookup[member.Id] = apiUserObject{
|
||||
ID: member.Id,
|
||||
Email: member.Email,
|
||||
}
|
||||
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("google: error getting group members: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// query all the users in the organization
|
||||
err = apiClient.Users.List().
|
||||
Context(ctx).
|
||||
Customer(currentAccountCustomerID).
|
||||
|
@ -181,22 +213,6 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, fmt.Errorf("google: error getting users: %w", err)
|
||||
}
|
||||
|
||||
userIDToGroups := map[string][]string{}
|
||||
for _, group := range groups {
|
||||
group := group
|
||||
err = apiClient.Members.List(group.Id).
|
||||
Context(ctx).
|
||||
Pages(ctx, func(res *admin.Members) error {
|
||||
for _, member := range res.Members {
|
||||
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("google: error getting group members: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range userLookup {
|
||||
groups := userIDToGroups[u.ID]
|
||||
|
|
|
@ -12,6 +12,8 @@ import (
|
|||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
|
@ -85,7 +87,7 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#groups",
|
||||
"groups": []M{
|
||||
{"id": "group1", "directMembersCount": "1"},
|
||||
{"id": "group1", "directMembersCount": "2"},
|
||||
{"id": "group2"},
|
||||
},
|
||||
})
|
||||
|
@ -97,8 +99,16 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
_ = json.NewEncoder(w).Encode(M{
|
||||
"members": []M{
|
||||
{
|
||||
"kind": "admin#directory#member",
|
||||
"id": "user1",
|
||||
"kind": "admin#directory#member",
|
||||
"id": "inside-user1",
|
||||
"email": "user1@inside.test",
|
||||
"type": "USER",
|
||||
},
|
||||
{
|
||||
"kind": "admin#directory#member",
|
||||
"id": "outside-user1",
|
||||
"email": "user1@outside.test",
|
||||
"type": "USER",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
@ -112,24 +122,24 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
"users": []M{
|
||||
{
|
||||
"kind": "admin#directory#user",
|
||||
"id": "user1",
|
||||
"primaryEmail": "user1@example.com",
|
||||
"id": "inside-user1",
|
||||
"primaryEmail": "user1@inside.test",
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user1":
|
||||
case "inside-user1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#user",
|
||||
"id": "user1",
|
||||
"id": "inside-user1",
|
||||
"name": M{
|
||||
"fullName": "User 1",
|
||||
},
|
||||
"primaryEmail": "user1@example.com",
|
||||
"primaryEmail": "user1@inside.test",
|
||||
})
|
||||
case "user2":
|
||||
case "outside-user1":
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
|
@ -158,20 +168,20 @@ func TestProvider_User(t *testing.T) {
|
|||
TokenURL: srv.URL + "/token",
|
||||
}), WithURL(srv.URL))
|
||||
|
||||
du, err := p.User(ctx, "user1", "")
|
||||
du, err := p.User(ctx, "inside-user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "user1", du.Id)
|
||||
assert.Equal(t, "user1@example.com", du.Email)
|
||||
assert.Equal(t, "inside-user1", du.Id)
|
||||
assert.Equal(t, "user1@inside.test", du.Email)
|
||||
assert.Equal(t, "User 1", du.DisplayName)
|
||||
assert.Equal(t, []string{"group1", "group2"}, du.GroupIds)
|
||||
|
||||
du, err = p.User(ctx, "user2", "")
|
||||
if !assert.NoError(t, err) {
|
||||
du, err = p.User(ctx, "outside-user1", "")
|
||||
if assert.ErrorIs(t, err, directoryerrors.ErrPreferExistingInformation) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "user2", du.Id)
|
||||
assert.Equal(t, "outside-user1", du.Id)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
|
@ -199,7 +209,8 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
assert.Equal(t, []*directory.Group{
|
||||
{Id: "group1"},
|
||||
}, dgs)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{Id: "user1", Email: "user1@example.com", GroupIds: []string{"group1"}},
|
||||
}, dus)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "inside-user1", "email": "user1@inside.test", "groupIds": ["group1"] },
|
||||
{ "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] }
|
||||
]`, dus)
|
||||
}
|
||||
|
|
|
@ -9,8 +9,6 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/ping"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/auth0"
|
||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||
"github.com/pomerium/pomerium/internal/directory/github"
|
||||
|
@ -18,6 +16,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/directory/google"
|
||||
"github.com/pomerium/pomerium/internal/directory/okta"
|
||||
"github.com/pomerium/pomerium/internal/directory/onelogin"
|
||||
"github.com/pomerium/pomerium/internal/directory/ping"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue