oauth: add minimal device auth support for ssh (#5657)

## Summary

This adds the necessary logic needed for device auth flow in ssh. The
code is not used currently; will follow up with testenv updates that can
let us test this with the mock idp.

## Related issues

<!-- For example...
- #159
-->

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [ ] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
This commit is contained in:
Joe Kralicky 2025-06-24 18:05:24 -04:00 committed by GitHub
parent db6449ecca
commit eacf19cd64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 104 additions and 31 deletions

View file

@ -12,14 +12,18 @@ import (
// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
UpdateUserInfoError error
SignInError error
SignOutError error
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
UpdateUserInfoError error
SignInError error
SignOutError error
DeviceAuthResponse oauth2.DeviceAuthResponse
DeviceAuthError error
DeviceAccessTokenResponse oauth2.Token
DeviceAccessTokenError error
}
// Authenticate is a mocked providers function.
@ -57,6 +61,16 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string)
return mp.SignInError
}
// DeviceAuth implements Authenticator.
func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return &mp.DeviceAuthResponse, mp.DeviceAuthError
}
// DeviceAccessToken implements Authenticator.
func (mp MockProvider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError
}
// VerifyAccessToken verifies an access token.
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyAccessToken not implemented")

View file

@ -188,6 +188,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
return oidc.ErrSignoutNotImplemented
}
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
// apple does not appear to have any way of verifying access tokens

View file

@ -258,6 +258,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin
return oidc.ErrSignoutNotImplemented
}
func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) {
return nil, oidc.ErrDeviceAuthNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
claims = jwtutil.Claims(map[string]any{})

View file

@ -13,6 +13,10 @@ var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented"
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented")
// ErrDeviceAuthNotImplemented is returned when device auth is not implemented
// by an identity provider.
var ErrDeviceAuthNotImplemented = errors.New("identity/oidc: device auth not implemented")
// ErrMissingProviderURL is returned when an identity provider requires a provider url
// does not receive one.
var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url")

View file

@ -133,11 +133,30 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: authenticate")
defer span.End()
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
return p.exchange(ctx, code, oa)
})
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return token, nil
}
oauth2Token, idToken, err := p.exchange(ctx, code)
func (p *Provider) authenticate(ctx context.Context, v identity.State, exchange func(oa *oauth2.Config) (*oauth2.Token, error)) (*oauth2.Token, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return nil, err
}
oauth2Token, err := exchange(oa)
if err != nil {
return nil, err
}
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
}
if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok {
v.SetRawIDToken(rawIDToken)
@ -146,46 +165,29 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
// hydrate `v` using claims inside the returned `id_token`
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
if err := idToken.Claims(v); err != nil {
err := fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
}
if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil {
err := fmt.Errorf("identity/oidc: couldn't update user info %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
}
return oauth2Token, nil
}
func (p *Provider) exchange(ctx context.Context, code string) (*oauth2.Token, *go_oidc.IDToken, error) {
func (p *Provider) exchange(ctx context.Context, code string, oa *oauth2.Config) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: token exchange")
defer span.End()
oa, err := p.GetOauthConfig()
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
}
// Exchange converts an authorization code into a token.
oauth2Token, err := oa.Exchange(ctx, code)
if err != nil {
err := fmt.Errorf("identity/oidc: token exchange failed: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
return nil, err
}
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
err := fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
}
return oauth2Token, idToken, nil
return oauth2Token, nil
}
// UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
@ -366,6 +368,40 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
return nil
}
func (p *Provider) DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
ctx, span := trace.Continue(ctx, "oidc: DeviceAuth")
defer span.End()
oa, err := p.GetOauthConfig()
if err != nil {
return nil, err
}
opts := defaultAuthCodeOptions
for k, v := range p.AuthCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
resp, err := oa.DeviceAuth(ctx, opts...)
if err != nil {
return nil, err
}
return resp, nil
}
func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: DeviceAccessToken")
defer span.End()
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
return oa.DeviceAccessToken(ctx, da)
})
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return token, nil
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()

View file

@ -41,6 +41,9 @@ type Authenticator interface {
SignIn(w http.ResponseWriter, r *http.Request, state string) error
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state State) (*oauth2.Token, error)
}
// AuthenticatorConstructor makes an Authenticator from the given options.