mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-06 19:38:09 +02:00
oauth: add minimal device auth support for ssh (#5657)
## Summary This adds the necessary logic needed for device auth flow in ssh. The code is not used currently; will follow up with testenv updates that can let us test this with the mock idp. ## Related issues <!-- For example... - #159 --> ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [ ] reference any related issues - [ ] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review
This commit is contained in:
parent
db6449ecca
commit
eacf19cd64
6 changed files with 104 additions and 31 deletions
|
@ -20,6 +20,10 @@ type MockProvider struct {
|
||||||
UpdateUserInfoError error
|
UpdateUserInfoError error
|
||||||
SignInError error
|
SignInError error
|
||||||
SignOutError error
|
SignOutError error
|
||||||
|
DeviceAuthResponse oauth2.DeviceAuthResponse
|
||||||
|
DeviceAuthError error
|
||||||
|
DeviceAccessTokenResponse oauth2.Token
|
||||||
|
DeviceAccessTokenError error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate is a mocked providers function.
|
// Authenticate is a mocked providers function.
|
||||||
|
@ -57,6 +61,16 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string)
|
||||||
return mp.SignInError
|
return mp.SignInError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceAuth implements Authenticator.
|
||||||
|
func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||||
|
return &mp.DeviceAuthResponse, mp.DeviceAuthError
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAccessToken implements Authenticator.
|
||||||
|
func (mp MockProvider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
|
||||||
|
return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyAccessToken verifies an access token.
|
// VerifyAccessToken verifies an access token.
|
||||||
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||||
return nil, fmt.Errorf("VerifyAccessToken not implemented")
|
return nil, fmt.Errorf("VerifyAccessToken not implemented")
|
||||||
|
|
|
@ -188,6 +188,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
|
||||||
return oidc.ErrSignoutNotImplemented
|
return oidc.ErrSignoutNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||||
|
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
|
||||||
|
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyAccessToken verifies an access token.
|
// VerifyAccessToken verifies an access token.
|
||||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||||
// apple does not appear to have any way of verifying access tokens
|
// apple does not appear to have any way of verifying access tokens
|
||||||
|
|
|
@ -258,6 +258,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
|
||||||
return oidc.ErrSignoutNotImplemented
|
return oidc.ErrSignoutNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||||
|
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
|
||||||
|
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyAccessToken verifies an access token.
|
// VerifyAccessToken verifies an access token.
|
||||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||||
claims = jwtutil.Claims(map[string]any{})
|
claims = jwtutil.Claims(map[string]any{})
|
||||||
|
|
|
@ -13,6 +13,10 @@ var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented"
|
||||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||||
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
|
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
|
||||||
|
|
||||||
|
// ErrDeviceAuthNotImplemented is returned when device auth is not implemented
|
||||||
|
// by an identity provider.
|
||||||
|
var ErrDeviceAuthNotImplemented = errors.New("identity/oidc: device auth not implemented")
|
||||||
|
|
||||||
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
||||||
// does not receive one.
|
// does not receive one.
|
||||||
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
|
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
|
||||||
|
|
|
@ -133,11 +133,30 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
|
||||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||||
ctx, span := trace.Continue(ctx, "oidc: authenticate")
|
ctx, span := trace.Continue(ctx, "oidc: authenticate")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
|
||||||
|
return p.exchange(ctx, code, oa)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
oauth2Token, idToken, err := p.exchange(ctx, code)
|
func (p *Provider) authenticate(ctx context.Context, v identity.State, exchange func(oa *oauth2.Config) (*oauth2.Token, error)) (*oauth2.Token, error) {
|
||||||
|
oa, err := p.GetOauthConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
oauth2Token, err := exchange(oa)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := p.getIDToken(ctx, oauth2Token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok {
|
if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok {
|
||||||
v.SetRawIDToken(rawIDToken)
|
v.SetRawIDToken(rawIDToken)
|
||||||
|
@ -146,46 +165,29 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
||||||
// hydrate `v` using claims inside the returned `id_token`
|
// hydrate `v` using claims inside the returned `id_token`
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||||
if err := idToken.Claims(v); err != nil {
|
if err := idToken.Claims(v); err != nil {
|
||||||
err := fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||||
span.SetStatus(codes.Error, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil {
|
if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil {
|
||||||
err := fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||||
span.SetStatus(codes.Error, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return oauth2Token, nil
|
return oauth2Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) exchange(ctx context.Context, code string) (*oauth2.Token, *go_oidc.IDToken, error) {
|
func (p *Provider) exchange(ctx context.Context, code string, oa *oauth2.Config) (*oauth2.Token, error) {
|
||||||
ctx, span := trace.Continue(ctx, "oidc: token exchange")
|
ctx, span := trace.Continue(ctx, "oidc: token exchange")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
oa, err := p.GetOauthConfig()
|
|
||||||
if err != nil {
|
|
||||||
span.SetStatus(codes.Error, err.Error())
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exchange converts an authorization code into a token.
|
// Exchange converts an authorization code into a token.
|
||||||
oauth2Token, err := oa.Exchange(ctx, code)
|
oauth2Token, err := oa.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
err := fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
||||||
span.SetStatus(codes.Error, err.Error())
|
span.SetStatus(codes.Error, err.Error())
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := p.getIDToken(ctx, oauth2Token)
|
return oauth2Token, nil
|
||||||
if err != nil {
|
|
||||||
err := fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
|
||||||
span.SetStatus(codes.Error, err.Error())
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return oauth2Token, idToken, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
|
// UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
|
||||||
|
@ -366,6 +368,40 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||||
|
ctx, span := trace.Continue(ctx, "oidc: DeviceAuth")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
oa, err := p.GetOauthConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := defaultAuthCodeOptions
|
||||||
|
for k, v := range p.AuthCodeOptions {
|
||||||
|
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := oa.DeviceAuth(ctx, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) {
|
||||||
|
ctx, span := trace.Continue(ctx, "oidc: DeviceAccessToken")
|
||||||
|
defer span.End()
|
||||||
|
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
|
||||||
|
return oa.DeviceAccessToken(ctx, da)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyAccessToken verifies an access token.
|
// VerifyAccessToken verifies an access token.
|
||||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||||
pp, err := p.GetProvider()
|
pp, err := p.GetProvider()
|
||||||
|
|
|
@ -41,6 +41,9 @@ type Authenticator interface {
|
||||||
|
|
||||||
SignIn(w http.ResponseWriter, r *http.Request, state string) error
|
SignIn(w http.ResponseWriter, r *http.Request, state string) error
|
||||||
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
|
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
|
||||||
|
|
||||||
|
DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
|
||||||
|
DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state State) (*oauth2.Token, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticatorConstructor makes an Authenticator from the given options.
|
// AuthenticatorConstructor makes an Authenticator from the given options.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue