mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-15 18:17:49 +02:00
directory: add explicit RefreshUser endpoint for faster sync (#1460)
* directory: add explicit RefreshUser endpoint for faster sync * add test * implement azure * update api call * add test for azure User * implement github * implement AccessToken, gitlab * implement okta * implement onelogin * fix test * fix inconsistent test * implement auth0
This commit is contained in:
parent
9b39deabd8
commit
aa731ae068
23 changed files with 1405 additions and 179 deletions
|
@ -596,5 +596,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
||||||
}
|
}
|
||||||
sessionState.Version = sessions.Version(res.GetServerVersion())
|
sessionState.Version = sessions.Version(res.GetServerVersion())
|
||||||
|
|
||||||
|
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
|
||||||
|
UserId: s.UserId,
|
||||||
|
AccessToken: accessToken.AccessToken,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("directory: failed to refresh user data")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
@ -29,10 +31,12 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
|
"github.com/golang/protobuf/ptypes/empty"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
@ -171,6 +175,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
|
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
|
@ -262,6 +267,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
templates: template.Must(frontend.NewTemplates()),
|
templates: template.Must(frontend.NewTemplates()),
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
|
@ -366,6 +372,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
|
return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
redirectURL: authURL,
|
redirectURL: authURL,
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
|
@ -515,6 +522,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
provider: identity.NewAtomicAuthenticator(),
|
provider: identity.NewAtomicAuthenticator(),
|
||||||
|
@ -633,6 +641,7 @@ func TestAuthenticate_Dashboard(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
templates: template.Must(frontend.NewTemplates()),
|
templates: template.Must(frontend.NewTemplates()),
|
||||||
}
|
}
|
||||||
|
@ -679,3 +688,16 @@ func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.Get
|
||||||
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
|
func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) {
|
||||||
return m.set(ctx, in, opts...)
|
return m.set(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockDirectoryServiceClient struct {
|
||||||
|
directory.DirectoryServiceClient
|
||||||
|
|
||||||
|
refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
|
||||||
|
if m.refreshUser != nil {
|
||||||
|
return m.refreshUser(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
return nil, status.Error(codes.Unimplemented, "")
|
||||||
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
|
||||||
type authenticateState struct {
|
type authenticateState struct {
|
||||||
|
@ -46,6 +47,7 @@ type authenticateState struct {
|
||||||
jwk *jose.JSONWebKeySet
|
jwk *jose.JSONWebKeySet
|
||||||
|
|
||||||
dataBrokerClient databroker.DataBrokerServiceClient
|
dataBrokerClient databroker.DataBrokerServiceClient
|
||||||
|
directoryClient directory.DirectoryServiceClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthenticateState() *authenticateState {
|
func newAuthenticateState() *authenticateState {
|
||||||
|
@ -129,6 +131,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
||||||
}
|
}
|
||||||
|
|
||||||
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||||
|
state.directoryClient = directory.NewDirectoryServiceClient(dataBrokerConn)
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
8
cache/cache.go
vendored
8
cache/cache.go
vendored
|
@ -7,6 +7,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"gopkg.in/tomb.v2"
|
"gopkg.in/tomb.v2"
|
||||||
|
@ -33,6 +34,9 @@ type Cache struct {
|
||||||
localGRPCConnection *grpc.ClientConn
|
localGRPCConnection *grpc.ClientConn
|
||||||
dataBrokerStorageType string //TODO remove in v0.11
|
dataBrokerStorageType string //TODO remove in v0.11
|
||||||
deprecatedCacheClusterDomain string //TODO: remove in v0.11
|
deprecatedCacheClusterDomain string //TODO: remove in v0.11
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
directoryProvider directory.Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new cache service.
|
// New creates a new cache service.
|
||||||
|
@ -90,6 +94,7 @@ func (c *Cache) OnConfigChange(cfg *config.Config) {
|
||||||
// Register registers all the gRPC services with the given server.
|
// Register registers all the gRPC services with the given server.
|
||||||
func (c *Cache) Register(grpcServer *grpc.Server) {
|
func (c *Cache) Register(grpcServer *grpc.Server) {
|
||||||
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
|
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
|
||||||
|
directory.RegisterDirectoryServiceServer(grpcServer, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run runs the cache components.
|
// Run runs the cache components.
|
||||||
|
@ -132,6 +137,9 @@ func (c *Cache) update(cfg *config.Config) error {
|
||||||
ClientID: cfg.Options.ClientID,
|
ClientID: cfg.Options.ClientID,
|
||||||
ClientSecret: cfg.Options.ClientSecret,
|
ClientSecret: cfg.Options.ClientSecret,
|
||||||
})
|
})
|
||||||
|
c.mu.Lock()
|
||||||
|
c.directoryProvider = directoryProvider
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
||||||
|
|
||||||
|
|
44
cache/directory.go
vendored
Normal file
44
cache/directory.go
vendored
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RefreshUser refreshes a user's directory information.
|
||||||
|
func (c *Cache) RefreshUser(ctx context.Context, req *directory.RefreshUserRequest) (*emptypb.Empty, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
dp := c.directoryProvider
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if dp == nil {
|
||||||
|
return nil, errors.New("no directory provider is available for refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
any, err := anypb.New(u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.dataBrokerServer.Set(ctx, &databroker.SetRequest{
|
||||||
|
Type: any.GetTypeUrl(),
|
||||||
|
Id: u.GetId(),
|
||||||
|
Data: any,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return new(emptypb.Empty), nil
|
||||||
|
}
|
|
@ -21,16 +21,33 @@ import (
|
||||||
// Name is the provider name.
|
// Name is the provider name.
|
||||||
const Name = "auth0"
|
const Name = "auth0"
|
||||||
|
|
||||||
|
type (
|
||||||
// RoleManager defines what is needed to get role info from Auth0.
|
// RoleManager defines what is needed to get role info from Auth0.
|
||||||
type RoleManager interface {
|
RoleManager interface {
|
||||||
List(opts ...management.ListOption) (r *management.RoleList, err error)
|
List(opts ...management.ListOption) (r *management.RoleList, err error)
|
||||||
Users(id string, opts ...management.ListOption) (u *management.UserList, err error)
|
Users(id string, opts ...management.ListOption) (u *management.UserList, err error)
|
||||||
}
|
}
|
||||||
|
// UserManager defines what is needed to get user info from Auth0.
|
||||||
|
UserManager interface {
|
||||||
|
Read(id string) (*management.User, error)
|
||||||
|
Roles(id string, opts ...management.ListOption) (r *management.RoleList, err error)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type newManagersFunc = func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error)
|
||||||
|
|
||||||
|
func defaultNewManagersFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
|
||||||
|
m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("auth0: could not build management: %w", err)
|
||||||
|
}
|
||||||
|
return m.Role, m.User, nil
|
||||||
|
}
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
domain string
|
domain string
|
||||||
serviceAccount *ServiceAccount
|
serviceAccount *ServiceAccount
|
||||||
newRoleManager func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error)
|
newManagers newManagersFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option provides config for the Auth0 Provider.
|
// Option provides config for the Auth0 Provider.
|
||||||
|
@ -50,17 +67,9 @@ func WithDomain(domain string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultNewRoleManagerFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) {
|
|
||||||
m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("auth0: could not build management")
|
|
||||||
}
|
|
||||||
return m.Role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConfig(options ...Option) *config {
|
func getConfig(options ...Option) *config {
|
||||||
cfg := &config{
|
cfg := &config{
|
||||||
newRoleManager: defaultNewRoleManagerFunc,
|
newManagers: defaultNewManagersFunc,
|
||||||
}
|
}
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
option(cfg)
|
option(cfg)
|
||||||
|
@ -82,13 +91,49 @@ func New(options ...Option) *Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) getRoleManager(ctx context.Context) (RoleManager, error) {
|
func (p *Provider) getManagers(ctx context.Context) (RoleManager, UserManager, error) {
|
||||||
return p.cfg.newRoleManager(ctx, p.cfg.domain, p.cfg.serviceAccount)
|
return p.cfg.newManagers(ctx, p.cfg.domain, p.cfg.serviceAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
_, um, err := p.getManagers(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := um.Read(providerUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("auth0: error getting user info: %w", err)
|
||||||
|
}
|
||||||
|
du.DisplayName = u.GetName()
|
||||||
|
du.Email = u.GetEmail()
|
||||||
|
|
||||||
|
for page, hasNext := 0, true; hasNext; page++ {
|
||||||
|
rl, err := um.Roles(providerUserID, management.IncludeTotals(true), management.Page(page))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("auth0: error getting user roles: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, role := range rl.Roles {
|
||||||
|
du.GroupIds = append(du.GroupIds, role.GetID())
|
||||||
|
}
|
||||||
|
|
||||||
|
hasNext = rl.HasNext()
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(du.GroupIds)
|
||||||
|
return du, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserGroups fetches a slice of groups and users.
|
// UserGroups fetches a slice of groups and users.
|
||||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
rm, err := p.getRoleManager(ctx)
|
rm, _, err := p.getManagers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -147,6 +192,9 @@ func getRoles(rm RoleManager) ([]*directory.Group, error) {
|
||||||
shouldContinue = listRes.HasNext()
|
shouldContinue = listRes.HasNext()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sort.Slice(roles, func(i, j int) bool {
|
||||||
|
return roles[i].GetId() < roles[j].GetId()
|
||||||
|
})
|
||||||
return roles, nil
|
return roles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,6 +218,7 @@ func getRoleUserIDs(rm RoleManager, roleID string) ([]string, error) {
|
||||||
shouldContinue = usersRes.HasNext()
|
shouldContinue = usersRes.HasNext()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sort.Strings(ids)
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,7 +229,30 @@ type ServiceAccount struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseServiceAccount parses the service account in the config options.
|
// ParseServiceAccount parses the service account in the config options.
|
||||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
|
||||||
|
if options.ServiceAccount != "" {
|
||||||
|
return parseServiceAccountFromString(options.ServiceAccount)
|
||||||
|
}
|
||||||
|
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceAccountFromOptions(clientID, clientSecret string) (*ServiceAccount, error) {
|
||||||
|
serviceAccount := ServiceAccount{
|
||||||
|
ClientID: clientID,
|
||||||
|
Secret: clientSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
if serviceAccount.ClientID == "" {
|
||||||
|
return nil, fmt.Errorf("auth0: client_id is required")
|
||||||
|
}
|
||||||
|
if serviceAccount.Secret == "" {
|
||||||
|
return nil, fmt.Errorf("auth0: client_secret is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serviceAccount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
|
||||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auth0: could not decode base64: %w", err)
|
return nil, fmt.Errorf("auth0: could not decode base64: %w", err)
|
||||||
|
|
|
@ -2,34 +2,133 @@ package auth0
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/go-chi/chi/middleware"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gopkg.in/auth0.v4/management"
|
"gopkg.in/auth0.v4/management"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0"
|
"github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type M = map[string]interface{}
|
||||||
|
|
||||||
|
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(middleware.Logger)
|
||||||
|
r.Post("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"access_token": "ACCESSTOKEN",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "REFRESHTOKEN",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Route("/api/v2", func(r chi.Router) {
|
||||||
|
r.Route("/users", func(r chi.Router) {
|
||||||
|
r.Route("/{user_id}", func(r chi.Router) {
|
||||||
|
r.Get("/roles", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch chi.URLParam(r, "user_id") {
|
||||||
|
case "user1":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"total": 2,
|
||||||
|
"limit": 2,
|
||||||
|
"roles": []M{
|
||||||
|
{"id": "role1"},
|
||||||
|
{"id": "role2"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch chi.URLParam(r, "user_id") {
|
||||||
|
case "user1":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"user_id": "user1",
|
||||||
|
"email": "user1@example.com",
|
||||||
|
"name": "User 1",
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_User(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
orig := http.DefaultTransport
|
||||||
|
defer func() {
|
||||||
|
http.DefaultTransport = orig
|
||||||
|
}()
|
||||||
|
http.DefaultTransport = &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var mockAPI http.Handler
|
||||||
|
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
srv.StartTLS()
|
||||||
|
defer srv.Close()
|
||||||
|
mockAPI = newMockAPI(t, srv)
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithDomain(srv.URL),
|
||||||
|
WithServiceAccount(&ServiceAccount{ClientID: "CLIENT_ID", Secret: "SECRET"}),
|
||||||
|
)
|
||||||
|
du, err := p.User(ctx, "auth0/user1", "")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
"id": "auth0/user1",
|
||||||
|
"displayName": "User 1",
|
||||||
|
"email": "user1@example.com",
|
||||||
|
"groupIds": ["role1", "role2"]
|
||||||
|
}`, du)
|
||||||
|
}
|
||||||
|
|
||||||
type mockNewRoleManagerFunc struct {
|
type mockNewRoleManagerFunc struct {
|
||||||
CalledWithContext context.Context
|
CalledWithContext context.Context
|
||||||
CalledWithDomain string
|
CalledWithDomain string
|
||||||
CalledWithServiceAccount *ServiceAccount
|
CalledWithServiceAccount *ServiceAccount
|
||||||
|
|
||||||
ReturnRoleManager RoleManager
|
ReturnRoleManager RoleManager
|
||||||
|
ReturnUserManager UserManager
|
||||||
ReturnError error
|
ReturnError error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) {
|
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
|
||||||
m.CalledWithContext = ctx
|
m.CalledWithContext = ctx
|
||||||
m.CalledWithDomain = domain
|
m.CalledWithDomain = domain
|
||||||
m.CalledWithServiceAccount = serviceAccount
|
m.CalledWithServiceAccount = serviceAccount
|
||||||
|
|
||||||
return m.ReturnRoleManager, m.ReturnError
|
return m.ReturnRoleManager, m.ReturnUserManager, m.ReturnError
|
||||||
}
|
}
|
||||||
|
|
||||||
type listOptionMatcher struct {
|
type listOptionMatcher struct {
|
||||||
|
@ -379,7 +478,7 @@ func TestProvider_UserGroups(t *testing.T) {
|
||||||
|
|
||||||
mRoleManager := mock_auth0.NewMockRoleManager(ctrl)
|
mRoleManager := mock_auth0.NewMockRoleManager(ctrl)
|
||||||
|
|
||||||
mNewRoleManagerFunc := mockNewRoleManagerFunc{
|
mNewManagersFunc := mockNewRoleManagerFunc{
|
||||||
ReturnRoleManager: mRoleManager,
|
ReturnRoleManager: mRoleManager,
|
||||||
ReturnError: tc.newRoleManagerError,
|
ReturnError: tc.newRoleManagerError,
|
||||||
}
|
}
|
||||||
|
@ -392,7 +491,7 @@ func TestProvider_UserGroups(t *testing.T) {
|
||||||
WithDomain(expectedDomain),
|
WithDomain(expectedDomain),
|
||||||
WithServiceAccount(expectedServiceAccount),
|
WithServiceAccount(expectedServiceAccount),
|
||||||
)
|
)
|
||||||
p.cfg.newRoleManager = mNewRoleManagerFunc.f
|
p.cfg.newManagers = mNewManagersFunc.f
|
||||||
|
|
||||||
actualGroups, actualUsers, err := p.UserGroups(context.Background())
|
actualGroups, actualUsers, err := p.UserGroups(context.Background())
|
||||||
if tc.expectedError != nil {
|
if tc.expectedError != nil {
|
||||||
|
@ -404,8 +503,8 @@ func TestProvider_UserGroups(t *testing.T) {
|
||||||
assert.Equal(t, tc.expectedGroups, actualGroups)
|
assert.Equal(t, tc.expectedGroups, actualGroups)
|
||||||
assert.Equal(t, tc.expectedUsers, actualUsers)
|
assert.Equal(t, tc.expectedUsers, actualUsers)
|
||||||
|
|
||||||
assert.Equal(t, expectedDomain, mNewRoleManagerFunc.CalledWithDomain)
|
assert.Equal(t, expectedDomain, mNewManagersFunc.CalledWithDomain)
|
||||||
assert.Equal(t, expectedServiceAccount, mNewRoleManagerFunc.CalledWithServiceAccount)
|
assert.Equal(t, expectedServiceAccount, mNewManagersFunc.CalledWithServiceAccount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -433,7 +532,7 @@ func TestParseServiceAccount(t *testing.T) {
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
actualServiceAccount, err := ParseServiceAccount(tc.rawServiceAccount)
|
actualServiceAccount, err := ParseServiceAccount(directory.Options{ServiceAccount: tc.rawServiceAccount})
|
||||||
if tc.expectedError != nil {
|
if tc.expectedError != nil {
|
||||||
assert.EqualError(t, err, tc.expectedError.Error())
|
assert.EqualError(t, err, tc.expectedError.Error())
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -6,14 +6,15 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,6 +102,50 @@ func New(options ...Option) *Provider {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
if p.cfg.serviceAccount == nil {
|
||||||
|
return nil, fmt.Errorf("azure: service account not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
userURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/v1.0/users/%s", providerUserID),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
var u usersDeltaResponseUser
|
||||||
|
err := p.api(ctx, userURL, &u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
du.DisplayName = u.DisplayName
|
||||||
|
du.Email = u.getEmail()
|
||||||
|
|
||||||
|
groupURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", providerUserID),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
var res struct {
|
||||||
|
Value []usersDeltaResponseUser `json:"value"`
|
||||||
|
}
|
||||||
|
err = p.api(ctx, groupURL, &res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, g := range res.Value {
|
||||||
|
du.GroupIds = append(du.GroupIds, g.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(du.GroupIds)
|
||||||
|
|
||||||
|
return du, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroups returns the directory users in azure active directory.
|
// UserGroups returns the directory users in azure active directory.
|
||||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
if p.cfg.serviceAccount == nil {
|
if p.cfg.serviceAccount == nil {
|
||||||
|
@ -116,13 +161,13 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
return groups, users, nil
|
return groups, users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
|
func (p *Provider) api(ctx context.Context, url string, out interface{}) error {
|
||||||
token, err := p.getToken(ctx)
|
token, err := p.getToken(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("azure: error creating HTTP request: %w", err)
|
return fmt.Errorf("azure: error creating HTTP request: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -143,7 +188,7 @@ func (p *Provider) api(ctx context.Context, method, url string, body io.Reader,
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.StatusCode/100 != 2 {
|
if res.StatusCode/100 != 2 {
|
||||||
return fmt.Errorf("azure: error querying api: %s", res.Status)
|
return fmt.Errorf("azure: error querying api (%s): %s", url, res.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.NewDecoder(res.Body).Decode(out)
|
err = json.NewDecoder(res.Body).Decode(out)
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/go-chi/chi/middleware"
|
"github.com/go-chi/chi/middleware"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -74,12 +75,62 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch chi.URLParam(r, "user_id") {
|
||||||
|
case "user-1":
|
||||||
|
_ = json.NewEncoder(w).Encode(M{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"})
|
||||||
|
default:
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch chi.URLParam(r, "user_id") {
|
||||||
|
case "user-1":
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"value": []M{
|
||||||
|
{"id": "admin"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test(t *testing.T) {
|
func TestProvider_User(t *testing.T) {
|
||||||
|
var mockAPI http.Handler
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
mockAPI = newMockAPI(t, srv)
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithGraphURL(mustParseURL(srv.URL)),
|
||||||
|
WithLoginURL(mustParseURL(srv.URL)),
|
||||||
|
WithServiceAccount(&ServiceAccount{
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
DirectoryID: "DIRECTORY_ID",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
du, err := p.User(context.Background(), "azure/user-1", "")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
"id": "azure/user-1",
|
||||||
|
"displayName": "User 1",
|
||||||
|
"email": "user1@example.com",
|
||||||
|
"groupIds": ["admin"]
|
||||||
|
}`, du)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_UserGroups(t *testing.T) {
|
||||||
var mockAPI http.Handler
|
var mockAPI http.Handler
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
mockAPI.ServeHTTP(w, r)
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
@ -118,10 +169,10 @@ func Test(t *testing.T) {
|
||||||
Email: "user3@example.com",
|
Email: "user3@example.com",
|
||||||
},
|
},
|
||||||
}, users)
|
}, users)
|
||||||
assert.Equal(t, []*directory.Group{
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
{Id: "admin", Name: "Admin Group"},
|
{ "id": "admin", "name": "Admin Group" },
|
||||||
{Id: "test", Name: "Test Group"},
|
{ "id": "test", "name": "Test Group"}
|
||||||
}, groups)
|
]`, groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseServiceAccount(t *testing.T) {
|
func TestParseServiceAccount(t *testing.T) {
|
||||||
|
|
|
@ -86,7 +86,7 @@ func (dc *deltaCollection) syncGroups(ctx context.Context) error {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var res groupsDeltaResponse
|
var res groupsDeltaResponse
|
||||||
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
|
err := dc.provider.api(ctx, apiURL, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -146,7 +146,7 @@ func (dc *deltaCollection) syncUsers(ctx context.Context) error {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var res usersDeltaResponse
|
var res usersDeltaResponse
|
||||||
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
|
err := dc.provider.api(ctx, apiURL, &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -197,6 +197,9 @@ func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory
|
||||||
}
|
}
|
||||||
groupLookup.addGroup(g.id, groupIDs, userIDs)
|
groupLookup.addGroup(g.id, groupIDs, userIDs)
|
||||||
}
|
}
|
||||||
|
sort.Slice(groups, func(i, j int) bool {
|
||||||
|
return groups[i].GetId() < groups[j].GetId()
|
||||||
|
})
|
||||||
|
|
||||||
var users []*directory.User
|
var users []*directory.User
|
||||||
for _, u := range dc.users {
|
for _, u := range dc.users {
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
package github
|
package github
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -83,6 +84,47 @@ func New(options ...Option) *Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
if p.cfg.serviceAccount == nil {
|
||||||
|
return nil, fmt.Errorf("github: service account not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := p.getUser(ctx, providerUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
du.DisplayName = au.Name
|
||||||
|
du.Email = au.Email
|
||||||
|
|
||||||
|
teamIDLookup := map[int]struct{}{}
|
||||||
|
orgSlugs, err := p.listOrgs(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, orgSlug := range orgSlugs {
|
||||||
|
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, teamID := range teamIDs {
|
||||||
|
teamIDLookup[teamID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for teamID := range teamIDLookup {
|
||||||
|
du.GroupIds = append(du.GroupIds, strconv.Itoa(teamID))
|
||||||
|
}
|
||||||
|
sort.Strings(du.GroupIds)
|
||||||
|
|
||||||
|
return du, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroups gets the directory user groups for github.
|
// UserGroups gets the directory user groups for github.
|
||||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
if p.cfg.serviceAccount == nil {
|
if p.cfg.serviceAccount == nil {
|
||||||
|
@ -230,6 +272,77 @@ func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObjec
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]int, error) {
|
||||||
|
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
|
||||||
|
|
||||||
|
enc := func(obj interface{}) string {
|
||||||
|
bs, _ := json.Marshal(obj)
|
||||||
|
return string(bs)
|
||||||
|
}
|
||||||
|
const pageCount = 100
|
||||||
|
|
||||||
|
var teamIDs []int
|
||||||
|
var cursor *string
|
||||||
|
for {
|
||||||
|
var res struct {
|
||||||
|
Data struct {
|
||||||
|
Organization struct {
|
||||||
|
Teams struct {
|
||||||
|
TotalCount int `json:"totalCount"`
|
||||||
|
PageInfo struct {
|
||||||
|
EndCursor string `json:"endCursor"`
|
||||||
|
} `json:"pageInfo"`
|
||||||
|
Edges []struct {
|
||||||
|
Node struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
} `json:"node"`
|
||||||
|
} `json:"edges"`
|
||||||
|
} `json:"teams"`
|
||||||
|
} `json:"organization"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
cursorStr := ""
|
||||||
|
if cursor != nil {
|
||||||
|
cursorStr = fmt.Sprintf(",%s", enc(*cursor))
|
||||||
|
}
|
||||||
|
q := fmt.Sprintf(`query {
|
||||||
|
organization(login:%s) {
|
||||||
|
teams(first:%s, userLogins:[%s] %s) {
|
||||||
|
totalCount
|
||||||
|
pageInfo {
|
||||||
|
endCursor
|
||||||
|
}
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`, enc(orgSlug), enc(pageCount), enc(userSlug), cursorStr)
|
||||||
|
_, err := p.graphql(ctx, q, &res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Data.Organization.Teams.Edges) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, edge := range res.Data.Organization.Teams.Edges {
|
||||||
|
teamIDs = append(teamIDs, edge.Node.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(teamIDs) >= res.Data.Organization.Teams.TotalCount {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
|
||||||
|
}
|
||||||
|
|
||||||
|
return teamIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
|
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -257,6 +370,41 @@ func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (htt
|
||||||
return res.Header, nil
|
return res.Header, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
|
||||||
|
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||||
|
Path: "/graphql",
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
bs, _ := json.Marshal(struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
}{query})
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github: failed to create http request: %w", err)
|
||||||
|
}
|
||||||
|
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
|
||||||
|
|
||||||
|
res, err := p.cfg.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github: failed to make http request: %w", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode/100 != 2 {
|
||||||
|
return nil, fmt.Errorf("github: error from API: %s", res.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out != nil {
|
||||||
|
err := json.NewDecoder(res.Body).Decode(out)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Header, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getNextLink(hdrs http.Header) string {
|
func getNextLink(hdrs http.Header) string {
|
||||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||||
if link.Rel == "next" {
|
if link.Rel == "next" {
|
||||||
|
|
|
@ -29,6 +29,33 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var body struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(r.Body).Decode(&body)
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"data": M{
|
||||||
|
"organization": M{
|
||||||
|
"teams": M{
|
||||||
|
"totalCount": 3,
|
||||||
|
"edges": []M{
|
||||||
|
{"node": M{
|
||||||
|
"id": 1,
|
||||||
|
}},
|
||||||
|
{"node": M{
|
||||||
|
"id": 2,
|
||||||
|
}},
|
||||||
|
{"node": M{
|
||||||
|
"id": 3,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
json.NewEncoder(w).Encode([]M{
|
json.NewEncoder(w).Encode([]M{
|
||||||
{"login": "org1"},
|
{"login": "org1"},
|
||||||
|
@ -88,7 +115,34 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test(t *testing.T) {
|
func TestProvider_User(t *testing.T) {
|
||||||
|
var mockAPI http.Handler
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
mockAPI = newMockAPI(t, srv)
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithURL(mustParseURL(srv.URL)),
|
||||||
|
WithServiceAccount(&ServiceAccount{
|
||||||
|
Username: "abc",
|
||||||
|
PersonalAccessToken: "xyz",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
du, err := p.User(context.Background(), "github/user1", "")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
"id": "github/user1",
|
||||||
|
"groupIds": ["1", "2", "3"],
|
||||||
|
"displayName": "User 1",
|
||||||
|
"email": "user1@example.com"
|
||||||
|
}`, du)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_UserGroups(t *testing.T) {
|
||||||
var mockAPI http.Handler
|
var mockAPI http.Handler
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
mockAPI.ServeHTTP(w, r)
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
|
|
@ -83,6 +83,32 @@ func New(options ...Option) *Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := p.getUser(ctx, providerUserID, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
du.DisplayName = au.Name
|
||||||
|
du.Email = au.Email
|
||||||
|
|
||||||
|
groups, err := p.listGroups(ctx, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, g := range groups {
|
||||||
|
du.GroupIds = append(du.GroupIds, g.Id)
|
||||||
|
}
|
||||||
|
sort.Strings(du.GroupIds)
|
||||||
|
|
||||||
|
return du, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroups gets the directory user groups for gitlab.
|
// UserGroups gets the directory user groups for gitlab.
|
||||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
if p.cfg.serviceAccount == nil {
|
if p.cfg.serviceAccount == nil {
|
||||||
|
@ -91,7 +117,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
|
|
||||||
p.log.Info().Msg("getting user groups")
|
p.log.Info().Msg("getting user groups")
|
||||||
|
|
||||||
groups, err := p.listGroups(ctx)
|
groups, err := p.listGroups(ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -129,8 +155,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
return groups, users, nil
|
return groups, users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) {
|
||||||
|
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/api/v4/users/%s", userID),
|
||||||
|
}).String()
|
||||||
|
var result apiUserObject
|
||||||
|
_, err := p.api(ctx, accessToken, apiURL, &result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab: error querying user: %w", err)
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// listGroups returns a map, with key is group ID, element is group name.
|
// listGroups returns a map, with key is group ID, element is group name.
|
||||||
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||||
Path: "/api/v4/groups",
|
Path: "/api/v4/groups",
|
||||||
}).String()
|
}).String()
|
||||||
|
@ -140,7 +178,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
hdrs, err := p.apiGet(ctx, nextURL, &result)
|
hdrs, err := p.api(ctx, accessToken, nextURL, &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
|
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -163,7 +201,7 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
|
||||||
}).String()
|
}).String()
|
||||||
for nextURL != "" {
|
for nextURL != "" {
|
||||||
var result []apiUserObject
|
var result []apiUserObject
|
||||||
hdrs, err := p.apiGet(ctx, nextURL, &result)
|
hdrs, err := p.api(ctx, "", nextURL, &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
|
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -174,14 +212,18 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
|
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if accessToken != "" {
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||||
|
} else {
|
||||||
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
|
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
|
||||||
|
}
|
||||||
|
|
||||||
res, err := p.cfg.httpClient.Do(req)
|
res, err := p.cfg.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -190,7 +232,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode/100 != 2 {
|
if res.StatusCode/100 != 2 {
|
||||||
return nil, fmt.Errorf("gitlab: error query api status_code=%d: %s", res.StatusCode, res.Status)
|
return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.NewDecoder(res.Body).Decode(out)
|
err = json.NewDecoder(res.Body).Decode(out)
|
||||||
|
|
|
@ -5,13 +5,16 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
|
"google.golang.org/api/googleapi"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
@ -19,11 +22,15 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
// Name is the provider name.
|
// Name is the provider name.
|
||||||
const Name = "google"
|
Name = "google"
|
||||||
|
|
||||||
|
currentAccountCustomerID = "my_customer"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultProviderURL = "https://accounts.google.com"
|
defaultProviderURL = "https://www.googleapis.com/admin/directory/v1/"
|
||||||
)
|
)
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
|
@ -78,6 +85,51 @@ func New(options ...Option) *Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
apiClient, err := p.getAPIClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := apiClient.Users.Get(providerUserID).
|
||||||
|
Context(ctx).
|
||||||
|
Do()
|
||||||
|
if isAccessDenied(err) {
|
||||||
|
// ignore forbidden errors as a user may login from another gsuite domain
|
||||||
|
return du, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("google: error getting user: %w", err)
|
||||||
|
} else {
|
||||||
|
if au.Name != nil {
|
||||||
|
du.DisplayName = au.Name.FullName
|
||||||
|
}
|
||||||
|
du.Email = au.PrimaryEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
err = apiClient.Groups.List().
|
||||||
|
Context(ctx).
|
||||||
|
UserKey(providerUserID).
|
||||||
|
Pages(ctx, func(res *admin.Groups) error {
|
||||||
|
for _, g := range res.Groups {
|
||||||
|
du.GroupIds = append(du.GroupIds, g.Id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("google: error getting groups for user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(du.GroupIds)
|
||||||
|
|
||||||
|
return du, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroups returns a slice of group names a given user is in
|
// UserGroups returns a slice of group names a given user is in
|
||||||
// NOTE: groups via Directory API is limited to 1 QPS!
|
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||||
|
@ -91,7 +143,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
var groups []*directory.Group
|
var groups []*directory.Group
|
||||||
err = apiClient.Groups.List().
|
err = apiClient.Groups.List().
|
||||||
Context(ctx).
|
Context(ctx).
|
||||||
Customer("my_customer").
|
Customer(currentAccountCustomerID).
|
||||||
Pages(ctx, func(res *admin.Groups) error {
|
Pages(ctx, func(res *admin.Groups) error {
|
||||||
for _, g := range res.Groups {
|
for _, g := range res.Groups {
|
||||||
// Skip group without member.
|
// Skip group without member.
|
||||||
|
@ -113,7 +165,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
userLookup := map[string]apiUserObject{}
|
userLookup := map[string]apiUserObject{}
|
||||||
err = apiClient.Users.List().
|
err = apiClient.Users.List().
|
||||||
Context(ctx).
|
Context(ctx).
|
||||||
Customer("my_customer").
|
Customer(currentAccountCustomerID).
|
||||||
Pages(ctx, func(res *admin.Users) error {
|
Pages(ctx, func(res *admin.Users) error {
|
||||||
for _, u := range res.Users {
|
for _, u := range res.Users {
|
||||||
userLookup[u.Id] = apiUserObject{
|
userLookup[u.Id] = apiUserObject{
|
||||||
|
@ -188,7 +240,7 @@ func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {
|
||||||
|
|
||||||
ts := config.TokenSource(ctx)
|
ts := config.TokenSource(ctx)
|
||||||
|
|
||||||
p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts))
|
p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts), option.WithEndpoint(p.cfg.url))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
||||||
}
|
}
|
||||||
|
@ -243,3 +295,16 @@ type apiUserObject struct {
|
||||||
DisplayName string
|
DisplayName string
|
||||||
Email string
|
Email string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isAccessDenied(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
gerr := new(googleapi.Error)
|
||||||
|
if errors.As(err, &gerr) {
|
||||||
|
return gerr.Code == http.StatusForbidden
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
140
internal/directory/google/google_test.go
Normal file
140
internal/directory/google/google_test.go
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package google
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/go-chi/chi/middleware"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var privateKey = `
|
||||||
|
-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH
|
||||||
|
pzXBbzgGKIbQbu+njoRzhQ95RbcEzZiVFivNggijkFoUNkFIjy42O/xXdTRTX3/u
|
||||||
|
4pxu1ctccqYnZwnry6E8ekQTVX7kgmVqgzIrY1Y6K3PVlhkgGDK/TStu+RIPoA73
|
||||||
|
vJUpJTFTw+6tgUSBmCkzctQsGFGUiGBRpqlxogEkImJYrcMUJkWhTopLy79+3OB5
|
||||||
|
eGGKWwA6P2ZhBB4RKHBWyWipsYxr889QNG6P3o+Be6lIzvcdiIIBYK8qWLmO45hR
|
||||||
|
xUGWSRK8sveAJO54t+wGE0dSPVKfoS4oNqGYdGhBzoTwsfZRvyvcWeQob4DNCqQa
|
||||||
|
n41XAuYGOG3X1PexdwGSwwqrq2tuG9d2AJ2NjG8nC9hjuuKDfBGTigTwzrwkrn+F
|
||||||
|
3o94NoglQgsXfZeWoBXR5HDaDTdqexSRK0OpSPbvzkn8QDUymdaw7nRS7kU2O7fa
|
||||||
|
W8kxiV8AVt2v/jYjAgMBAAECggGAS+6NE0Mo0Pki2Mi0+y4ls7BgNNbJhYFRthpi
|
||||||
|
RvN8WXE+B0EhquzWDbxKecIZBr6FOqnFQf2muja+QH1VckutjRDKVOZ7QXsygGYf
|
||||||
|
/cff5aM7U3gYTS4avQIDeb0axRioIElu4vtIsJtLFMfe5yaAZktXvPz9fiamn/wG
|
||||||
|
r/xqZ6ifir6nsC2okxGczWE+XCGsjpWA/321lhvj548os5JV6SfTUBUpvqNGVQC2
|
||||||
|
ByXIPDffsCTfQ1rjQ85gM4vuqiQqKn/KXRrMR1kRIOrglMJ6dllitsadkfWdPkVg
|
||||||
|
fHjM1KAnw/uob5kFhvqeEll9sESfCXttrb4XneKAsEck958ChSpU4csaBAfLFVYP
|
||||||
|
5xyIfoaQ+/CUjWI0u1lQbg6jZO59rYfdd5OlH+MyhHybuHR1a0G1izNfuG9WPOWI
|
||||||
|
aprNayH2Wxy9/ZvlrE5yTAeW9tof28hO6O7wBNOcJTrzztsN+V8pSAo0IE2r4D83
|
||||||
|
h978LneAwhC/8mVvzhd/y2t99vcBAoHBAMumCoHmHRAsYBApg2SHxPCTDi4fS1hC
|
||||||
|
IbcuzyvJLv214bHOsdgg2I72a1Q+bbrZAHgENVgSP9Sx1k8aXlvnGOEbpL8w3jRL
|
||||||
|
G4/qXzGrMBp3LCzupi3dI3yrrIMWb2C0goyHeAejzrfaM+uDYTGW4iqhA39zBj4o
|
||||||
|
zoydz3v0i8Yag7Df9MIwr34WD9Ng0oXh8XRCAYJmS1e43jnM+XcFdSfGVhKn9h1B
|
||||||
|
Cbv/hqUSv6baNloWLlPBffLII5bx633MMwKBwQDGg87fKEPz73pbp1kAOuXDpUnm
|
||||||
|
+OUFuf6Uqr/OqCorh8SOwxb2q3ScqyCWGVLOaVBcSV+zxIdcQzX2kjf1mj7PcQ6Q
|
||||||
|
2xfDIS1/xT+RiX9LO0kbkVDYcwcGeKVtmUwWyjauo96OB2r+SchTsNJpYOT3a/7r
|
||||||
|
JUKdbHFwsFwAx5q7r9mOh0BOybuXM6N7lUDBf4SgrhjnKRh1pME3R0JbJj9m8tZg
|
||||||
|
SsWlHcj04yAXJ7NGemiiYgeDZ4unsAfx7/sS/lECgcEAsL0Yj2XTQU8Ry9ULaDsA
|
||||||
|
az1k6BhWvnEea6lfOQPwGVY5WqQk6oqPB3vK6CEKAEgGRSJ53UZxSTlR4fLjg2UL
|
||||||
|
zYm9MATMQ5wPfpYMKcIFDGLy3sf7RwCNpMwk+tuEq+vdBPMo85BxflQMDVBHEM9+
|
||||||
|
1zpIG9sKxvWJVLY89LnmeHZYZi/nboTsOUQSVgPIkVLmx1vljXMT3jzd+FHxCx+c
|
||||||
|
bnmOB8DnMrpYJWV9SFP+KmNlGkf3ys65bPPPF1g7ZUDLAoHAaKHqtQa1Imr0NED1
|
||||||
|
kUB6AHArjrlbhXQuck+5f4R1jbIm8RR1Exj2AunT6Cl60t8Bg1MNRWRt8DxgwhD5
|
||||||
|
u9NMDezKP6GrWacwIytlQSGW3aFm/EfQs/WVG10V3LmzOEPnJI+s63GPfG6JT0tg
|
||||||
|
7DgtFxhuKaTfAri45iueoq6SqSCb7Brv01dTL/QA1E+r7RF4Z3S8HYM0qDVpvegq
|
||||||
|
Wn7DZlDSm7htioUzeZgJPwsm3BwC8Kv4x9MY8g6/cU8LKEyxAoHAWCaDpLIuQ51r
|
||||||
|
PeL+u/1cfNdi6OOrtZ6S95tu3Vv+mYzpCPnOpgPHFp3l+RGmLg56t7uvHFaFxOvB
|
||||||
|
EjPm4bVhnPSA7pl7ZHQXhinG9+4UgcejoCAJzfg05BI1tMbwFZ+C0tG/PNzBlaX+
|
||||||
|
IwkGO8VP/54N6wL1UqfZ8AKJFZW8G7W7KVkjqye1FS4oeDlJ197t/X+PMn5sFAc7
|
||||||
|
UVsDaSelBqpsfmetXSH8KC3XkbgCtHvgAnJDkGkp84VmJvMr5ukv
|
||||||
|
-----END RSA PRIVATE KEY-----
|
||||||
|
`
|
||||||
|
|
||||||
|
type M = map[string]interface{}
|
||||||
|
|
||||||
|
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(middleware.Logger)
|
||||||
|
r.Post("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"access_token": "ACCESSTOKEN",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "REFRESHTOKEN",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Route("/groups", func(r chi.Router) {
|
||||||
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Query().Get("userKey") {
|
||||||
|
case "user1":
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"kind": "admin#directory#groups",
|
||||||
|
"groups": []M{
|
||||||
|
{"id": "group1"},
|
||||||
|
{"id": "group2"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Route("/users", func(r chi.Router) {
|
||||||
|
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch chi.URLParam(r, "user_id") {
|
||||||
|
case "user1":
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"kind": "admin#directory#user",
|
||||||
|
"id": "1",
|
||||||
|
"name": M{
|
||||||
|
"fullName": "User 1",
|
||||||
|
},
|
||||||
|
"primaryEmail": "user1@example.com",
|
||||||
|
})
|
||||||
|
case "user2":
|
||||||
|
http.Error(w, "forbidden", http.StatusForbidden)
|
||||||
|
default:
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_User(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
var mockAPI http.Handler
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
mockAPI = newMockAPI(t, srv)
|
||||||
|
|
||||||
|
p := New(WithServiceAccount(&ServiceAccount{
|
||||||
|
Type: "service_account",
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
TokenURL: srv.URL + "/token",
|
||||||
|
}), WithURL(srv.URL))
|
||||||
|
|
||||||
|
du, err := p.User(ctx, "user1", "")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, "user1", du.Id)
|
||||||
|
assert.Equal(t, "user1@example.com", du.Email)
|
||||||
|
assert.Equal(t, "User 1", du.DisplayName)
|
||||||
|
assert.Equal(t, []string{"group1", "group2"}, du.GroupIds)
|
||||||
|
|
||||||
|
du, err = p.User(ctx, "user2", "")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, "user2", du.Id)
|
||||||
|
}
|
|
@ -108,6 +108,36 @@ func New(options ...Option) *Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
if p.cfg.serviceAccount == nil {
|
||||||
|
return nil, ErrServiceAccountNotDefined
|
||||||
|
}
|
||||||
|
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := p.getUser(ctx, providerUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
du.DisplayName = au.getDisplayName()
|
||||||
|
du.Email = au.Profile.Email
|
||||||
|
|
||||||
|
groups, err := p.listUserGroups(ctx, providerUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, g := range groups {
|
||||||
|
du.GroupIds = append(du.GroupIds, g.ID)
|
||||||
|
}
|
||||||
|
sort.Strings(du.GroupIds)
|
||||||
|
|
||||||
|
return du, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroups fetches the groups of which the user is a member
|
// UserGroups fetches the groups of which the user is a member
|
||||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
|
@ -159,7 +189,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
users = append(users, &directory.User{
|
users = append(users, &directory.User{
|
||||||
Id: databroker.GetUserID(Name, u.ID),
|
Id: databroker.GetUserID(Name, u.ID),
|
||||||
GroupIds: groups,
|
GroupIds: groups,
|
||||||
DisplayName: u.Profile.FirstName + " " + u.Profile.LastName,
|
DisplayName: u.getDisplayName(),
|
||||||
Email: u.Profile.Email,
|
Email: u.Profile.Email,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -183,14 +213,7 @@ func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||||
|
|
||||||
groupURL := p.cfg.providerURL.ResolveReference(u).String()
|
groupURL := p.cfg.providerURL.ResolveReference(u).String()
|
||||||
for groupURL != "" {
|
for groupURL != "" {
|
||||||
var out []struct {
|
var out []apiGroupObject
|
||||||
ID string `json:"id"`
|
|
||||||
Profile struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
} `json:"profile"`
|
|
||||||
LastUpdated string `json:"lastUpdated"`
|
|
||||||
LastMembershipUpdated string `json:"lastMembershipUpdated"`
|
|
||||||
}
|
|
||||||
hdrs, err := p.apiGet(ctx, groupURL, &out)
|
hdrs, err := p.apiGet(ctx, groupURL, &out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
|
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
|
||||||
|
@ -239,6 +262,36 @@ func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users [
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getUser(ctx context.Context, userID string) (*apiUserObject, error) {
|
||||||
|
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/api/v1/users/%s", userID),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
var out apiUserObject
|
||||||
|
_, err := p.apiGet(ctx, apiURL, &out)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("okta: error querying for user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) listUserGroups(ctx context.Context, userID string) (groups []apiGroupObject, err error) {
|
||||||
|
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/api/v1/users/%s/groups", userID),
|
||||||
|
}).String()
|
||||||
|
for apiURL != "" {
|
||||||
|
var out []apiGroupObject
|
||||||
|
hdrs, err := p.apiGet(ctx, apiURL, &out)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("okta: error querying for user groups: %w", err)
|
||||||
|
}
|
||||||
|
groups = append(groups, out...)
|
||||||
|
apiURL = getNextLink(hdrs)
|
||||||
|
}
|
||||||
|
return groups, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -334,7 +387,16 @@ func (err *APIError) Error() string {
|
||||||
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
|
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
type apiUserObject struct {
|
type (
|
||||||
|
apiGroupObject struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Profile struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"profile"`
|
||||||
|
LastUpdated string `json:"lastUpdated"`
|
||||||
|
LastMembershipUpdated string `json:"lastMembershipUpdated"`
|
||||||
|
}
|
||||||
|
apiUserObject struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Profile struct {
|
Profile struct {
|
||||||
FirstName string `json:"firstName"`
|
FirstName string `json:"firstName"`
|
||||||
|
@ -342,3 +404,8 @@ type apiUserObject struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
} `json:"profile"`
|
} `json:"profile"`
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (obj *apiUserObject) getDisplayName() string {
|
||||||
|
return obj.Profile.FirstName + " " + obj.Profile.LastName
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tomnomnom/linkheader"
|
"github.com/tomnomnom/linkheader"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,7 +44,9 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
r.Get("/api/v1/groups", func(w http.ResponseWriter, r *http.Request) {
|
r.Route("/api/v1", func(r chi.Router) {
|
||||||
|
r.Route("/groups", func(r chi.Router) {
|
||||||
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ")
|
lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ")
|
||||||
var groups []string
|
var groups []string
|
||||||
for group := range getAllGroups() {
|
for group := range getAllGroups() {
|
||||||
|
@ -86,7 +89,7 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
|
||||||
|
|
||||||
_ = json.NewEncoder(w).Encode(result)
|
_ = json.NewEncoder(w).Encode(result)
|
||||||
})
|
})
|
||||||
r.Get("/api/v1/groups/{group}/users", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/{group}/users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
group := chi.URLParam(r, "group")
|
group := chi.URLParam(r, "group")
|
||||||
|
|
||||||
if _, ok := getAllGroups()[group]; !ok {
|
if _, ok := getAllGroups()[group]; !ok {
|
||||||
|
@ -122,9 +125,61 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht
|
||||||
|
|
||||||
_ = json.NewEncoder(w).Encode(result)
|
_ = json.NewEncoder(w).Encode(result)
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
r.Route("/users", func(r chi.Router) {
|
||||||
|
r.Get("/{user_id}/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var groups []apiGroupObject
|
||||||
|
for _, nm := range userEmailToGroups[chi.URLParam(r, "user_id")] {
|
||||||
|
obj := apiGroupObject{
|
||||||
|
ID: nm,
|
||||||
|
}
|
||||||
|
obj.Profile.Name = nm
|
||||||
|
groups = append(groups, obj)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(groups)
|
||||||
|
})
|
||||||
|
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := apiUserObject{
|
||||||
|
ID: chi.URLParam(r, "user_id"),
|
||||||
|
}
|
||||||
|
user.Profile.Email = chi.URLParam(r, "user_id")
|
||||||
|
user.Profile.FirstName = "first"
|
||||||
|
user.Profile.LastName = "last"
|
||||||
|
_ = json.NewEncoder(w).Encode(user)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProvider_User(t *testing.T) {
|
||||||
|
var mockOkta http.Handler
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mockOkta.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
mockOkta = newMockOkta(srv, map[string][]string{
|
||||||
|
"a@example.com": {"user", "admin"},
|
||||||
|
"b@example.com": {"user", "test"},
|
||||||
|
"c@example.com": {"user"},
|
||||||
|
})
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
|
||||||
|
WithProviderURL(mustParseURL(srv.URL)),
|
||||||
|
)
|
||||||
|
user, err := p.User(context.Background(), "okta/a@example.com", "")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
"id": "okta/a@example.com",
|
||||||
|
"groupIds": ["admin","user"],
|
||||||
|
"displayName": "first last",
|
||||||
|
"email": "a@example.com"
|
||||||
|
}`, user)
|
||||||
|
}
|
||||||
|
|
||||||
func TestProvider_UserGroups(t *testing.T) {
|
func TestProvider_UserGroups(t *testing.T) {
|
||||||
var mockOkta http.Handler
|
var mockOkta http.Handler
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -94,6 +94,32 @@ func New(options ...Option) *Provider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User returns the user record for the given id.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
if p.cfg.serviceAccount == nil {
|
||||||
|
return nil, fmt.Errorf("onelogin: service account not defined")
|
||||||
|
}
|
||||||
|
_, providerUserID := databroker.FromUserID(userID)
|
||||||
|
du := &directory.User{
|
||||||
|
Id: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := p.getToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := p.getUser(ctx, token.AccessToken, providerUserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
du.DisplayName = au.getDisplayName()
|
||||||
|
du.Email = au.Email
|
||||||
|
du.GroupIds = []string{strconv.Itoa(au.GroupID)}
|
||||||
|
|
||||||
|
return du, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserGroups gets the directory user groups for onelogin.
|
// UserGroups gets the directory user groups for onelogin.
|
||||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
if p.cfg.serviceAccount == nil {
|
if p.cfg.serviceAccount == nil {
|
||||||
|
@ -107,12 +133,12 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := p.listGroups(ctx, token)
|
groups, err := p.listGroups(ctx, token.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
apiUsers, err := p.getUsers(ctx, token)
|
apiUsers, err := p.listUsers(ctx, token.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -133,7 +159,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
return groups, users, nil
|
return groups, users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*directory.Group, error) {
|
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||||
var groups []*directory.Group
|
var groups []*directory.Group
|
||||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||||
Path: "/api/1/groups",
|
Path: "/api/1/groups",
|
||||||
|
@ -144,9 +170,9 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
nextLink, err := p.apiGet(ctx, token, apiURL, &result)
|
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("onelogin: error querying group api: %w", err)
|
return nil, fmt.Errorf("onelogin: listing groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range result {
|
for _, r := range result {
|
||||||
|
@ -161,7 +187,24 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire
|
||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUserObject, error) {
|
func (p *Provider) getUser(ctx context.Context, accessToken string, userID string) (*apiUserObject, error) {
|
||||||
|
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/api/1/users/%s", userID),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
var out []apiUserObject
|
||||||
|
_, err := p.apiGet(ctx, accessToken, apiURL, &out)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("onelogin: error getting user: %w", err)
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil, fmt.Errorf("onelogin: user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) listUsers(ctx context.Context, accessToken string) ([]apiUserObject, error) {
|
||||||
var users []apiUserObject
|
var users []apiUserObject
|
||||||
|
|
||||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||||
|
@ -170,9 +213,9 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser
|
||||||
}).String()
|
}).String()
|
||||||
for apiURL != "" {
|
for apiURL != "" {
|
||||||
var result []apiUserObject
|
var result []apiUserObject
|
||||||
nextLink, err := p.apiGet(ctx, token, apiURL, &result)
|
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("onelogin: listing users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
users = append(users, result...)
|
users = append(users, result...)
|
||||||
|
@ -182,12 +225,12 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) apiGet(ctx context.Context, token *oauth2.Token, uri string, out interface{}) (nextLink string, err error) {
|
func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, out interface{}) (nextLink string, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", token.AccessToken))
|
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", accessToken))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := p.cfg.httpClient.Do(req)
|
res, err := p.cfg.httpClient.Do(req)
|
||||||
|
@ -307,3 +350,7 @@ type apiUserObject struct {
|
||||||
FirstName string `json:"firstname"`
|
FirstName string `json:"firstname"`
|
||||||
LastName string `json:"lastname"`
|
LastName string `json:"lastname"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (obj *apiUserObject) getDisplayName() string {
|
||||||
|
return obj.FirstName + " " + obj.LastName
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -102,6 +103,28 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han
|
||||||
|
|
||||||
_ = json.NewEncoder(w).Encode(result)
|
_ = json.NewEncoder(w).Encode(result)
|
||||||
})
|
})
|
||||||
|
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userIDToGroupID := map[int]int{}
|
||||||
|
for userID, groupName := range userIDToGroupName {
|
||||||
|
for id, n := range allGroups {
|
||||||
|
if groupName == n {
|
||||||
|
userIDToGroupID[userID] = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, _ := strconv.Atoi(chi.URLParam(r, "user_id"))
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"data": []M{{
|
||||||
|
"id": userID,
|
||||||
|
"email": userIDToGroupName[userID] + "@example.com",
|
||||||
|
"group_id": userIDToGroupID[userID],
|
||||||
|
"firstname": "User",
|
||||||
|
"lastname": fmt.Sprint(userID),
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
})
|
||||||
r.Get("/users", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||||
userIDToGroupID := map[int]int{}
|
userIDToGroupID := map[int]int{}
|
||||||
for userID, groupName := range userIDToGroupName {
|
for userID, groupName := range userIDToGroupName {
|
||||||
|
@ -130,6 +153,37 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProvider_User(t *testing.T) {
|
||||||
|
var mockAPI http.Handler
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mockAPI.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
mockAPI = newMockAPI(srv, map[int]string{
|
||||||
|
111: "admin",
|
||||||
|
222: "test",
|
||||||
|
333: "user",
|
||||||
|
})
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithServiceAccount(&ServiceAccount{
|
||||||
|
ClientID: "CLIENTID",
|
||||||
|
ClientSecret: "CLIENTSECRET",
|
||||||
|
}),
|
||||||
|
WithURL(mustParseURL(srv.URL)),
|
||||||
|
)
|
||||||
|
user, err := p.User(context.Background(), "onelogin/111", "ACCESSTOKEN")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
"id": "onelogin/111",
|
||||||
|
"groupIds": ["0"],
|
||||||
|
"displayName": "User 111",
|
||||||
|
"email": "admin@example.com"
|
||||||
|
}`, user)
|
||||||
|
}
|
||||||
|
|
||||||
func TestProvider_UserGroups(t *testing.T) {
|
func TestProvider_UserGroups(t *testing.T) {
|
||||||
var mockAPI http.Handler
|
var mockAPI http.Handler
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -28,8 +28,12 @@ type User = directory.User
|
||||||
// Options are the options specific to the provider.
|
// Options are the options specific to the provider.
|
||||||
type Options = directory.Options
|
type Options = directory.Options
|
||||||
|
|
||||||
|
// RegisterDirectoryServiceServer registers the directory gRPC service.
|
||||||
|
var RegisterDirectoryServiceServer = directory.RegisterDirectoryServiceServer
|
||||||
|
|
||||||
// A Provider provides user group directory information.
|
// A Provider provides user group directory information.
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
|
User(ctx context.Context, userID, accessToken string) (*User, error)
|
||||||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +59,7 @@ func GetProvider(options Options) (provider Provider) {
|
||||||
|
|
||||||
switch options.Provider {
|
switch options.Provider {
|
||||||
case auth0.Name:
|
case auth0.Name:
|
||||||
serviceAccount, err := auth0.ParseServiceAccount(options.ServiceAccount)
|
serviceAccount, err := auth0.ParseServiceAccount(options)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return auth0.New(
|
return auth0.New(
|
||||||
auth0.WithDomain(options.ProviderURL),
|
auth0.WithDomain(options.ProviderURL),
|
||||||
|
@ -139,6 +143,10 @@ func GetProvider(options Options) (provider Provider) {
|
||||||
|
|
||||||
type nullProvider struct{}
|
type nullProvider struct{}
|
||||||
|
|
||||||
|
func (nullProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
|
func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,22 @@
|
||||||
// Package databroker contains databroker protobuf definitions.
|
// Package databroker contains databroker protobuf definitions.
|
||||||
package databroker
|
package databroker
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
// GetUserID gets the databroker user id from a provider user id.
|
// GetUserID gets the databroker user id from a provider user id.
|
||||||
func GetUserID(provider, providerUserID string) string {
|
func GetUserID(provider, providerUserID string) string {
|
||||||
return provider + "/" + providerUserID
|
return provider + "/" + providerUserID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FromUserID gets the provider and provider user id from a databroker user id.
|
||||||
|
func FromUserID(userID string) (provider, providerUserID string) {
|
||||||
|
ps := strings.SplitN(userID, "/", 2)
|
||||||
|
if len(ps) < 2 {
|
||||||
|
return "", userID
|
||||||
|
}
|
||||||
|
return ps[0], ps[1]
|
||||||
|
}
|
||||||
|
|
||||||
// ApplyOffsetAndLimit applies the offset and limit to the list of records.
|
// ApplyOffsetAndLimit applies the offset and limit to the list of records.
|
||||||
func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) {
|
func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) {
|
||||||
records = all
|
records = all
|
||||||
|
|
|
@ -7,7 +7,12 @@
|
||||||
package directory
|
package directory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context "context"
|
||||||
proto "github.com/golang/protobuf/proto"
|
proto "github.com/golang/protobuf/proto"
|
||||||
|
empty "github.com/golang/protobuf/ptypes/empty"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
codes "google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
@ -175,29 +180,97 @@ func (x *Group) GetEmail() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RefreshUserRequest struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||||
|
AccessToken string `protobuf:"bytes,2,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RefreshUserRequest) Reset() {
|
||||||
|
*x = RefreshUserRequest{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_directory_proto_msgTypes[2]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RefreshUserRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RefreshUserRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *RefreshUserRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_directory_proto_msgTypes[2]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use RefreshUserRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*RefreshUserRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_directory_proto_rawDescGZIP(), []int{2}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RefreshUserRequest) GetUserId() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.UserId
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RefreshUserRequest) GetAccessToken() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.AccessToken
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
var File_directory_proto protoreflect.FileDescriptor
|
var File_directory_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
var file_directory_proto_rawDesc = []byte{
|
var file_directory_proto_rawDesc = []byte{
|
||||||
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||||
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x86, 0x01, 0x0a,
|
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x1b, 0x67, 0x6f,
|
||||||
0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
|
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d,
|
||||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12,
|
0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x04, 0x55, 0x73,
|
||||||
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
|
0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
|
||||||
0x1b, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03,
|
0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02,
|
||||||
0x28, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c,
|
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09,
|
||||||
0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01,
|
0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52,
|
||||||
0x28, 0x09, 0x52, 0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12,
|
0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x69, 0x73,
|
||||||
0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||||
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18,
|
0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
|
||||||
0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
|
||||||
0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02,
|
0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76,
|
||||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65,
|
0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65,
|
||||||
0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
|
0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||||
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
|
0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20,
|
||||||
0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
|
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61,
|
||||||
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22,
|
||||||
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65,
|
0x50, 0x0a, 0x12, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
|
||||||
0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64,
|
||||||
|
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21,
|
||||||
|
0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02,
|
||||||
|
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65,
|
||||||
|
0x6e, 0x32, 0x58, 0x0a, 0x10, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x53, 0x65,
|
||||||
|
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68,
|
||||||
|
0x55, 0x73, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79,
|
||||||
|
0x2e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75,
|
||||||
|
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
|
||||||
|
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x31, 0x5a, 0x2f, 0x67,
|
||||||
|
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||||
|
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
|
||||||
|
0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06,
|
||||||
|
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -212,14 +285,18 @@ func file_directory_proto_rawDescGZIP() []byte {
|
||||||
return file_directory_proto_rawDescData
|
return file_directory_proto_rawDescData
|
||||||
}
|
}
|
||||||
|
|
||||||
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
|
||||||
var file_directory_proto_goTypes = []interface{}{
|
var file_directory_proto_goTypes = []interface{}{
|
||||||
(*User)(nil), // 0: directory.User
|
(*User)(nil), // 0: directory.User
|
||||||
(*Group)(nil), // 1: directory.Group
|
(*Group)(nil), // 1: directory.Group
|
||||||
|
(*RefreshUserRequest)(nil), // 2: directory.RefreshUserRequest
|
||||||
|
(*empty.Empty)(nil), // 3: google.protobuf.Empty
|
||||||
}
|
}
|
||||||
var file_directory_proto_depIdxs = []int32{
|
var file_directory_proto_depIdxs = []int32{
|
||||||
0, // [0:0] is the sub-list for method output_type
|
2, // 0: directory.DirectoryService.RefreshUser:input_type -> directory.RefreshUserRequest
|
||||||
0, // [0:0] is the sub-list for method input_type
|
3, // 1: directory.DirectoryService.RefreshUser:output_type -> google.protobuf.Empty
|
||||||
|
1, // [1:2] is the sub-list for method output_type
|
||||||
|
0, // [0:1] is the sub-list for method input_type
|
||||||
0, // [0:0] is the sub-list for extension type_name
|
0, // [0:0] is the sub-list for extension type_name
|
||||||
0, // [0:0] is the sub-list for extension extendee
|
0, // [0:0] is the sub-list for extension extendee
|
||||||
0, // [0:0] is the sub-list for field type_name
|
0, // [0:0] is the sub-list for field type_name
|
||||||
|
@ -255,6 +332,18 @@ func file_directory_proto_init() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
file_directory_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*RefreshUserRequest); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
type x struct{}
|
type x struct{}
|
||||||
out := protoimpl.TypeBuilder{
|
out := protoimpl.TypeBuilder{
|
||||||
|
@ -262,9 +351,9 @@ func file_directory_proto_init() {
|
||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: file_directory_proto_rawDesc,
|
RawDescriptor: file_directory_proto_rawDesc,
|
||||||
NumEnums: 0,
|
NumEnums: 0,
|
||||||
NumMessages: 2,
|
NumMessages: 3,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 0,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
GoTypes: file_directory_proto_goTypes,
|
GoTypes: file_directory_proto_goTypes,
|
||||||
DependencyIndexes: file_directory_proto_depIdxs,
|
DependencyIndexes: file_directory_proto_depIdxs,
|
||||||
|
@ -275,3 +364,83 @@ func file_directory_proto_init() {
|
||||||
file_directory_proto_goTypes = nil
|
file_directory_proto_goTypes = nil
|
||||||
file_directory_proto_depIdxs = nil
|
file_directory_proto_depIdxs = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ context.Context
|
||||||
|
var _ grpc.ClientConnInterface
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
const _ = grpc.SupportPackageIsVersion6
|
||||||
|
|
||||||
|
// DirectoryServiceClient is the client API for DirectoryService service.
|
||||||
|
//
|
||||||
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||||
|
type DirectoryServiceClient interface {
|
||||||
|
RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type directoryServiceClient struct {
|
||||||
|
cc grpc.ClientConnInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDirectoryServiceClient(cc grpc.ClientConnInterface) DirectoryServiceClient {
|
||||||
|
return &directoryServiceClient{cc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *directoryServiceClient) RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
|
||||||
|
out := new(empty.Empty)
|
||||||
|
err := c.cc.Invoke(ctx, "/directory.DirectoryService/RefreshUser", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DirectoryServiceServer is the server API for DirectoryService service.
|
||||||
|
type DirectoryServiceServer interface {
|
||||||
|
RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedDirectoryServiceServer can be embedded to have forward compatible implementations.
|
||||||
|
type UnimplementedDirectoryServiceServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*UnimplementedDirectoryServiceServer) RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method RefreshUser not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterDirectoryServiceServer(s *grpc.Server, srv DirectoryServiceServer) {
|
||||||
|
s.RegisterService(&_DirectoryService_serviceDesc, srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DirectoryService_RefreshUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(RefreshUserRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DirectoryServiceServer).RefreshUser(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/directory.DirectoryService/RefreshUser",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DirectoryServiceServer).RefreshUser(ctx, req.(*RefreshUserRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _DirectoryService_serviceDesc = grpc.ServiceDesc{
|
||||||
|
ServiceName: "directory.DirectoryService",
|
||||||
|
HandlerType: (*DirectoryServiceServer)(nil),
|
||||||
|
Methods: []grpc.MethodDesc{
|
||||||
|
{
|
||||||
|
MethodName: "RefreshUser",
|
||||||
|
Handler: _DirectoryService_RefreshUser_Handler,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Streams: []grpc.StreamDesc{},
|
||||||
|
Metadata: "directory.proto",
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ syntax = "proto3";
|
||||||
package directory;
|
package directory;
|
||||||
option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
|
option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
|
||||||
|
|
||||||
|
import "google/protobuf/empty.proto";
|
||||||
|
|
||||||
message User {
|
message User {
|
||||||
string version = 1;
|
string version = 1;
|
||||||
string id = 2;
|
string id = 2;
|
||||||
|
@ -17,3 +19,12 @@ message Group {
|
||||||
string name = 3;
|
string name = 3;
|
||||||
string email = 4;
|
string email = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message RefreshUserRequest {
|
||||||
|
string user_id = 1;
|
||||||
|
string access_token = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
service DirectoryService {
|
||||||
|
rpc RefreshUser(RefreshUserRequest) returns (google.protobuf.Empty);
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue