oauth: add minimal device auth support for ssh (#5657)

## Summary

This adds the necessary logic needed for device auth flow in ssh. The
code is not used currently; will follow up with testenv updates that can
let us test this with the mock idp.

## Related issues

<!-- For example...
- #159
-->

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [ ] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
This commit is contained in:
Joe Kralicky 2025-06-24 18:05:24 -04:00 committed by GitHub
parent db6449ecca
commit eacf19cd64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 104 additions and 31 deletions

View file

@ -20,6 +20,10 @@ type MockProvider struct {
UpdateUserInfoError error
SignInError error
SignOutError error
DeviceAuthResponse oauth2.DeviceAuthResponse
DeviceAuthError error
DeviceAccessTokenResponse oauth2.Token
DeviceAccessTokenError error
}
// Authenticate is a mocked providers function.
@ -57,6 +61,16 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string)
return mp.SignInError
}
// DeviceAuth implements Authenticator.
func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return &mp.DeviceAuthResponse, mp.DeviceAuthError
}
// DeviceAccessToken implements Authenticator.
func (mp MockProvider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError
}
// VerifyAccessToken verifies an access token.
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyAccessToken not implemented")

View file

@ -188,6 +188,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
return oidc.ErrSignoutNotImplemented
}
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
// apple does not appear to have any way of verifying access tokens

View file

@ -258,6 +258,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
return oidc.ErrSignoutNotImplemented
}
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
claims = jwtutil.Claims(map[string]any{})

View file

@ -13,6 +13,10 @@ var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented"
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
// ErrDeviceAuthNotImplemented is returned when device auth is not implemented
// by an identity provider.
var ErrDeviceAuthNotImplemented = errors.New("identity/oidc: device auth not implemented")
// ErrMissingProviderURL is returned when an identity provider requires a provider url
// does not receive one.
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")

View file

@ -133,11 +133,30 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: authenticate")
defer span.End()
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
return p.exchange(ctx, code, oa)
})
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return token, nil
}
oauth2Token, idToken, err := p.exchange(ctx, code)
func (p *Provider) authenticate(ctx context.Context, v identity.State, exchange func(oa *oauth2.Config) (*oauth2.Token, error)) (*oauth2.Token, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return nil, err
}
oauth2Token, err := exchange(oa)
if err != nil {
return nil, err
}
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
}
if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok {
v.SetRawIDToken(rawIDToken)
@ -146,46 +165,29 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
// hydrate `v` using claims inside the returned `id_token`
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
if err := idToken.Claims(v); err != nil {
err := fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
}
if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil {
err := fmt.Errorf("identity/oidc: couldn't update user info %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
}
return oauth2Token, nil
}
func (p *Provider) exchange(ctx context.Context, code string) (*oauth2.Token, *go_oidc.IDToken, error) {
func (p *Provider) exchange(ctx context.Context, code string, oa *oauth2.Config) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: token exchange")
defer span.End()
oa, err := p.GetOauthConfig()
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
}
// Exchange converts an authorization code into a token.
oauth2Token, err := oa.Exchange(ctx, code)
if err != nil {
err := fmt.Errorf("identity/oidc: token exchange failed: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
return nil, err
}
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
err := fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
}
return oauth2Token, idToken, nil
return oauth2Token, nil
}
// UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
@ -366,6 +368,40 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
return nil
}
func (p *Provider) DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
ctx, span := trace.Continue(ctx, "oidc: DeviceAuth")
defer span.End()
oa, err := p.GetOauthConfig()
if err != nil {
return nil, err
}
opts := defaultAuthCodeOptions
for k, v := range p.AuthCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
resp, err := oa.DeviceAuth(ctx, opts...)
if err != nil {
return nil, err
}
return resp, nil
}
func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: DeviceAccessToken")
defer span.End()
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
return oa.DeviceAccessToken(ctx, da)
})
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return token, nil
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()

View file

@ -41,6 +41,9 @@ type Authenticator interface {
SignIn(w http.ResponseWriter, r *http.Request, state string) error
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state State) (*oauth2.Token, error)
}
// AuthenticatorConstructor makes an Authenticator from the given options.