mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-06 19:38:09 +02:00
oauth: add minimal device auth support for ssh (#5657)
## Summary This adds the necessary logic needed for device auth flow in ssh. The code is not used currently; will follow up with testenv updates that can let us test this with the mock idp. ## Related issues <!-- For example... - #159 --> ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [ ] reference any related issues - [ ] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review
This commit is contained in:
parent
db6449ecca
commit
eacf19cd64
6 changed files with 104 additions and 31 deletions
|
@ -20,6 +20,10 @@ type MockProvider struct {
|
|||
UpdateUserInfoError error
|
||||
SignInError error
|
||||
SignOutError error
|
||||
DeviceAuthResponse oauth2.DeviceAuthResponse
|
||||
DeviceAuthError error
|
||||
DeviceAccessTokenResponse oauth2.Token
|
||||
DeviceAccessTokenError error
|
||||
}
|
||||
|
||||
// Authenticate is a mocked providers function.
|
||||
|
@ -57,6 +61,16 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string)
|
|||
return mp.SignInError
|
||||
}
|
||||
|
||||
// DeviceAuth implements Authenticator.
|
||||
func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||
return &mp.DeviceAuthResponse, mp.DeviceAuthError
|
||||
}
|
||||
|
||||
// DeviceAccessToken implements Authenticator.
|
||||
func (mp MockProvider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
|
||||
return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, fmt.Errorf("VerifyAccessToken not implemented")
|
||||
|
|
|
@ -188,6 +188,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
|
|||
return oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||
}
|
||||
|
||||
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
|
||||
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
// apple does not appear to have any way of verifying access tokens
|
||||
|
|
|
@ -258,6 +258,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
|
|||
return oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||
}
|
||||
|
||||
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
|
||||
return nil, oidc.ErrDeviceAuthNotImplemented
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
|
|
|
@ -13,6 +13,10 @@ var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented"
|
|||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
|
||||
|
||||
// ErrDeviceAuthNotImplemented is returned when device auth is not implemented
|
||||
// by an identity provider.
|
||||
var ErrDeviceAuthNotImplemented = errors.New("identity/oidc: device auth not implemented")
|
||||
|
||||
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
||||
// does not receive one.
|
||||
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")
|
||||
|
|
|
@ -133,11 +133,30 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
|
|||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||
ctx, span := trace.Continue(ctx, "oidc: authenticate")
|
||||
defer span.End()
|
||||
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
|
||||
return p.exchange(ctx, code, oa)
|
||||
})
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
oauth2Token, idToken, err := p.exchange(ctx, code)
|
||||
func (p *Provider) authenticate(ctx context.Context, v identity.State, exchange func(oa *oauth2.Config) (*oauth2.Token, error)) (*oauth2.Token, error) {
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oauth2Token, err := exchange(oa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := p.getIDToken(ctx, oauth2Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
||||
}
|
||||
|
||||
if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok {
|
||||
v.SetRawIDToken(rawIDToken)
|
||||
|
@ -146,46 +165,29 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
|||
// hydrate `v` using claims inside the returned `id_token`
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
if err := idToken.Claims(v); err != nil {
|
||||
err := fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
|
||||
}
|
||||
|
||||
if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil {
|
||||
err := fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
|
||||
}
|
||||
|
||||
return oauth2Token, nil
|
||||
}
|
||||
|
||||
func (p *Provider) exchange(ctx context.Context, code string) (*oauth2.Token, *go_oidc.IDToken, error) {
|
||||
func (p *Provider) exchange(ctx context.Context, code string, oa *oauth2.Config) (*oauth2.Token, error) {
|
||||
ctx, span := trace.Continue(ctx, "oidc: token exchange")
|
||||
defer span.End()
|
||||
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Exchange converts an authorization code into a token.
|
||||
oauth2Token, err := oa.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := p.getIDToken(ctx, oauth2Token)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return oauth2Token, idToken, nil
|
||||
return oauth2Token, nil
|
||||
}
|
||||
|
||||
// UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
|
||||
|
@ -366,6 +368,40 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
|
||||
ctx, span := trace.Continue(ctx, "oidc: DeviceAuth")
|
||||
defer span.End()
|
||||
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := defaultAuthCodeOptions
|
||||
for k, v := range p.AuthCodeOptions {
|
||||
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
resp, err := oa.DeviceAuth(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) {
|
||||
ctx, span := trace.Continue(ctx, "oidc: DeviceAccessToken")
|
||||
defer span.End()
|
||||
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
|
||||
return oa.DeviceAccessToken(ctx, da)
|
||||
})
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
pp, err := p.GetProvider()
|
||||
|
|
|
@ -41,6 +41,9 @@ type Authenticator interface {
|
|||
|
||||
SignIn(w http.ResponseWriter, r *http.Request, state string) error
|
||||
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
|
||||
|
||||
DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
|
||||
DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state State) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
// AuthenticatorConstructor makes an Authenticator from the given options.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue