mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-13 17:17:43 +02:00
directory: add explicit RefreshUser endpoint for faster sync (#1460)
* directory: add explicit RefreshUser endpoint for faster sync * add test * implement azure * update api call * add test for azure User * implement github * implement AccessToken, gitlab * implement okta * implement onelogin * fix test * fix inconsistent test * implement auth0
This commit is contained in:
parent
9b39deabd8
commit
aa731ae068
23 changed files with 1405 additions and 179 deletions
|
@ -596,5 +596,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
|||
}
|
||||
sessionState.Version = sessions.Version(res.GetServerVersion())
|
||||
|
||||
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
|
||||
UserId: s.UserId,
|
||||
AccessToken: accessToken.AccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("directory: failed to refresh user data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@ import (
|
|||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
|
@ -29,10 +31,12 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
@ -171,6 +175,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
|
||||
options: config.NewAtomicOptions(),
|
||||
|
@ -262,6 +267,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
options: config.NewAtomicOptions(),
|
||||
|
@ -366,6 +372,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
redirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
cookieCipher: aead,
|
||||
|
@ -515,6 +522,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
provider: identity.NewAtomicAuthenticator(),
|
||||
|
@ -633,6 +641,7 @@ func TestAuthenticate_Dashboard(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
}
|
||||
|
@ -679,3 +688,16 @@ func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.Get
|
|||
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
|
||||
return m.set(ctx, in, opts...)
|
||||
}
|
||||
|
||||
type mockDirectoryServiceClient struct {
|
||||
directory.DirectoryServiceClient
|
||||
|
||||
refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
|
||||
}
|
||||
|
||||
func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
|
||||
if m.refreshUser != nil {
|
||||
return m.refreshUser(ctx, in, opts...)
|
||||
}
|
||||
return nil, status.Error(codes.Unimplemented, "")
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
type authenticateState struct {
|
||||
|
@ -46,6 +47,7 @@ type authenticateState struct {
|
|||
jwk *jose.JSONWebKeySet
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
directoryClient directory.DirectoryServiceClient
|
||||
}
|
||||
|
||||
func newAuthenticateState() *authenticateState {
|
||||
|
@ -129,6 +131,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
}
|
||||
|
||||
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||
state.directoryClient = directory.NewDirectoryServiceClient(dataBrokerConn)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
|
8
cache/cache.go
vendored
8
cache/cache.go
vendored
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"gopkg.in/tomb.v2"
|
||||
|
@ -33,6 +34,9 @@ type Cache struct {
|
|||
localGRPCConnection *grpc.ClientConn
|
||||
dataBrokerStorageType string //TODO remove in v0.11
|
||||
deprecatedCacheClusterDomain string //TODO: remove in v0.11
|
||||
|
||||
mu sync.Mutex
|
||||
directoryProvider directory.Provider
|
||||
}
|
||||
|
||||
// New creates a new cache service.
|
||||
|
@ -90,6 +94,7 @@ func (c *Cache) OnConfigChange(cfg *config.Config) {
|
|||
// Register registers all the gRPC services with the given server.
|
||||
func (c *Cache) Register(grpcServer *grpc.Server) {
|
||||
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
|
||||
directory.RegisterDirectoryServiceServer(grpcServer, c)
|
||||
}
|
||||
|
||||
// Run runs the cache components.
|
||||
|
@ -132,6 +137,9 @@ func (c *Cache) update(cfg *config.Config) error {
|
|||
ClientID: cfg.Options.ClientID,
|
||||
ClientSecret: cfg.Options.ClientSecret,
|
||||
})
|
||||
c.mu.Lock()
|
||||
c.directoryProvider = directoryProvider
|
||||
c.mu.Unlock()
|
||||
|
||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
||||
|
||||
|
|
44
cache/directory.go
vendored
Normal file
44
cache/directory.go
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// RefreshUser refreshes a user's directory information.
|
||||
func (c *Cache) RefreshUser(ctx context.Context, req *directory.RefreshUserRequest) (*emptypb.Empty, error) {
|
||||
c.mu.Lock()
|
||||
dp := c.directoryProvider
|
||||
c.mu.Unlock()
|
||||
|
||||
if dp == nil {
|
||||
return nil, errors.New("no directory provider is available for refresh")
|
||||
}
|
||||
|
||||
u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
any, err := anypb.New(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = c.dataBrokerServer.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return new(emptypb.Empty), nil
|
||||
}
|
|
@ -21,16 +21,33 @@ import (
|
|||
// Name is the provider name.
|
||||
const Name = "auth0"
|
||||
|
||||
// RoleManager defines what is needed to get role info from Auth0.
|
||||
type RoleManager interface {
|
||||
type (
|
||||
// RoleManager defines what is needed to get role info from Auth0.
|
||||
RoleManager interface {
|
||||
List(opts ...management.ListOption) (r *management.RoleList, err error)
|
||||
Users(id string, opts ...management.ListOption) (u *management.UserList, err error)
|
||||
}
|
||||
// UserManager defines what is needed to get user info from Auth0.
|
||||
UserManager interface {
|
||||
Read(id string) (*management.User, error)
|
||||
Roles(id string, opts ...management.ListOption) (r *management.RoleList, err error)
|
||||
}
|
||||
)
|
||||
|
||||
type newManagersFunc = func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error)
|
||||
|
||||
func defaultNewManagersFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
|
||||
m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("auth0: could not build management: %w", err)
|
||||
}
|
||||
return m.Role, m.User, nil
|
||||
}
|
||||
|
||||
type config struct {
|
||||
domain string
|
||||
serviceAccount *ServiceAccount
|
||||
newRoleManager func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error)
|
||||
newManagers newManagersFunc
|
||||
}
|
||||
|
||||
// Option provides config for the Auth0 Provider.
|
||||
|
@ -50,17 +67,9 @@ func WithDomain(domain string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func defaultNewRoleManagerFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) {
|
||||
m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: could not build management")
|
||||
}
|
||||
return m.Role, nil
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := &config{
|
||||
newRoleManager: defaultNewRoleManagerFunc,
|
||||
newManagers: defaultNewManagersFunc,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
|
@ -82,13 +91,49 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Provider) getRoleManager(ctx context.Context) (RoleManager, error) {
|
||||
return p.cfg.newRoleManager(ctx, p.cfg.domain, p.cfg.serviceAccount)
|
||||
func (p *Provider) getManagers(ctx context.Context) (RoleManager, UserManager, error) {
|
||||
return p.cfg.newManagers(ctx, p.cfg.domain, p.cfg.serviceAccount)
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
_, um, err := p.getManagers(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
u, err := um.Read(providerUserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: error getting user info: %w", err)
|
||||
}
|
||||
du.DisplayName = u.GetName()
|
||||
du.Email = u.GetEmail()
|
||||
|
||||
for page, hasNext := 0, true; hasNext; page++ {
|
||||
rl, err := um.Roles(providerUserID, management.IncludeTotals(true), management.Page(page))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: error getting user roles: %w", err)
|
||||
}
|
||||
|
||||
for _, role := range rl.Roles {
|
||||
du.GroupIds = append(du.GroupIds, role.GetID())
|
||||
}
|
||||
|
||||
hasNext = rl.HasNext()
|
||||
}
|
||||
|
||||
sort.Strings(du.GroupIds)
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups fetches a slice of groups and users.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
rm, err := p.getRoleManager(ctx)
|
||||
rm, _, err := p.getManagers(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
||||
}
|
||||
|
@ -147,6 +192,9 @@ func getRoles(rm RoleManager) ([]*directory.Group, error) {
|
|||
shouldContinue = listRes.HasNext()
|
||||
}
|
||||
|
||||
sort.Slice(roles, func(i, j int) bool {
|
||||
return roles[i].GetId() < roles[j].GetId()
|
||||
})
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
|
@ -170,6 +218,7 @@ func getRoleUserIDs(rm RoleManager, roleID string) ([]string, error) {
|
|||
shouldContinue = usersRes.HasNext()
|
||||
}
|
||||
|
||||
sort.Strings(ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
@ -180,7 +229,30 @@ type ServiceAccount struct {
|
|||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
|
||||
if options.ServiceAccount != "" {
|
||||
return parseServiceAccountFromString(options.ServiceAccount)
|
||||
}
|
||||
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret)
|
||||
}
|
||||
|
||||
func parseServiceAccountFromOptions(clientID, clientSecret string) (*ServiceAccount, error) {
|
||||
serviceAccount := ServiceAccount{
|
||||
ClientID: clientID,
|
||||
Secret: clientSecret,
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("auth0: client_id is required")
|
||||
}
|
||||
if serviceAccount.Secret == "" {
|
||||
return nil, fmt.Errorf("auth0: client_secret is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: could not decode base64: %w", err)
|
||||
|
|
|
@ -2,34 +2,133 @@ package auth0
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/auth0.v4/management"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
})
|
||||
})
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Route("/{user_id}", func(r chi.Router) {
|
||||
r.Get("/roles", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user1":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"total": 2,
|
||||
"limit": 2,
|
||||
"roles": []M{
|
||||
{"id": "role1"},
|
||||
{"id": "role2"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user1":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"user_id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"name": "User 1",
|
||||
})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer clearTimeout()
|
||||
|
||||
orig := http.DefaultTransport
|
||||
defer func() {
|
||||
http.DefaultTransport = orig
|
||||
}()
|
||||
http.DefaultTransport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
srv.StartTLS()
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithDomain(srv.URL),
|
||||
WithServiceAccount(&ServiceAccount{ClientID: "CLIENT_ID", Secret: "SECRET"}),
|
||||
)
|
||||
du, err := p.User(ctx, "auth0/user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "auth0/user1",
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com",
|
||||
"groupIds": ["role1", "role2"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
type mockNewRoleManagerFunc struct {
|
||||
CalledWithContext context.Context
|
||||
CalledWithDomain string
|
||||
CalledWithServiceAccount *ServiceAccount
|
||||
|
||||
ReturnRoleManager RoleManager
|
||||
ReturnUserManager UserManager
|
||||
ReturnError error
|
||||
}
|
||||
|
||||
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) {
|
||||
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
|
||||
m.CalledWithContext = ctx
|
||||
m.CalledWithDomain = domain
|
||||
m.CalledWithServiceAccount = serviceAccount
|
||||
|
||||
return m.ReturnRoleManager, m.ReturnError
|
||||
return m.ReturnRoleManager, m.ReturnUserManager, m.ReturnError
|
||||
}
|
||||
|
||||
type listOptionMatcher struct {
|
||||
|
@ -379,7 +478,7 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
|
||||
mRoleManager := mock_auth0.NewMockRoleManager(ctrl)
|
||||
|
||||
mNewRoleManagerFunc := mockNewRoleManagerFunc{
|
||||
mNewManagersFunc := mockNewRoleManagerFunc{
|
||||
ReturnRoleManager: mRoleManager,
|
||||
ReturnError: tc.newRoleManagerError,
|
||||
}
|
||||
|
@ -392,7 +491,7 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
WithDomain(expectedDomain),
|
||||
WithServiceAccount(expectedServiceAccount),
|
||||
)
|
||||
p.cfg.newRoleManager = mNewRoleManagerFunc.f
|
||||
p.cfg.newManagers = mNewManagersFunc.f
|
||||
|
||||
actualGroups, actualUsers, err := p.UserGroups(context.Background())
|
||||
if tc.expectedError != nil {
|
||||
|
@ -404,8 +503,8 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
assert.Equal(t, tc.expectedGroups, actualGroups)
|
||||
assert.Equal(t, tc.expectedUsers, actualUsers)
|
||||
|
||||
assert.Equal(t, expectedDomain, mNewRoleManagerFunc.CalledWithDomain)
|
||||
assert.Equal(t, expectedServiceAccount, mNewRoleManagerFunc.CalledWithServiceAccount)
|
||||
assert.Equal(t, expectedDomain, mNewManagersFunc.CalledWithDomain)
|
||||
assert.Equal(t, expectedServiceAccount, mNewManagersFunc.CalledWithServiceAccount)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -433,7 +532,7 @@ func TestParseServiceAccount(t *testing.T) {
|
|||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actualServiceAccount, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
actualServiceAccount, err := ParseServiceAccount(directory.Options{ServiceAccount: tc.rawServiceAccount})
|
||||
if tc.expectedError != nil {
|
||||
assert.EqualError(t, err, tc.expectedError.Error())
|
||||
} else {
|
||||
|
|
|
@ -6,14 +6,15 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
|
@ -101,6 +102,50 @@ func New(options ...Option) *Provider {
|
|||
return p
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("azure: service account not defined")
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
userURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/users/%s", providerUserID),
|
||||
}).String()
|
||||
|
||||
var u usersDeltaResponseUser
|
||||
err := p.api(ctx, userURL, &u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = u.DisplayName
|
||||
du.Email = u.getEmail()
|
||||
|
||||
groupURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", providerUserID),
|
||||
}).String()
|
||||
|
||||
var res struct {
|
||||
Value []usersDeltaResponseUser `json:"value"`
|
||||
}
|
||||
err = p.api(ctx, groupURL, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range res.Value {
|
||||
du.GroupIds = append(du.GroupIds, g.ID)
|
||||
}
|
||||
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups returns the directory users in azure active directory.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -116,13 +161,13 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
|
||||
func (p *Provider) api(ctx context.Context, url string, out interface{}) error {
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("azure: error creating HTTP request: %w", err)
|
||||
}
|
||||
|
@ -143,7 +188,7 @@ func (p *Provider) api(ctx context.Context, method, url string, body io.Reader,
|
|||
}
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("azure: error querying api: %s", res.Status)
|
||||
return fmt.Errorf("azure: error querying api (%s): %s", url, res.Status)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(res.Body).Decode(out)
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
|
@ -74,12 +75,62 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
},
|
||||
})
|
||||
})
|
||||
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user-1":
|
||||
_ = json.NewEncoder(w).Encode(M{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user-1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "admin"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithGraphURL(mustParseURL(srv.URL)),
|
||||
WithLoginURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "DIRECTORY_ID",
|
||||
}),
|
||||
)
|
||||
|
||||
du, err := p.User(context.Background(), "azure/user-1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "azure/user-1",
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com",
|
||||
"groupIds": ["admin"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
|
@ -118,10 +169,10 @@ func Test(t *testing.T) {
|
|||
Email: "user3@example.com",
|
||||
},
|
||||
}, users)
|
||||
assert.Equal(t, []*directory.Group{
|
||||
{Id: "admin", Name: "Admin Group"},
|
||||
{Id: "test", Name: "Test Group"},
|
||||
}, groups)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "admin", "name": "Admin Group" },
|
||||
{ "id": "test", "name": "Test Group"}
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
|
|
|
@ -86,7 +86,7 @@ func (dc *deltaCollection) syncGroups(ctx context.Context) error {
|
|||
|
||||
for {
|
||||
var res groupsDeltaResponse
|
||||
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
|
||||
err := dc.provider.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func (dc *deltaCollection) syncUsers(ctx context.Context) error {
|
|||
|
||||
for {
|
||||
var res usersDeltaResponse
|
||||
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
|
||||
err := dc.provider.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -197,6 +197,9 @@ func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory
|
|||
}
|
||||
groupLookup.addGroup(g.id, groupIDs, userIDs)
|
||||
}
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
return groups[i].GetId() < groups[j].GetId()
|
||||
})
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range dc.users {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -83,6 +84,47 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("github: service account not defined")
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, providerUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
teamIDLookup := map[int]struct{}{}
|
||||
orgSlugs, err := p.listOrgs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, orgSlug := range orgSlugs {
|
||||
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, teamID := range teamIDs {
|
||||
teamIDLookup[teamID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for teamID := range teamIDLookup {
|
||||
du.GroupIds = append(du.GroupIds, strconv.Itoa(teamID))
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for github.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -230,6 +272,77 @@ func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObjec
|
|||
return &res, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]int, error) {
|
||||
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
|
||||
|
||||
enc := func(obj interface{}) string {
|
||||
bs, _ := json.Marshal(obj)
|
||||
return string(bs)
|
||||
}
|
||||
const pageCount = 100
|
||||
|
||||
var teamIDs []int
|
||||
var cursor *string
|
||||
for {
|
||||
var res struct {
|
||||
Data struct {
|
||||
Organization struct {
|
||||
Teams struct {
|
||||
TotalCount int `json:"totalCount"`
|
||||
PageInfo struct {
|
||||
EndCursor string `json:"endCursor"`
|
||||
} `json:"pageInfo"`
|
||||
Edges []struct {
|
||||
Node struct {
|
||||
ID int `json:"id"`
|
||||
} `json:"node"`
|
||||
} `json:"edges"`
|
||||
} `json:"teams"`
|
||||
} `json:"organization"`
|
||||
} `json:"data"`
|
||||
}
|
||||
cursorStr := ""
|
||||
if cursor != nil {
|
||||
cursorStr = fmt.Sprintf(",%s", enc(*cursor))
|
||||
}
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
teams(first:%s, userLogins:[%s] %s) {
|
||||
totalCount
|
||||
pageInfo {
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, enc(orgSlug), enc(pageCount), enc(userSlug), cursorStr)
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(res.Data.Organization.Teams.Edges) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, edge := range res.Data.Organization.Teams.Edges {
|
||||
teamIDs = append(teamIDs, edge.Node.ID)
|
||||
}
|
||||
|
||||
if len(teamIDs) >= res.Data.Organization.Teams.TotalCount {
|
||||
break
|
||||
}
|
||||
|
||||
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
|
||||
}
|
||||
|
||||
return teamIDs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||
if err != nil {
|
||||
|
@ -257,6 +370,41 @@ func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (htt
|
|||
return res.Header, nil
|
||||
}
|
||||
|
||||
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/graphql",
|
||||
}).String()
|
||||
|
||||
bs, _ := json.Marshal(struct {
|
||||
Query string `json:"query"`
|
||||
}{query})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to create http request: %w", err)
|
||||
}
|
||||
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to make http request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("github: error from API: %s", res.Status)
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
err := json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return res.Header, nil
|
||||
}
|
||||
|
||||
func getNextLink(hdrs http.Header) string {
|
||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||
if link.Rel == "next" {
|
||||
|
|
|
@ -29,6 +29,33 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"data": M{
|
||||
"organization": M{
|
||||
"teams": M{
|
||||
"totalCount": 3,
|
||||
"edges": []M{
|
||||
{"node": M{
|
||||
"id": 1,
|
||||
}},
|
||||
{"node": M{
|
||||
"id": 2,
|
||||
}},
|
||||
{"node": M{
|
||||
"id": 3,
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode([]M{
|
||||
{"login": "org1"},
|
||||
|
@ -88,7 +115,34 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
return r
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
Username: "abc",
|
||||
PersonalAccessToken: "xyz",
|
||||
}),
|
||||
)
|
||||
du, err := p.User(context.Background(), "github/user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "github/user1",
|
||||
"groupIds": ["1", "2", "3"],
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com"
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
|
|
|
@ -83,6 +83,32 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, providerUserID, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
groups, err := p.listGroups(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range groups {
|
||||
du.GroupIds = append(du.GroupIds, g.Id)
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for gitlab.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -91,7 +117,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
|
||||
groups, err := p.listGroups(ctx)
|
||||
groups, err := p.listGroups(ctx, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -129,8 +155,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v4/users/%s", userID),
|
||||
}).String()
|
||||
var result apiUserObject
|
||||
_, err := p.api(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying user: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// listGroups returns a map, with key is group ID, element is group name.
|
||||
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/api/v4/groups",
|
||||
}).String()
|
||||
|
@ -140,7 +178,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
|||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
hdrs, err := p.apiGet(ctx, nextURL, &result)
|
||||
hdrs, err := p.api(ctx, accessToken, nextURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
|
||||
}
|
||||
|
@ -163,7 +201,7 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
|
|||
}).String()
|
||||
for nextURL != "" {
|
||||
var result []apiUserObject
|
||||
hdrs, err := p.apiGet(ctx, nextURL, &result)
|
||||
hdrs, err := p.api(ctx, "", nextURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
|
||||
}
|
||||
|
@ -174,14 +212,18 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
|
|||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
||||
func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if accessToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
} else {
|
||||
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
|
||||
}
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
@ -190,7 +232,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt
|
|||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("gitlab: error query api status_code=%d: %s", res.StatusCode, res.Status)
|
||||
return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(res.Body).Decode(out)
|
||||
|
|
|
@ -5,13 +5,16 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -19,11 +22,15 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "google"
|
||||
const (
|
||||
// Name is the provider name.
|
||||
Name = "google"
|
||||
|
||||
currentAccountCustomerID = "my_customer"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultProviderURL = "https://accounts.google.com"
|
||||
defaultProviderURL = "https://www.googleapis.com/admin/directory/v1/"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
|
@ -78,6 +85,51 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
apiClient, err := p.getAPIClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := apiClient.Users.Get(providerUserID).
|
||||
Context(ctx).
|
||||
Do()
|
||||
if isAccessDenied(err) {
|
||||
// ignore forbidden errors as a user may login from another gsuite domain
|
||||
return du, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting user: %w", err)
|
||||
} else {
|
||||
if au.Name != nil {
|
||||
du.DisplayName = au.Name.FullName
|
||||
}
|
||||
du.Email = au.PrimaryEmail
|
||||
}
|
||||
|
||||
err = apiClient.Groups.List().
|
||||
Context(ctx).
|
||||
UserKey(providerUserID).
|
||||
Pages(ctx, func(res *admin.Groups) error {
|
||||
for _, g := range res.Groups {
|
||||
du.GroupIds = append(du.GroupIds, g.Id)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting groups for user: %w", err)
|
||||
}
|
||||
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups returns a slice of group names a given user is in
|
||||
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
|
@ -91,7 +143,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
var groups []*directory.Group
|
||||
err = apiClient.Groups.List().
|
||||
Context(ctx).
|
||||
Customer("my_customer").
|
||||
Customer(currentAccountCustomerID).
|
||||
Pages(ctx, func(res *admin.Groups) error {
|
||||
for _, g := range res.Groups {
|
||||
// Skip group without member.
|
||||
|
@ -113,7 +165,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
userLookup := map[string]apiUserObject{}
|
||||
err = apiClient.Users.List().
|
||||
Context(ctx).
|
||||
Customer("my_customer").
|
||||
Customer(currentAccountCustomerID).
|
||||
Pages(ctx, func(res *admin.Users) error {
|
||||
for _, u := range res.Users {
|
||||
userLookup[u.Id] = apiUserObject{
|
||||
|
@ -188,7 +240,7 @@ func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {
|
|||
|
||||
ts := config.TokenSource(ctx)
|
||||
|
||||
p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts))
|
||||
p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts), option.WithEndpoint(p.cfg.url))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
||||
}
|
||||
|
@ -243,3 +295,16 @@ type apiUserObject struct {
|
|||
DisplayName string
|
||||
Email string
|
||||
}
|
||||
|
||||
func isAccessDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
gerr := new(googleapi.Error)
|
||||
if errors.As(err, &gerr) {
|
||||
return gerr.Code == http.StatusForbidden
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
140
internal/directory/google/google_test.go
Normal file
140
internal/directory/google/google_test.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var privateKey = `
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH
|
||||
pzXBbzgGKIbQbu+njoRzhQ95RbcEzZiVFivNggijkFoUNkFIjy42O/xXdTRTX3/u
|
||||
4pxu1ctccqYnZwnry6E8ekQTVX7kgmVqgzIrY1Y6K3PVlhkgGDK/TStu+RIPoA73
|
||||
vJUpJTFTw+6tgUSBmCkzctQsGFGUiGBRpqlxogEkImJYrcMUJkWhTopLy79+3OB5
|
||||
eGGKWwA6P2ZhBB4RKHBWyWipsYxr889QNG6P3o+Be6lIzvcdiIIBYK8qWLmO45hR
|
||||
xUGWSRK8sveAJO54t+wGE0dSPVKfoS4oNqGYdGhBzoTwsfZRvyvcWeQob4DNCqQa
|
||||
n41XAuYGOG3X1PexdwGSwwqrq2tuG9d2AJ2NjG8nC9hjuuKDfBGTigTwzrwkrn+F
|
||||
3o94NoglQgsXfZeWoBXR5HDaDTdqexSRK0OpSPbvzkn8QDUymdaw7nRS7kU2O7fa
|
||||
W8kxiV8AVt2v/jYjAgMBAAECggGAS+6NE0Mo0Pki2Mi0+y4ls7BgNNbJhYFRthpi
|
||||
RvN8WXE+B0EhquzWDbxKecIZBr6FOqnFQf2muja+QH1VckutjRDKVOZ7QXsygGYf
|
||||
/cff5aM7U3gYTS4avQIDeb0axRioIElu4vtIsJtLFMfe5yaAZktXvPz9fiamn/wG
|
||||
r/xqZ6ifir6nsC2okxGczWE+XCGsjpWA/321lhvj548os5JV6SfTUBUpvqNGVQC2
|
||||
ByXIPDffsCTfQ1rjQ85gM4vuqiQqKn/KXRrMR1kRIOrglMJ6dllitsadkfWdPkVg
|
||||
fHjM1KAnw/uob5kFhvqeEll9sESfCXttrb4XneKAsEck958ChSpU4csaBAfLFVYP
|
||||
5xyIfoaQ+/CUjWI0u1lQbg6jZO59rYfdd5OlH+MyhHybuHR1a0G1izNfuG9WPOWI
|
||||
aprNayH2Wxy9/ZvlrE5yTAeW9tof28hO6O7wBNOcJTrzztsN+V8pSAo0IE2r4D83
|
||||
h978LneAwhC/8mVvzhd/y2t99vcBAoHBAMumCoHmHRAsYBApg2SHxPCTDi4fS1hC
|
||||
IbcuzyvJLv214bHOsdgg2I72a1Q+bbrZAHgENVgSP9Sx1k8aXlvnGOEbpL8w3jRL
|
||||
G4/qXzGrMBp3LCzupi3dI3yrrIMWb2C0goyHeAejzrfaM+uDYTGW4iqhA39zBj4o
|
||||
zoydz3v0i8Yag7Df9MIwr34WD9Ng0oXh8XRCAYJmS1e43jnM+XcFdSfGVhKn9h1B
|
||||
Cbv/hqUSv6baNloWLlPBffLII5bx633MMwKBwQDGg87fKEPz73pbp1kAOuXDpUnm
|
||||
+OUFuf6Uqr/OqCorh8SOwxb2q3ScqyCWGVLOaVBcSV+zxIdcQzX2kjf1mj7PcQ6Q
|
||||
2xfDIS1/xT+RiX9LO0kbkVDYcwcGeKVtmUwWyjauo96OB2r+SchTsNJpYOT3a/7r
|
||||
JUKdbHFwsFwAx5q7r9mOh0BOybuXM6N7lUDBf4SgrhjnKRh1pME3R0JbJj9m8tZg
|
||||
SsWlHcj04yAXJ7NGemiiYgeDZ4unsAfx7/sS/lECgcEAsL0Yj2XTQU8Ry9ULaDsA
|
||||
az1k6BhWvnEea6lfOQPwGVY5WqQk6oqPB3vK6CEKAEgGRSJ53UZxSTlR4fLjg2UL
|
||||
zYm9MATMQ5wPfpYMKcIFDGLy3sf7RwCNpMwk+tuEq+vdBPMo85BxflQMDVBHEM9+
|
||||
1zpIG9sKxvWJVLY89LnmeHZYZi/nboTsOUQSVgPIkVLmx1vljXMT3jzd+FHxCx+c
|
||||
bnmOB8DnMrpYJWV9SFP+KmNlGkf3ys65bPPPF1g7ZUDLAoHAaKHqtQa1Imr0NED1
|
||||
kUB6AHArjrlbhXQuck+5f4R1jbIm8RR1Exj2AunT6Cl60t8Bg1MNRWRt8DxgwhD5
|
||||
u9NMDezKP6GrWacwIytlQSGW3aFm/EfQs/WVG10V3LmzOEPnJI+s63GPfG6JT0tg
|
||||
7DgtFxhuKaTfAri45iueoq6SqSCb7Brv01dTL/QA1E+r7RF4Z3S8HYM0qDVpvegq
|
||||
Wn7DZlDSm7htioUzeZgJPwsm3BwC8Kv4x9MY8g6/cU8LKEyxAoHAWCaDpLIuQ51r
|
||||
PeL+u/1cfNdi6OOrtZ6S95tu3Vv+mYzpCPnOpgPHFp3l+RGmLg56t7uvHFaFxOvB
|
||||
EjPm4bVhnPSA7pl7ZHQXhinG9+4UgcejoCAJzfg05BI1tMbwFZ+C0tG/PNzBlaX+
|
||||
IwkGO8VP/54N6wL1UqfZ8AKJFZW8G7W7KVkjqye1FS4oeDlJ197t/X+PMn5sFAc7
|
||||
UVsDaSelBqpsfmetXSH8KC3XkbgCtHvgAnJDkGkp84VmJvMr5ukv
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
})
|
||||
})
|
||||
r.Route("/groups", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Query().Get("userKey") {
|
||||
case "user1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#groups",
|
||||
"groups": []M{
|
||||
{"id": "group1"},
|
||||
{"id": "group2"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#user",
|
||||
"id": "1",
|
||||
"name": M{
|
||||
"fullName": "User 1",
|
||||
},
|
||||
"primaryEmail": "user1@example.com",
|
||||
})
|
||||
case "user2":
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer clearTimeout()
|
||||
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(WithServiceAccount(&ServiceAccount{
|
||||
Type: "service_account",
|
||||
PrivateKey: privateKey,
|
||||
TokenURL: srv.URL + "/token",
|
||||
}), WithURL(srv.URL))
|
||||
|
||||
du, err := p.User(ctx, "user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "user1", du.Id)
|
||||
assert.Equal(t, "user1@example.com", du.Email)
|
||||
assert.Equal(t, "User 1", du.DisplayName)
|
||||
assert.Equal(t, []string{"group1", "group2"}, du.GroupIds)
|
||||
|
||||
du, err = p.User(ctx, "user2", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "user2", du.Id)
|
||||
}
|
|
@ -108,6 +108,36 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, ErrServiceAccountNotDefined
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, providerUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.getDisplayName()
|
||||
du.Email = au.Profile.Email
|
||||
|
||||
groups, err := p.listUserGroups(ctx, providerUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range groups {
|
||||
du.GroupIds = append(du.GroupIds, g.ID)
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups fetches the groups of which the user is a member
|
||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
|
@ -159,7 +189,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, u.ID),
|
||||
GroupIds: groups,
|
||||
DisplayName: u.Profile.FirstName + " " + u.Profile.LastName,
|
||||
DisplayName: u.getDisplayName(),
|
||||
Email: u.Profile.Email,
|
||||
})
|
||||
}
|
||||
|
@ -183,14 +213,7 @@ func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
|
|||
|
||||
groupURL := p.cfg.providerURL.ResolveReference(u).String()
|
||||
for groupURL != "" {
|
||||
var out []struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"profile"`
|
||||
LastUpdated string `json:"lastUpdated"`
|
||||
LastMembershipUpdated string `json:"lastMembershipUpdated"`
|
||||
}
|
||||
var out []apiGroupObject
|
||||
hdrs, err := p.apiGet(ctx, groupURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
|
||||
|
@ -239,6 +262,36 @@ func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users [
|
|||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, userID string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v1/users/%s", userID),
|
||||
}).String()
|
||||
|
||||
var out apiUserObject
|
||||
_, err := p.apiGet(ctx, apiURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for user: %w", err)
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUserGroups(ctx context.Context, userID string) (groups []apiGroupObject, err error) {
|
||||
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v1/users/%s/groups", userID),
|
||||
}).String()
|
||||
for apiURL != "" {
|
||||
var out []apiGroupObject
|
||||
hdrs, err := p.apiGet(ctx, apiURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for user groups: %w", err)
|
||||
}
|
||||
groups = append(groups, out...)
|
||||
apiURL = getNextLink(hdrs)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
|
@ -334,11 +387,25 @@ func (err *APIError) Error() string {
|
|||
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
|
||||
}
|
||||
|
||||
type apiUserObject struct {
|
||||
type (
|
||||
apiGroupObject struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"profile"`
|
||||
LastUpdated string `json:"lastUpdated"`
|
||||
LastMembershipUpdated string `json:"lastMembershipUpdated"`
|
||||
}
|
||||
apiUserObject struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
Email string `json:"email"`
|
||||
} `json:"profile"`
|
||||
}
|
||||
)
|
||||
|
||||
func (obj *apiUserObject) getDisplayName() string {
|
||||
return obj.Profile.FirstName + " " + obj.Profile.LastName
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
|
@ -43,7 +44,9 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
|
|||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Get("/api/v1/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
r.Route("/groups", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ")
|
||||
var groups []string
|
||||
for group := range getAllGroups() {
|
||||
|
@ -86,7 +89,7 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
|
|||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
r.Get("/api/v1/groups/{group}/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Get("/{group}/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
group := chi.URLParam(r, "group")
|
||||
|
||||
if _, ok := getAllGroups()[group]; !ok {
|
||||
|
@ -122,9 +125,61 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
|
|||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/{user_id}/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
var groups []apiGroupObject
|
||||
for _, nm := range userEmailToGroups[chi.URLParam(r, "user_id")] {
|
||||
obj := apiGroupObject{
|
||||
ID: nm,
|
||||
}
|
||||
obj.Profile.Name = nm
|
||||
groups = append(groups, obj)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(groups)
|
||||
})
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
user := apiUserObject{
|
||||
ID: chi.URLParam(r, "user_id"),
|
||||
}
|
||||
user.Profile.Email = chi.URLParam(r, "user_id")
|
||||
user.Profile.FirstName = "first"
|
||||
user.Profile.LastName = "last"
|
||||
_ = json.NewEncoder(w).Encode(user)
|
||||
})
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockOkta http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockOkta.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockOkta = newMockOkta(srv, map[string][]string{
|
||||
"a@example.com": {"user", "admin"},
|
||||
"b@example.com": {"user", "test"},
|
||||
"c@example.com": {"user"},
|
||||
})
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
|
||||
WithProviderURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
user, err := p.User(context.Background(), "okta/a@example.com", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "okta/a@example.com",
|
||||
"groupIds": ["admin","user"],
|
||||
"displayName": "first last",
|
||||
"email": "a@example.com"
|
||||
}`, user)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockOkta http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -94,6 +94,32 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("onelogin: service account not defined")
|
||||
}
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, token.AccessToken, providerUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.getDisplayName()
|
||||
du.Email = au.Email
|
||||
du.GroupIds = []string{strconv.Itoa(au.GroupID)}
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for onelogin.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -107,12 +133,12 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := p.listGroups(ctx, token)
|
||||
groups, err := p.listGroups(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
apiUsers, err := p.getUsers(ctx, token)
|
||||
apiUsers, err := p.listUsers(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -133,7 +159,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*directory.Group, error) {
|
||||
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||
var groups []*directory.Group
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: "/api/1/groups",
|
||||
|
@ -144,9 +170,9 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire
|
|||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
nextLink, err := p.apiGet(ctx, token, apiURL, &result)
|
||||
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("onelogin: error querying group api: %w", err)
|
||||
return nil, fmt.Errorf("onelogin: listing groups: %w", err)
|
||||
}
|
||||
|
||||
for _, r := range result {
|
||||
|
@ -161,7 +187,24 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire
|
|||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUserObject, error) {
|
||||
func (p *Provider) getUser(ctx context.Context, accessToken string, userID string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/1/users/%s", userID),
|
||||
}).String()
|
||||
|
||||
var out []apiUserObject
|
||||
_, err := p.apiGet(ctx, accessToken, apiURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("onelogin: error getting user: %w", err)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, fmt.Errorf("onelogin: user not found")
|
||||
}
|
||||
|
||||
return &out[0], nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUsers(ctx context.Context, accessToken string) ([]apiUserObject, error) {
|
||||
var users []apiUserObject
|
||||
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
|
@ -170,9 +213,9 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser
|
|||
}).String()
|
||||
for apiURL != "" {
|
||||
var result []apiUserObject
|
||||
nextLink, err := p.apiGet(ctx, token, apiURL, &result)
|
||||
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("onelogin: listing users: %w", err)
|
||||
}
|
||||
|
||||
users = append(users, result...)
|
||||
|
@ -182,12 +225,12 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser
|
|||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, token *oauth2.Token, uri string, out interface{}) (nextLink string, err error) {
|
||||
func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, out interface{}) (nextLink string, err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", token.AccessToken))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", accessToken))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
|
@ -307,3 +350,7 @@ type apiUserObject struct {
|
|||
FirstName string `json:"firstname"`
|
||||
LastName string `json:"lastname"`
|
||||
}
|
||||
|
||||
func (obj *apiUserObject) getDisplayName() string {
|
||||
return obj.FirstName + " " + obj.LastName
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -102,6 +103,28 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han
|
|||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
userIDToGroupID := map[int]int{}
|
||||
for userID, groupName := range userIDToGroupName {
|
||||
for id, n := range allGroups {
|
||||
if groupName == n {
|
||||
userIDToGroupID[userID] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userID, _ := strconv.Atoi(chi.URLParam(r, "user_id"))
|
||||
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"data": []M{{
|
||||
"id": userID,
|
||||
"email": userIDToGroupName[userID] + "@example.com",
|
||||
"group_id": userIDToGroupID[userID],
|
||||
"firstname": "User",
|
||||
"lastname": fmt.Sprint(userID),
|
||||
}},
|
||||
})
|
||||
})
|
||||
r.Get("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
userIDToGroupID := map[int]int{}
|
||||
for userID, groupName := range userIDToGroupName {
|
||||
|
@ -130,6 +153,37 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han
|
|||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(srv, map[int]string{
|
||||
111: "admin",
|
||||
222: "test",
|
||||
333: "user",
|
||||
})
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}),
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
user, err := p.User(context.Background(), "onelogin/111", "ACCESSTOKEN")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "onelogin/111",
|
||||
"groupIds": ["0"],
|
||||
"displayName": "User 111",
|
||||
"email": "admin@example.com"
|
||||
}`, user)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -28,8 +28,12 @@ type User = directory.User
|
|||
// Options are the options specific to the provider.
|
||||
type Options = directory.Options
|
||||
|
||||
// RegisterDirectoryServiceServer registers the directory gRPC service.
|
||||
var RegisterDirectoryServiceServer = directory.RegisterDirectoryServiceServer
|
||||
|
||||
// A Provider provides user group directory information.
|
||||
type Provider interface {
|
||||
User(ctx context.Context, userID, accessToken string) (*User, error)
|
||||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||
}
|
||||
|
||||
|
@ -55,7 +59,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
|
||||
switch options.Provider {
|
||||
case auth0.Name:
|
||||
serviceAccount, err := auth0.ParseServiceAccount(options.ServiceAccount)
|
||||
serviceAccount, err := auth0.ParseServiceAccount(options)
|
||||
if err == nil {
|
||||
return auth0.New(
|
||||
auth0.WithDomain(options.ProviderURL),
|
||||
|
@ -139,6 +143,10 @@ func GetProvider(options Options) (provider Provider) {
|
|||
|
||||
type nullProvider struct{}
|
||||
|
||||
func (nullProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,22 @@
|
|||
// Package databroker contains databroker protobuf definitions.
|
||||
package databroker
|
||||
|
||||
import "strings"
|
||||
|
||||
// GetUserID gets the databroker user id from a provider user id.
|
||||
func GetUserID(provider, providerUserID string) string {
|
||||
return provider + "/" + providerUserID
|
||||
}
|
||||
|
||||
// FromUserID gets the provider and provider user id from a databroker user id.
|
||||
func FromUserID(userID string) (provider, providerUserID string) {
|
||||
ps := strings.SplitN(userID, "/", 2)
|
||||
if len(ps) < 2 {
|
||||
return "", userID
|
||||
}
|
||||
return ps[0], ps[1]
|
||||
}
|
||||
|
||||
// ApplyOffsetAndLimit applies the offset and limit to the list of records.
|
||||
func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) {
|
||||
records = all
|
||||
|
|
|
@ -7,7 +7,12 @@
|
|||
package directory
|
||||
|
||||
import (
|
||||
context "context"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
empty "github.com/golang/protobuf/ptypes/empty"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
|
@ -175,29 +180,97 @@ func (x *Group) GetEmail() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
type RefreshUserRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
AccessToken string `protobuf:"bytes,2,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) Reset() {
|
||||
*x = RefreshUserRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_directory_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RefreshUserRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RefreshUserRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_directory_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RefreshUserRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RefreshUserRequest) Descriptor() ([]byte, []int) {
|
||||
return file_directory_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) GetUserId() string {
|
||||
if x != nil {
|
||||
return x.UserId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) GetAccessToken() string {
|
||||
if x != nil {
|
||||
return x.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_directory_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_directory_proto_rawDesc = []byte{
|
||||
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x86, 0x01, 0x0a,
|
||||
0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12,
|
||||
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
|
||||
0x1b, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03,
|
||||
0x28, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c,
|
||||
0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12,
|
||||
0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
||||
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18,
|
||||
0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65,
|
||||
0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
|
||||
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
|
||||
0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65,
|
||||
0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x1b, 0x67, 0x6f,
|
||||
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d,
|
||||
0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x04, 0x55, 0x73,
|
||||
0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02,
|
||||
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09,
|
||||
0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52,
|
||||
0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x69, 0x73,
|
||||
0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
|
||||
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
|
||||
0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76,
|
||||
0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65,
|
||||
0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61,
|
||||
0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22,
|
||||
0x50, 0x0a, 0x12, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21,
|
||||
0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65,
|
||||
0x6e, 0x32, 0x58, 0x0a, 0x10, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68,
|
||||
0x55, 0x73, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79,
|
||||
0x2e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x31, 0x5a, 0x2f, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
|
||||
0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -212,14 +285,18 @@ func file_directory_proto_rawDescGZIP() []byte {
|
|||
return file_directory_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
|
||||
var file_directory_proto_goTypes = []interface{}{
|
||||
(*User)(nil), // 0: directory.User
|
||||
(*Group)(nil), // 1: directory.Group
|
||||
(*RefreshUserRequest)(nil), // 2: directory.RefreshUserRequest
|
||||
(*empty.Empty)(nil), // 3: google.protobuf.Empty
|
||||
}
|
||||
var file_directory_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
2, // 0: directory.DirectoryService.RefreshUser:input_type -> directory.RefreshUserRequest
|
||||
3, // 1: directory.DirectoryService.RefreshUser:output_type -> google.protobuf.Empty
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
|
@ -255,6 +332,18 @@ func file_directory_proto_init() {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
file_directory_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*RefreshUserRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
|
@ -262,9 +351,9 @@ func file_directory_proto_init() {
|
|||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_directory_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumMessages: 3,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_directory_proto_goTypes,
|
||||
DependencyIndexes: file_directory_proto_depIdxs,
|
||||
|
@ -275,3 +364,83 @@ func file_directory_proto_init() {
|
|||
file_directory_proto_goTypes = nil
|
||||
file_directory_proto_depIdxs = nil
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConnInterface
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion6
|
||||
|
||||
// DirectoryServiceClient is the client API for DirectoryService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type DirectoryServiceClient interface {
|
||||
RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
|
||||
}
|
||||
|
||||
type directoryServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewDirectoryServiceClient(cc grpc.ClientConnInterface) DirectoryServiceClient {
|
||||
return &directoryServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *directoryServiceClient) RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
|
||||
out := new(empty.Empty)
|
||||
err := c.cc.Invoke(ctx, "/directory.DirectoryService/RefreshUser", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DirectoryServiceServer is the server API for DirectoryService service.
|
||||
type DirectoryServiceServer interface {
|
||||
RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error)
|
||||
}
|
||||
|
||||
// UnimplementedDirectoryServiceServer can be embedded to have forward compatible implementations.
|
||||
type UnimplementedDirectoryServiceServer struct {
|
||||
}
|
||||
|
||||
func (*UnimplementedDirectoryServiceServer) RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RefreshUser not implemented")
|
||||
}
|
||||
|
||||
func RegisterDirectoryServiceServer(s *grpc.Server, srv DirectoryServiceServer) {
|
||||
s.RegisterService(&_DirectoryService_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _DirectoryService_RefreshUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RefreshUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DirectoryServiceServer).RefreshUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/directory.DirectoryService/RefreshUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DirectoryServiceServer).RefreshUser(ctx, req.(*RefreshUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _DirectoryService_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "directory.DirectoryService",
|
||||
HandlerType: (*DirectoryServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "RefreshUser",
|
||||
Handler: _DirectoryService_RefreshUser_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "directory.proto",
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ syntax = "proto3";
|
|||
package directory;
|
||||
option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
message User {
|
||||
string version = 1;
|
||||
string id = 2;
|
||||
|
@ -17,3 +19,12 @@ message Group {
|
|||
string name = 3;
|
||||
string email = 4;
|
||||
}
|
||||
|
||||
message RefreshUserRequest {
|
||||
string user_id = 1;
|
||||
string access_token = 2;
|
||||
}
|
||||
|
||||
service DirectoryService {
|
||||
rpc RefreshUser(RefreshUserRequest) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue