all: support route scoped sessions

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-11-06 17:30:27 -08:00 committed by Bobby DeSimone
parent 83342112bb
commit d3d60d1055
53 changed files with 2092 additions and 2416 deletions

View file

@ -2,14 +2,9 @@ package identity // import "github.com/pomerium/pomerium/internal/identity"
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
@ -17,6 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
)
@ -26,7 +22,8 @@ import (
type OktaProvider struct {
*Provider
RevokeURL *url.URL
RevokeURL string `json:"revocation_endpoint"`
userAPI *url.URL
}
// NewOktaProvider creates a new instance of Okta as an identity provider.
@ -53,80 +50,62 @@ func NewOktaProvider(p *Provider) (*OktaProvider, error) {
}
// okta supports a revocation endpoint
var claims struct {
RevokeURL string `json:"revocation_endpoint"`
}
if err := p.provider.Claims(&claims); err != nil {
return nil, err
}
oktaProvider := OktaProvider{Provider: p}
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
if err != nil {
if err := p.provider.Claims(&oktaProvider); err != nil {
return nil, err
}
if p.ServiceAccount != "" {
p.UserGroupFn = oktaProvider.UserGroups
userAPI, err := urlutil.ParseAndValidateURL(p.ProviderURL)
if err != nil {
return nil, err
}
userAPI.Path = "/api/v1/users/"
oktaProvider.userAPI = userAPI
} else {
log.Warn().Msg("identity/okta: api token provided, cannot retrieve groups")
}
return &oktaProvider, nil
}
// Revoke revokes the access token a given session state.
// https://developer.okta.com/docs/api/resources/oidc#revoke
func (p *OktaProvider) Revoke(token string) error {
func (p *OktaProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("token", token)
params.Add("token", token.AccessToken)
params.Add("token_type_hint", "refresh_token")
err := httputil.Client(http.MethodPost, p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
return err
}
return nil
}
type accessToken struct {
Subject string `json:"sub"`
Groups []string `json:"groups"`
}
// Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed. If configured properly, Okta is we can configure the access token
// to include group membership claims which allows us to avoid a follow up oauth2 call.
func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" {
return nil, errors.New("identity/okta: missing refresh token")
// UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *OktaProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
var response []struct {
ID string `json:"id"`
Profile struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"profile"`
}
t := oauth2.Token{RefreshToken: s.RefreshToken}
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.ServiceAccount)}
err := httputil.Client(ctx, http.MethodGet, fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject), version.UserAgent(), headers, nil, &response)
if err != nil {
log.Error().Err(err).Msg("identity/okta: refresh failed")
return nil, err
}
payload, err := parseJWT(newToken.AccessToken)
if err != nil {
return nil, fmt.Errorf("identity/okta: malformed access token jwt: %v", err)
var groups []string
for _, group := range response {
log.Debug().Interface("group", group).Msg("identity/okta: group")
groups = append(groups, group.ID)
}
var token accessToken
if err := json.Unmarshal(payload, &token); err != nil {
return nil, fmt.Errorf("identity/okta: failed to unmarshal access token claims: %v", err)
}
if len(token.Groups) != 0 {
s.Groups = token.Groups
}
s.AccessToken = newToken.AccessToken
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
return s, nil
}
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
return groups, nil
}