directory: add explicit RefreshUser endpoint for faster sync (#1460)

* directory: add explicit RefreshUser endpoint for faster sync

* add test

* implement azure

* update api call

* add test for azure User

* implement github

* implement AccessToken, gitlab

* implement okta

* implement onelogin

* fix test

* fix inconsistent test

* implement auth0
This commit is contained in:
Caleb Doxsey 2020-10-05 08:23:15 -06:00 committed by GitHub
parent 9b39deabd8
commit aa731ae068
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1405 additions and 179 deletions

View file

@ -596,5 +596,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
}
sessionState.Version = sessions.Version(res.GetServerVersion())
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
UserId: s.UserId,
AccessToken: accessToken.AccessToken,
})
if err != nil {
log.Error().Err(err).Msg("directory: failed to refresh user data")
}
return nil
}

View file

@ -14,6 +14,8 @@ import (
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/config"
@ -29,10 +31,12 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305"
@ -171,6 +175,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
@ -262,6 +267,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
templates: template.Must(frontend.NewTemplates()),
options: config.NewAtomicOptions(),
@ -366,6 +372,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
@ -515,6 +522,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
provider: identity.NewAtomicAuthenticator(),
@ -633,6 +641,7 @@ func TestAuthenticate_Dashboard(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
templates: template.Must(frontend.NewTemplates()),
}
@ -679,3 +688,16 @@ func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.Get
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
return m.set(ctx, in, opts...)
}
type mockDirectoryServiceClient struct {
directory.DirectoryServiceClient
refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
}
func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
if m.refreshUser != nil {
return m.refreshUser(ctx, in, opts...)
}
return nil, status.Error(codes.Unimplemented, "")
}

View file

@ -22,6 +22,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
type authenticateState struct {
@ -46,6 +47,7 @@ type authenticateState struct {
jwk *jose.JSONWebKeySet
dataBrokerClient databroker.DataBrokerServiceClient
directoryClient directory.DirectoryServiceClient
}
func newAuthenticateState() *authenticateState {
@ -129,6 +131,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
}
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
state.directoryClient = directory.NewDirectoryServiceClient(dataBrokerConn)
return state, nil
}

8
cache/cache.go vendored
View file

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"net"
"sync"
"google.golang.org/grpc"
"gopkg.in/tomb.v2"
@ -33,6 +34,9 @@ type Cache struct {
localGRPCConnection *grpc.ClientConn
dataBrokerStorageType string //TODO remove in v0.11
deprecatedCacheClusterDomain string //TODO: remove in v0.11
mu sync.Mutex
directoryProvider directory.Provider
}
// New creates a new cache service.
@ -90,6 +94,7 @@ func (c *Cache) OnConfigChange(cfg *config.Config) {
// Register registers all the gRPC services with the given server.
func (c *Cache) Register(grpcServer *grpc.Server) {
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
directory.RegisterDirectoryServiceServer(grpcServer, c)
}
// Run runs the cache components.
@ -132,6 +137,9 @@ func (c *Cache) update(cfg *config.Config) error {
ClientID: cfg.Options.ClientID,
ClientSecret: cfg.Options.ClientSecret,
})
c.mu.Lock()
c.directoryProvider = directoryProvider
c.mu.Unlock()
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)

44
cache/directory.go vendored Normal file
View file

@ -0,0 +1,44 @@
package cache
import (
"context"
"errors"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// RefreshUser refreshes a user's directory information.
func (c *Cache) RefreshUser(ctx context.Context, req *directory.RefreshUserRequest) (*emptypb.Empty, error) {
c.mu.Lock()
dp := c.directoryProvider
c.mu.Unlock()
if dp == nil {
return nil, errors.New("no directory provider is available for refresh")
}
u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken())
if err != nil {
return nil, err
}
any, err := anypb.New(u)
if err != nil {
return nil, err
}
_, err = c.dataBrokerServer.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
})
if err != nil {
return nil, err
}
return new(emptypb.Empty), nil
}

View file

@ -21,16 +21,33 @@ import (
// Name is the provider name.
const Name = "auth0"
// RoleManager defines what is needed to get role info from Auth0.
type RoleManager interface {
type (
// RoleManager defines what is needed to get role info from Auth0.
RoleManager interface {
List(opts ...management.ListOption) (r *management.RoleList, err error)
Users(id string, opts ...management.ListOption) (u *management.UserList, err error)
}
// UserManager defines what is needed to get user info from Auth0.
UserManager interface {
Read(id string) (*management.User, error)
Roles(id string, opts ...management.ListOption) (r *management.RoleList, err error)
}
)
type newManagersFunc = func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error)
func defaultNewManagersFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx))
if err != nil {
return nil, nil, fmt.Errorf("auth0: could not build management: %w", err)
}
return m.Role, m.User, nil
}
type config struct {
domain string
serviceAccount *ServiceAccount
newRoleManager func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error)
newManagers newManagersFunc
}
// Option provides config for the Auth0 Provider.
@ -50,17 +67,9 @@ func WithDomain(domain string) Option {
}
}
func defaultNewRoleManagerFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) {
m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("auth0: could not build management")
}
return m.Role, nil
}
func getConfig(options ...Option) *config {
cfg := &config{
newRoleManager: defaultNewRoleManagerFunc,
newManagers: defaultNewManagersFunc,
}
for _, option := range options {
option(cfg)
@ -82,13 +91,49 @@ func New(options ...Option) *Provider {
}
}
func (p *Provider) getRoleManager(ctx context.Context) (RoleManager, error) {
return p.cfg.newRoleManager(ctx, p.cfg.domain, p.cfg.serviceAccount)
func (p *Provider) getManagers(ctx context.Context) (RoleManager, UserManager, error) {
return p.cfg.newManagers(ctx, p.cfg.domain, p.cfg.serviceAccount)
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
_, um, err := p.getManagers(ctx)
if err != nil {
return nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
}
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
u, err := um.Read(providerUserID)
if err != nil {
return nil, fmt.Errorf("auth0: error getting user info: %w", err)
}
du.DisplayName = u.GetName()
du.Email = u.GetEmail()
for page, hasNext := 0, true; hasNext; page++ {
rl, err := um.Roles(providerUserID, management.IncludeTotals(true), management.Page(page))
if err != nil {
return nil, fmt.Errorf("auth0: error getting user roles: %w", err)
}
for _, role := range rl.Roles {
du.GroupIds = append(du.GroupIds, role.GetID())
}
hasNext = rl.HasNext()
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups fetches a slice of groups and users.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
rm, err := p.getRoleManager(ctx)
rm, _, err := p.getManagers(ctx)
if err != nil {
return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
}
@ -147,6 +192,9 @@ func getRoles(rm RoleManager) ([]*directory.Group, error) {
shouldContinue = listRes.HasNext()
}
sort.Slice(roles, func(i, j int) bool {
return roles[i].GetId() < roles[j].GetId()
})
return roles, nil
}
@ -170,6 +218,7 @@ func getRoleUserIDs(rm RoleManager, roleID string) ([]string, error) {
shouldContinue = usersRes.HasNext()
}
sort.Strings(ids)
return ids, nil
}
@ -180,7 +229,30 @@ type ServiceAccount struct {
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
if options.ServiceAccount != "" {
return parseServiceAccountFromString(options.ServiceAccount)
}
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret)
}
func parseServiceAccountFromOptions(clientID, clientSecret string) (*ServiceAccount, error) {
serviceAccount := ServiceAccount{
ClientID: clientID,
Secret: clientSecret,
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("auth0: client_id is required")
}
if serviceAccount.Secret == "" {
return nil, fmt.Errorf("auth0: client_secret is required")
}
return &serviceAccount, nil
}
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, fmt.Errorf("auth0: could not decode base64: %w", err)

View file

@ -2,34 +2,133 @@ package auth0
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"gopkg.in/auth0.v4/management"
"github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"token_type": "Bearer",
"refresh_token": "REFRESHTOKEN",
})
})
r.Route("/api/v2", func(r chi.Router) {
r.Route("/users", func(r chi.Router) {
r.Route("/{user_id}", func(r chi.Router) {
r.Get("/roles", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user1":
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"total": 2,
"limit": 2,
"roles": []M{
{"id": "role1"},
{"id": "role2"},
},
})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user1":
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"user_id": "user1",
"email": "user1@example.com",
"name": "User 1",
})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute)
defer clearTimeout()
orig := http.DefaultTransport
defer func() {
http.DefaultTransport = orig
}()
http.DefaultTransport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
var mockAPI http.Handler
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
srv.StartTLS()
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithDomain(srv.URL),
WithServiceAccount(&ServiceAccount{ClientID: "CLIENT_ID", Secret: "SECRET"}),
)
du, err := p.User(ctx, "auth0/user1", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "auth0/user1",
"displayName": "User 1",
"email": "user1@example.com",
"groupIds": ["role1", "role2"]
}`, du)
}
type mockNewRoleManagerFunc struct {
CalledWithContext context.Context
CalledWithDomain string
CalledWithServiceAccount *ServiceAccount
ReturnRoleManager RoleManager
ReturnUserManager UserManager
ReturnError error
}
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) {
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
m.CalledWithContext = ctx
m.CalledWithDomain = domain
m.CalledWithServiceAccount = serviceAccount
return m.ReturnRoleManager, m.ReturnError
return m.ReturnRoleManager, m.ReturnUserManager, m.ReturnError
}
type listOptionMatcher struct {
@ -379,7 +478,7 @@ func TestProvider_UserGroups(t *testing.T) {
mRoleManager := mock_auth0.NewMockRoleManager(ctrl)
mNewRoleManagerFunc := mockNewRoleManagerFunc{
mNewManagersFunc := mockNewRoleManagerFunc{
ReturnRoleManager: mRoleManager,
ReturnError: tc.newRoleManagerError,
}
@ -392,7 +491,7 @@ func TestProvider_UserGroups(t *testing.T) {
WithDomain(expectedDomain),
WithServiceAccount(expectedServiceAccount),
)
p.cfg.newRoleManager = mNewRoleManagerFunc.f
p.cfg.newManagers = mNewManagersFunc.f
actualGroups, actualUsers, err := p.UserGroups(context.Background())
if tc.expectedError != nil {
@ -404,8 +503,8 @@ func TestProvider_UserGroups(t *testing.T) {
assert.Equal(t, tc.expectedGroups, actualGroups)
assert.Equal(t, tc.expectedUsers, actualUsers)
assert.Equal(t, expectedDomain, mNewRoleManagerFunc.CalledWithDomain)
assert.Equal(t, expectedServiceAccount, mNewRoleManagerFunc.CalledWithServiceAccount)
assert.Equal(t, expectedDomain, mNewManagersFunc.CalledWithDomain)
assert.Equal(t, expectedServiceAccount, mNewManagersFunc.CalledWithServiceAccount)
})
}
@ -433,7 +532,7 @@ func TestParseServiceAccount(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
actualServiceAccount, err := ParseServiceAccount(tc.rawServiceAccount)
actualServiceAccount, err := ParseServiceAccount(directory.Options{ServiceAccount: tc.rawServiceAccount})
if tc.expectedError != nil {
assert.EqualError(t, err, tc.expectedError.Error())
} else {

View file

@ -6,14 +6,15 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
@ -101,6 +102,50 @@ func New(options ...Option) *Provider {
return p
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("azure: service account not defined")
}
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
userURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/users/%s", providerUserID),
}).String()
var u usersDeltaResponseUser
err := p.api(ctx, userURL, &u)
if err != nil {
return nil, err
}
du.DisplayName = u.DisplayName
du.Email = u.getEmail()
groupURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", providerUserID),
}).String()
var res struct {
Value []usersDeltaResponseUser `json:"value"`
}
err = p.api(ctx, groupURL, &res)
if err != nil {
return nil, err
}
for _, g := range res.Value {
du.GroupIds = append(du.GroupIds, g.ID)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups returns the directory users in azure active directory.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
@ -116,13 +161,13 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return groups, users, nil
}
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
func (p *Provider) api(ctx context.Context, url string, out interface{}) error {
token, err := p.getToken(ctx)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, method, url, body)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("azure: error creating HTTP request: %w", err)
}
@ -143,7 +188,7 @@ func (p *Provider) api(ctx context.Context, method, url string, body io.Reader,
}
if res.StatusCode/100 != 2 {
return fmt.Errorf("azure: error querying api: %s", res.Status)
return fmt.Errorf("azure: error querying api (%s): %s", url, res.Status)
}
err = json.NewDecoder(res.Body).Decode(out)

View file

@ -13,6 +13,7 @@ import (
"github.com/go-chi/chi/middleware"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
@ -74,12 +75,62 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
},
})
})
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user-1":
_ = json.NewEncoder(w).Encode(M{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user-1":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "admin"},
},
})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
return r
}
func Test(t *testing.T) {
func TestProvider_User(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithGraphURL(mustParseURL(srv.URL)),
WithLoginURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "DIRECTORY_ID",
}),
)
du, err := p.User(context.Background(), "azure/user-1", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "azure/user-1",
"displayName": "User 1",
"email": "user1@example.com",
"groupIds": ["admin"]
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
@ -118,10 +169,10 @@ func Test(t *testing.T) {
Email: "user3@example.com",
},
}, users)
assert.Equal(t, []*directory.Group{
{Id: "admin", Name: "Admin Group"},
{Id: "test", Name: "Test Group"},
}, groups)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "admin", "name": "Admin Group" },
{ "id": "test", "name": "Test Group"}
]`, groups)
}
func TestParseServiceAccount(t *testing.T) {

View file

@ -86,7 +86,7 @@ func (dc *deltaCollection) syncGroups(ctx context.Context) error {
for {
var res groupsDeltaResponse
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
err := dc.provider.api(ctx, apiURL, &res)
if err != nil {
return err
}
@ -146,7 +146,7 @@ func (dc *deltaCollection) syncUsers(ctx context.Context) error {
for {
var res usersDeltaResponse
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
err := dc.provider.api(ctx, apiURL, &res)
if err != nil {
return err
}
@ -197,6 +197,9 @@ func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory
}
groupLookup.addGroup(g.id, groupIDs, userIDs)
}
sort.Slice(groups, func(i, j int) bool {
return groups[i].GetId() < groups[j].GetId()
})
var users []*directory.User
for _, u := range dc.users {

View file

@ -2,6 +2,7 @@
package github
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
@ -83,6 +84,47 @@ func New(options ...Option) *Provider {
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("github: service account not defined")
}
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
au, err := p.getUser(ctx, providerUserID)
if err != nil {
return nil, err
}
du.DisplayName = au.Name
du.Email = au.Email
teamIDLookup := map[int]struct{}{}
orgSlugs, err := p.listOrgs(ctx)
if err != nil {
return nil, err
}
for _, orgSlug := range orgSlugs {
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
if err != nil {
return nil, err
}
for _, teamID := range teamIDs {
teamIDLookup[teamID] = struct{}{}
}
}
for teamID := range teamIDLookup {
du.GroupIds = append(du.GroupIds, strconv.Itoa(teamID))
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups gets the directory user groups for github.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
@ -230,6 +272,77 @@ func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObjec
return &res, nil
}
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]int, error) {
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
enc := func(obj interface{}) string {
bs, _ := json.Marshal(obj)
return string(bs)
}
const pageCount = 100
var teamIDs []int
var cursor *string
for {
var res struct {
Data struct {
Organization struct {
Teams struct {
TotalCount int `json:"totalCount"`
PageInfo struct {
EndCursor string `json:"endCursor"`
} `json:"pageInfo"`
Edges []struct {
Node struct {
ID int `json:"id"`
} `json:"node"`
} `json:"edges"`
} `json:"teams"`
} `json:"organization"`
} `json:"data"`
}
cursorStr := ""
if cursor != nil {
cursorStr = fmt.Sprintf(",%s", enc(*cursor))
}
q := fmt.Sprintf(`query {
organization(login:%s) {
teams(first:%s, userLogins:[%s] %s) {
totalCount
pageInfo {
endCursor
}
edges {
node {
id
}
}
}
}
}`, enc(orgSlug), enc(pageCount), enc(userSlug), cursorStr)
_, err := p.graphql(ctx, q, &res)
if err != nil {
return nil, err
}
if len(res.Data.Organization.Teams.Edges) == 0 {
break
}
for _, edge := range res.Data.Organization.Teams.Edges {
teamIDs = append(teamIDs, edge.Node.ID)
}
if len(teamIDs) >= res.Data.Organization.Teams.TotalCount {
break
}
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
}
return teamIDs, nil
}
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
@ -257,6 +370,41 @@ func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (htt
return res.Header, nil
}
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
apiURL := p.cfg.url.ResolveReference(&url.URL{
Path: "/graphql",
}).String()
bs, _ := json.Marshal(struct {
Query string `json:"query"`
}{query})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
if err != nil {
return nil, fmt.Errorf("github: failed to create http request: %w", err)
}
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("github: failed to make http request: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("github: error from API: %s", res.Status)
}
if out != nil {
err := json.NewDecoder(res.Body).Decode(out)
if err != nil {
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
}
}
return res.Header, nil
}
func getNextLink(hdrs http.Header) string {
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
if link.Rel == "next" {

View file

@ -29,6 +29,33 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
next.ServeHTTP(w, r)
})
})
r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) {
var body struct {
Query string `json:"query"`
}
json.NewDecoder(r.Body).Decode(&body)
_ = json.NewEncoder(w).Encode(M{
"data": M{
"organization": M{
"teams": M{
"totalCount": 3,
"edges": []M{
{"node": M{
"id": 1,
}},
{"node": M{
"id": 2,
}},
{"node": M{
"id": 3,
}},
},
},
},
},
})
})
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode([]M{
{"login": "org1"},
@ -88,7 +115,34 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
return r
}
func Test(t *testing.T) {
func TestProvider_User(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
Username: "abc",
PersonalAccessToken: "xyz",
}),
)
du, err := p.User(context.Background(), "github/user1", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "github/user1",
"groupIds": ["1", "2", "3"],
"displayName": "User 1",
"email": "user1@example.com"
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)

View file

@ -83,6 +83,32 @@ func New(options ...Option) *Provider {
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
au, err := p.getUser(ctx, providerUserID, accessToken)
if err != nil {
return nil, err
}
du.DisplayName = au.Name
du.Email = au.Email
groups, err := p.listGroups(ctx, accessToken)
if err != nil {
return nil, err
}
for _, g := range groups {
du.GroupIds = append(du.GroupIds, g.Id)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups gets the directory user groups for gitlab.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
@ -91,7 +117,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
p.log.Info().Msg("getting user groups")
groups, err := p.listGroups(ctx)
groups, err := p.listGroups(ctx, "")
if err != nil {
return nil, nil, err
}
@ -129,8 +155,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return groups, users, nil
}
func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) {
apiURL := p.cfg.url.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v4/users/%s", userID),
}).String()
var result apiUserObject
_, err := p.api(ctx, accessToken, apiURL, &result)
if err != nil {
return nil, fmt.Errorf("gitlab: error querying user: %w", err)
}
return &result, nil
}
// listGroups returns a map, with key is group ID, element is group name.
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: "/api/v4/groups",
}).String()
@ -140,7 +178,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
ID int `json:"id"`
Name string `json:"name"`
}
hdrs, err := p.apiGet(ctx, nextURL, &result)
hdrs, err := p.api(ctx, accessToken, nextURL, &result)
if err != nil {
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
}
@ -163,7 +201,7 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
}).String()
for nextURL != "" {
var result []apiUserObject
hdrs, err := p.apiGet(ctx, nextURL, &result)
hdrs, err := p.api(ctx, "", nextURL, &result)
if err != nil {
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
}
@ -174,14 +212,18 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
return users, nil
}
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) {
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
if accessToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
} else {
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
}
res, err := p.cfg.httpClient.Do(req)
if err != nil {
@ -190,7 +232,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("gitlab: error query api status_code=%d: %s", res.StatusCode, res.Status)
return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status)
}
err = json.NewDecoder(res.Body).Decode(out)

View file

@ -5,13 +5,16 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"sync"
"github.com/rs/zerolog"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
"github.com/pomerium/pomerium/internal/log"
@ -19,11 +22,15 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "google"
const (
// Name is the provider name.
Name = "google"
currentAccountCustomerID = "my_customer"
)
const (
defaultProviderURL = "https://accounts.google.com"
defaultProviderURL = "https://www.googleapis.com/admin/directory/v1/"
)
type config struct {
@ -78,6 +85,51 @@ func New(options ...Option) *Provider {
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
apiClient, err := p.getAPIClient(ctx)
if err != nil {
return nil, fmt.Errorf("google: error getting API client: %w", err)
}
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
au, err := apiClient.Users.Get(providerUserID).
Context(ctx).
Do()
if isAccessDenied(err) {
// ignore forbidden errors as a user may login from another gsuite domain
return du, nil
} else if err != nil {
return nil, fmt.Errorf("google: error getting user: %w", err)
} else {
if au.Name != nil {
du.DisplayName = au.Name.FullName
}
du.Email = au.PrimaryEmail
}
err = apiClient.Groups.List().
Context(ctx).
UserKey(providerUserID).
Pages(ctx, func(res *admin.Groups) error {
for _, g := range res.Groups {
du.GroupIds = append(du.GroupIds, g.Id)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("google: error getting groups for user: %w", err)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups returns a slice of group names a given user is in
// NOTE: groups via Directory API is limited to 1 QPS!
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
@ -91,7 +143,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
var groups []*directory.Group
err = apiClient.Groups.List().
Context(ctx).
Customer("my_customer").
Customer(currentAccountCustomerID).
Pages(ctx, func(res *admin.Groups) error {
for _, g := range res.Groups {
// Skip group without member.
@ -113,7 +165,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
userLookup := map[string]apiUserObject{}
err = apiClient.Users.List().
Context(ctx).
Customer("my_customer").
Customer(currentAccountCustomerID).
Pages(ctx, func(res *admin.Users) error {
for _, u := range res.Users {
userLookup[u.Id] = apiUserObject{
@ -188,7 +240,7 @@ func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {
ts := config.TokenSource(ctx)
p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts))
p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts), option.WithEndpoint(p.cfg.url))
if err != nil {
return nil, fmt.Errorf("google: failed creating admin service %w", err)
}
@ -243,3 +295,16 @@ type apiUserObject struct {
DisplayName string
Email string
}
func isAccessDenied(err error) bool {
if err == nil {
return false
}
gerr := new(googleapi.Error)
if errors.As(err, &gerr) {
return gerr.Code == http.StatusForbidden
}
return false
}

View file

@ -0,0 +1,140 @@
package google
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/stretchr/testify/assert"
)
var privateKey = `
-----BEGIN RSA PRIVATE KEY-----
MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH
pzXBbzgGKIbQbu+njoRzhQ95RbcEzZiVFivNggijkFoUNkFIjy42O/xXdTRTX3/u
4pxu1ctccqYnZwnry6E8ekQTVX7kgmVqgzIrY1Y6K3PVlhkgGDK/TStu+RIPoA73
vJUpJTFTw+6tgUSBmCkzctQsGFGUiGBRpqlxogEkImJYrcMUJkWhTopLy79+3OB5
eGGKWwA6P2ZhBB4RKHBWyWipsYxr889QNG6P3o+Be6lIzvcdiIIBYK8qWLmO45hR
xUGWSRK8sveAJO54t+wGE0dSPVKfoS4oNqGYdGhBzoTwsfZRvyvcWeQob4DNCqQa
n41XAuYGOG3X1PexdwGSwwqrq2tuG9d2AJ2NjG8nC9hjuuKDfBGTigTwzrwkrn+F
3o94NoglQgsXfZeWoBXR5HDaDTdqexSRK0OpSPbvzkn8QDUymdaw7nRS7kU2O7fa
W8kxiV8AVt2v/jYjAgMBAAECggGAS+6NE0Mo0Pki2Mi0+y4ls7BgNNbJhYFRthpi
RvN8WXE+B0EhquzWDbxKecIZBr6FOqnFQf2muja+QH1VckutjRDKVOZ7QXsygGYf
/cff5aM7U3gYTS4avQIDeb0axRioIElu4vtIsJtLFMfe5yaAZktXvPz9fiamn/wG
r/xqZ6ifir6nsC2okxGczWE+XCGsjpWA/321lhvj548os5JV6SfTUBUpvqNGVQC2
ByXIPDffsCTfQ1rjQ85gM4vuqiQqKn/KXRrMR1kRIOrglMJ6dllitsadkfWdPkVg
fHjM1KAnw/uob5kFhvqeEll9sESfCXttrb4XneKAsEck958ChSpU4csaBAfLFVYP
5xyIfoaQ+/CUjWI0u1lQbg6jZO59rYfdd5OlH+MyhHybuHR1a0G1izNfuG9WPOWI
aprNayH2Wxy9/ZvlrE5yTAeW9tof28hO6O7wBNOcJTrzztsN+V8pSAo0IE2r4D83
h978LneAwhC/8mVvzhd/y2t99vcBAoHBAMumCoHmHRAsYBApg2SHxPCTDi4fS1hC
IbcuzyvJLv214bHOsdgg2I72a1Q+bbrZAHgENVgSP9Sx1k8aXlvnGOEbpL8w3jRL
G4/qXzGrMBp3LCzupi3dI3yrrIMWb2C0goyHeAejzrfaM+uDYTGW4iqhA39zBj4o
zoydz3v0i8Yag7Df9MIwr34WD9Ng0oXh8XRCAYJmS1e43jnM+XcFdSfGVhKn9h1B
Cbv/hqUSv6baNloWLlPBffLII5bx633MMwKBwQDGg87fKEPz73pbp1kAOuXDpUnm
+OUFuf6Uqr/OqCorh8SOwxb2q3ScqyCWGVLOaVBcSV+zxIdcQzX2kjf1mj7PcQ6Q
2xfDIS1/xT+RiX9LO0kbkVDYcwcGeKVtmUwWyjauo96OB2r+SchTsNJpYOT3a/7r
JUKdbHFwsFwAx5q7r9mOh0BOybuXM6N7lUDBf4SgrhjnKRh1pME3R0JbJj9m8tZg
SsWlHcj04yAXJ7NGemiiYgeDZ4unsAfx7/sS/lECgcEAsL0Yj2XTQU8Ry9ULaDsA
az1k6BhWvnEea6lfOQPwGVY5WqQk6oqPB3vK6CEKAEgGRSJ53UZxSTlR4fLjg2UL
zYm9MATMQ5wPfpYMKcIFDGLy3sf7RwCNpMwk+tuEq+vdBPMo85BxflQMDVBHEM9+
1zpIG9sKxvWJVLY89LnmeHZYZi/nboTsOUQSVgPIkVLmx1vljXMT3jzd+FHxCx+c
bnmOB8DnMrpYJWV9SFP+KmNlGkf3ys65bPPPF1g7ZUDLAoHAaKHqtQa1Imr0NED1
kUB6AHArjrlbhXQuck+5f4R1jbIm8RR1Exj2AunT6Cl60t8Bg1MNRWRt8DxgwhD5
u9NMDezKP6GrWacwIytlQSGW3aFm/EfQs/WVG10V3LmzOEPnJI+s63GPfG6JT0tg
7DgtFxhuKaTfAri45iueoq6SqSCb7Brv01dTL/QA1E+r7RF4Z3S8HYM0qDVpvegq
Wn7DZlDSm7htioUzeZgJPwsm3BwC8Kv4x9MY8g6/cU8LKEyxAoHAWCaDpLIuQ51r
PeL+u/1cfNdi6OOrtZ6S95tu3Vv+mYzpCPnOpgPHFp3l+RGmLg56t7uvHFaFxOvB
EjPm4bVhnPSA7pl7ZHQXhinG9+4UgcejoCAJzfg05BI1tMbwFZ+C0tG/PNzBlaX+
IwkGO8VP/54N6wL1UqfZ8AKJFZW8G7W7KVkjqye1FS4oeDlJ197t/X+PMn5sFAc7
UVsDaSelBqpsfmetXSH8KC3XkbgCtHvgAnJDkGkp84VmJvMr5ukv
-----END RSA PRIVATE KEY-----
`
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/token", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"token_type": "Bearer",
"refresh_token": "REFRESHTOKEN",
})
})
r.Route("/groups", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Query().Get("userKey") {
case "user1":
_ = json.NewEncoder(w).Encode(M{
"kind": "admin#directory#groups",
"groups": []M{
{"id": "group1"},
{"id": "group2"},
},
})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user1":
_ = json.NewEncoder(w).Encode(M{
"kind": "admin#directory#user",
"id": "1",
"name": M{
"fullName": "User 1",
},
"primaryEmail": "user1@example.com",
})
case "user2":
http.Error(w, "forbidden", http.StatusForbidden)
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(WithServiceAccount(&ServiceAccount{
Type: "service_account",
PrivateKey: privateKey,
TokenURL: srv.URL + "/token",
}), WithURL(srv.URL))
du, err := p.User(ctx, "user1", "")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, "user1", du.Id)
assert.Equal(t, "user1@example.com", du.Email)
assert.Equal(t, "User 1", du.DisplayName)
assert.Equal(t, []string{"group1", "group2"}, du.GroupIds)
du, err = p.User(ctx, "user2", "")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, "user2", du.Id)
}

View file

@ -108,6 +108,36 @@ func New(options ...Option) *Provider {
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, ErrServiceAccountNotDefined
}
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
au, err := p.getUser(ctx, providerUserID)
if err != nil {
return nil, err
}
du.DisplayName = au.getDisplayName()
du.Email = au.Profile.Email
groups, err := p.listUserGroups(ctx, providerUserID)
if err != nil {
return nil, err
}
for _, g := range groups {
du.GroupIds = append(du.GroupIds, g.ID)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
@ -159,7 +189,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
users = append(users, &directory.User{
Id: databroker.GetUserID(Name, u.ID),
GroupIds: groups,
DisplayName: u.Profile.FirstName + " " + u.Profile.LastName,
DisplayName: u.getDisplayName(),
Email: u.Profile.Email,
})
}
@ -183,14 +213,7 @@ func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
groupURL := p.cfg.providerURL.ResolveReference(u).String()
for groupURL != "" {
var out []struct {
ID string `json:"id"`
Profile struct {
Name string `json:"name"`
} `json:"profile"`
LastUpdated string `json:"lastUpdated"`
LastMembershipUpdated string `json:"lastMembershipUpdated"`
}
var out []apiGroupObject
hdrs, err := p.apiGet(ctx, groupURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
@ -239,6 +262,36 @@ func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users [
return users, nil
}
func (p *Provider) getUser(ctx context.Context, userID string) (*apiUserObject, error) {
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v1/users/%s", userID),
}).String()
var out apiUserObject
_, err := p.apiGet(ctx, apiURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for user: %w", err)
}
return &out, nil
}
func (p *Provider) listUserGroups(ctx context.Context, userID string) (groups []apiGroupObject, err error) {
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v1/users/%s/groups", userID),
}).String()
for apiURL != "" {
var out []apiGroupObject
hdrs, err := p.apiGet(ctx, apiURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for user groups: %w", err)
}
groups = append(groups, out...)
apiURL = getNextLink(hdrs)
}
return groups, nil
}
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
@ -334,11 +387,25 @@ func (err *APIError) Error() string {
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
}
type apiUserObject struct {
type (
apiGroupObject struct {
ID string `json:"id"`
Profile struct {
Name string `json:"name"`
} `json:"profile"`
LastUpdated string `json:"lastUpdated"`
LastMembershipUpdated string `json:"lastMembershipUpdated"`
}
apiUserObject struct {
ID string `json:"id"`
Profile struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Email string `json:"email"`
} `json:"profile"`
}
)
func (obj *apiUserObject) getDisplayName() string {
return obj.Profile.FirstName + " " + obj.Profile.LastName
}

View file

@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
@ -43,7 +44,9 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
next.ServeHTTP(w, r)
})
})
r.Get("/api/v1/groups", func(w http.ResponseWriter, r *http.Request) {
r.Route("/api/v1", func(r chi.Router) {
r.Route("/groups", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ")
var groups []string
for group := range getAllGroups() {
@ -86,7 +89,7 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
_ = json.NewEncoder(w).Encode(result)
})
r.Get("/api/v1/groups/{group}/users", func(w http.ResponseWriter, r *http.Request) {
r.Get("/{group}/users", func(w http.ResponseWriter, r *http.Request) {
group := chi.URLParam(r, "group")
if _, ok := getAllGroups()[group]; !ok {
@ -122,9 +125,61 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
_ = json.NewEncoder(w).Encode(result)
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/{user_id}/groups", func(w http.ResponseWriter, r *http.Request) {
var groups []apiGroupObject
for _, nm := range userEmailToGroups[chi.URLParam(r, "user_id")] {
obj := apiGroupObject{
ID: nm,
}
obj.Profile.Name = nm
groups = append(groups, obj)
}
_ = json.NewEncoder(w).Encode(groups)
})
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
user := apiUserObject{
ID: chi.URLParam(r, "user_id"),
}
user.Profile.Email = chi.URLParam(r, "user_id")
user.Profile.FirstName = "first"
user.Profile.LastName = "last"
_ = json.NewEncoder(w).Encode(user)
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
var mockOkta http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockOkta.ServeHTTP(w, r)
}))
defer srv.Close()
mockOkta = newMockOkta(srv, map[string][]string{
"a@example.com": {"user", "admin"},
"b@example.com": {"user", "test"},
"c@example.com": {"user"},
})
p := New(
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
WithProviderURL(mustParseURL(srv.URL)),
)
user, err := p.User(context.Background(), "okta/a@example.com", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "okta/a@example.com",
"groupIds": ["admin","user"],
"displayName": "first last",
"email": "a@example.com"
}`, user)
}
func TestProvider_UserGroups(t *testing.T) {
var mockOkta http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View file

@ -94,6 +94,32 @@ func New(options ...Option) *Provider {
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("onelogin: service account not defined")
}
_, providerUserID := databroker.FromUserID(userID)
du := &directory.User{
Id: userID,
}
token, err := p.getToken(ctx)
if err != nil {
return nil, err
}
au, err := p.getUser(ctx, token.AccessToken, providerUserID)
if err != nil {
return nil, err
}
du.DisplayName = au.getDisplayName()
du.Email = au.Email
du.GroupIds = []string{strconv.Itoa(au.GroupID)}
return du, nil
}
// UserGroups gets the directory user groups for onelogin.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
@ -107,12 +133,12 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return nil, nil, err
}
groups, err := p.listGroups(ctx, token)
groups, err := p.listGroups(ctx, token.AccessToken)
if err != nil {
return nil, nil, err
}
apiUsers, err := p.getUsers(ctx, token)
apiUsers, err := p.listUsers(ctx, token.AccessToken)
if err != nil {
return nil, nil, err
}
@ -133,7 +159,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return groups, users, nil
}
func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*directory.Group, error) {
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
var groups []*directory.Group
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: "/api/1/groups",
@ -144,9 +170,9 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire
ID int `json:"id"`
Name string `json:"name"`
}
nextLink, err := p.apiGet(ctx, token, apiURL, &result)
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
if err != nil {
return nil, fmt.Errorf("onelogin: error querying group api: %w", err)
return nil, fmt.Errorf("onelogin: listing groups: %w", err)
}
for _, r := range result {
@ -161,7 +187,24 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire
return groups, nil
}
func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUserObject, error) {
func (p *Provider) getUser(ctx context.Context, accessToken string, userID string) (*apiUserObject, error) {
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/1/users/%s", userID),
}).String()
var out []apiUserObject
_, err := p.apiGet(ctx, accessToken, apiURL, &out)
if err != nil {
return nil, fmt.Errorf("onelogin: error getting user: %w", err)
}
if len(out) == 0 {
return nil, fmt.Errorf("onelogin: user not found")
}
return &out[0], nil
}
func (p *Provider) listUsers(ctx context.Context, accessToken string) ([]apiUserObject, error) {
var users []apiUserObject
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
@ -170,9 +213,9 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser
}).String()
for apiURL != "" {
var result []apiUserObject
nextLink, err := p.apiGet(ctx, token, apiURL, &result)
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
if err != nil {
return nil, err
return nil, fmt.Errorf("onelogin: listing users: %w", err)
}
users = append(users, result...)
@ -182,12 +225,12 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser
return users, nil
}
func (p *Provider) apiGet(ctx context.Context, token *oauth2.Token, uri string, out interface{}) (nextLink string, err error) {
func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, out interface{}) (nextLink string, err error) {
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", token.AccessToken))
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", accessToken))
req.Header.Set("Content-Type", "application/json")
res, err := p.cfg.httpClient.Do(req)
@ -307,3 +350,7 @@ type apiUserObject struct {
FirstName string `json:"firstname"`
LastName string `json:"lastname"`
}
func (obj *apiUserObject) getDisplayName() string {
return obj.FirstName + " " + obj.LastName
}

View file

@ -8,6 +8,7 @@ import (
"net/http/httptest"
"net/url"
"sort"
"strconv"
"testing"
"time"
@ -102,6 +103,28 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han
_ = json.NewEncoder(w).Encode(result)
})
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
userIDToGroupID := map[int]int{}
for userID, groupName := range userIDToGroupName {
for id, n := range allGroups {
if groupName == n {
userIDToGroupID[userID] = id
}
}
}
userID, _ := strconv.Atoi(chi.URLParam(r, "user_id"))
_ = json.NewEncoder(w).Encode(M{
"data": []M{{
"id": userID,
"email": userIDToGroupName[userID] + "@example.com",
"group_id": userIDToGroupID[userID],
"firstname": "User",
"lastname": fmt.Sprint(userID),
}},
})
})
r.Get("/users", func(w http.ResponseWriter, r *http.Request) {
userIDToGroupID := map[int]int{}
for userID, groupName := range userIDToGroupName {
@ -130,6 +153,37 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han
return r
}
func TestProvider_User(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(srv, map[int]string{
111: "admin",
222: "test",
333: "user",
})
p := New(
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}),
WithURL(mustParseURL(srv.URL)),
)
user, err := p.User(context.Background(), "onelogin/111", "ACCESSTOKEN")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "onelogin/111",
"groupIds": ["0"],
"displayName": "User 111",
"email": "admin@example.com"
}`, user)
}
func TestProvider_UserGroups(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View file

@ -28,8 +28,12 @@ type User = directory.User
// Options are the options specific to the provider.
type Options = directory.Options
// RegisterDirectoryServiceServer registers the directory gRPC service.
var RegisterDirectoryServiceServer = directory.RegisterDirectoryServiceServer
// A Provider provides user group directory information.
type Provider interface {
User(ctx context.Context, userID, accessToken string) (*User, error)
UserGroups(ctx context.Context) ([]*Group, []*User, error)
}
@ -55,7 +59,7 @@ func GetProvider(options Options) (provider Provider) {
switch options.Provider {
case auth0.Name:
serviceAccount, err := auth0.ParseServiceAccount(options.ServiceAccount)
serviceAccount, err := auth0.ParseServiceAccount(options)
if err == nil {
return auth0.New(
auth0.WithDomain(options.ProviderURL),
@ -139,6 +143,10 @@ func GetProvider(options Options) (provider Provider) {
type nullProvider struct{}
func (nullProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
return nil, nil
}
func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
return nil, nil, nil
}

View file

@ -1,11 +1,22 @@
// Package databroker contains databroker protobuf definitions.
package databroker
import "strings"
// GetUserID gets the databroker user id from a provider user id.
func GetUserID(provider, providerUserID string) string {
return provider + "/" + providerUserID
}
// FromUserID gets the provider and provider user id from a databroker user id.
func FromUserID(userID string) (provider, providerUserID string) {
ps := strings.SplitN(userID, "/", 2)
if len(ps) < 2 {
return "", userID
}
return ps[0], ps[1]
}
// ApplyOffsetAndLimit applies the offset and limit to the list of records.
func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) {
records = all

View file

@ -7,7 +7,12 @@
package directory
import (
context "context"
proto "github.com/golang/protobuf/proto"
empty "github.com/golang/protobuf/ptypes/empty"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
@ -175,29 +180,97 @@ func (x *Group) GetEmail() string {
return ""
}
type RefreshUserRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
AccessToken string `protobuf:"bytes,2,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
}
func (x *RefreshUserRequest) Reset() {
*x = RefreshUserRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_directory_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RefreshUserRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RefreshUserRequest) ProtoMessage() {}
func (x *RefreshUserRequest) ProtoReflect() protoreflect.Message {
mi := &file_directory_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RefreshUserRequest.ProtoReflect.Descriptor instead.
func (*RefreshUserRequest) Descriptor() ([]byte, []int) {
return file_directory_proto_rawDescGZIP(), []int{2}
}
func (x *RefreshUserRequest) GetUserId() string {
if x != nil {
return x.UserId
}
return ""
}
func (x *RefreshUserRequest) GetAccessToken() string {
if x != nil {
return x.AccessToken
}
return ""
}
var File_directory_proto protoreflect.FileDescriptor
var file_directory_proto_rawDesc = []byte{
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x86, 0x01, 0x0a,
0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12,
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
0x1b, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03,
0x28, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c,
0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01,
0x28, 0x09, 0x52, 0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12,
0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18,
0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65,
0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65,
0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x1b, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d,
0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x04, 0x55, 0x73,
0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02,
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09,
0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52,
0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x69, 0x73,
0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52,
0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76,
0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65,
0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28,
0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61,
0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22,
0x50, 0x0a, 0x12, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21,
0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65,
0x6e, 0x32, 0x58, 0x0a, 0x10, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68,
0x55, 0x73, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79,
0x2e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x31, 0x5a, 0x2f, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@ -212,14 +285,18 @@ func file_directory_proto_rawDescGZIP() []byte {
return file_directory_proto_rawDescData
}
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
var file_directory_proto_goTypes = []interface{}{
(*User)(nil), // 0: directory.User
(*Group)(nil), // 1: directory.Group
(*RefreshUserRequest)(nil), // 2: directory.RefreshUserRequest
(*empty.Empty)(nil), // 3: google.protobuf.Empty
}
var file_directory_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
2, // 0: directory.DirectoryService.RefreshUser:input_type -> directory.RefreshUserRequest
3, // 1: directory.DirectoryService.RefreshUser:output_type -> google.protobuf.Empty
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -255,6 +332,18 @@ func file_directory_proto_init() {
return nil
}
}
file_directory_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RefreshUserRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -262,9 +351,9 @@ func file_directory_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_directory_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumMessages: 3,
NumExtensions: 0,
NumServices: 0,
NumServices: 1,
},
GoTypes: file_directory_proto_goTypes,
DependencyIndexes: file_directory_proto_depIdxs,
@ -275,3 +364,83 @@ func file_directory_proto_init() {
file_directory_proto_goTypes = nil
file_directory_proto_depIdxs = nil
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConnInterface
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion6
// DirectoryServiceClient is the client API for DirectoryService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type DirectoryServiceClient interface {
RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
}
type directoryServiceClient struct {
cc grpc.ClientConnInterface
}
func NewDirectoryServiceClient(cc grpc.ClientConnInterface) DirectoryServiceClient {
return &directoryServiceClient{cc}
}
func (c *directoryServiceClient) RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
out := new(empty.Empty)
err := c.cc.Invoke(ctx, "/directory.DirectoryService/RefreshUser", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DirectoryServiceServer is the server API for DirectoryService service.
type DirectoryServiceServer interface {
RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error)
}
// UnimplementedDirectoryServiceServer can be embedded to have forward compatible implementations.
type UnimplementedDirectoryServiceServer struct {
}
func (*UnimplementedDirectoryServiceServer) RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method RefreshUser not implemented")
}
func RegisterDirectoryServiceServer(s *grpc.Server, srv DirectoryServiceServer) {
s.RegisterService(&_DirectoryService_serviceDesc, srv)
}
func _DirectoryService_RefreshUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RefreshUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DirectoryServiceServer).RefreshUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/directory.DirectoryService/RefreshUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DirectoryServiceServer).RefreshUser(ctx, req.(*RefreshUserRequest))
}
return interceptor(ctx, in, info, handler)
}
var _DirectoryService_serviceDesc = grpc.ServiceDesc{
ServiceName: "directory.DirectoryService",
HandlerType: (*DirectoryServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "RefreshUser",
Handler: _DirectoryService_RefreshUser_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "directory.proto",
}

View file

@ -3,6 +3,8 @@ syntax = "proto3";
package directory;
option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
import "google/protobuf/empty.proto";
message User {
string version = 1;
string id = 2;
@ -17,3 +19,12 @@ message Group {
string name = 3;
string email = 4;
}
message RefreshUserRequest {
string user_id = 1;
string access_token = 2;
}
service DirectoryService {
rpc RefreshUser(RefreshUserRequest) returns (google.protobuf.Empty);
}