oauth: add minimal device auth support for ssh (#5657)

## Summary

This adds the necessary logic needed for device auth flow in ssh. The
code is not used currently; will follow up with testenv updates that can
let us test this with the mock idp.

## Related issues

<!-- For example...
- #159
-->

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [ ] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
This commit is contained in:
Joe Kralicky 2025-06-24 18:05:24 -04:00 committed by GitHub
parent db6449ecca
commit eacf19cd64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 104 additions and 31 deletions

View file

@ -133,11 +133,30 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: authenticate")
defer span.End()
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
return p.exchange(ctx, code, oa)
})
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return token, nil
}
oauth2Token, idToken, err := p.exchange(ctx, code)
func (p *Provider) authenticate(ctx context.Context, v identity.State, exchange func(oa *oauth2.Config) (*oauth2.Token, error)) (*oauth2.Token, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return nil, err
}
oauth2Token, err := exchange(oa)
if err != nil {
return nil, err
}
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
}
if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok {
v.SetRawIDToken(rawIDToken)
@ -146,46 +165,29 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
// hydrate `v` using claims inside the returned `id_token`
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
if err := idToken.Claims(v); err != nil {
err := fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err)
}
if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil {
err := fmt.Errorf("identity/oidc: couldn't update user info %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err)
}
return oauth2Token, nil
}
func (p *Provider) exchange(ctx context.Context, code string) (*oauth2.Token, *go_oidc.IDToken, error) {
func (p *Provider) exchange(ctx context.Context, code string, oa *oauth2.Config) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: token exchange")
defer span.End()
oa, err := p.GetOauthConfig()
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
}
// Exchange converts an authorization code into a token.
oauth2Token, err := oa.Exchange(ctx, code)
if err != nil {
err := fmt.Errorf("identity/oidc: token exchange failed: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
return nil, err
}
idToken, err := p.getIDToken(ctx, oauth2Token)
if err != nil {
err := fmt.Errorf("identity/oidc: failed getting id_token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, nil, err
}
return oauth2Token, idToken, nil
return oauth2Token, nil
}
// UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any
@ -366,6 +368,40 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
return nil
}
func (p *Provider) DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
ctx, span := trace.Continue(ctx, "oidc: DeviceAuth")
defer span.End()
oa, err := p.GetOauthConfig()
if err != nil {
return nil, err
}
opts := defaultAuthCodeOptions
for k, v := range p.AuthCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
resp, err := oa.DeviceAuth(ctx, opts...)
if err != nil {
return nil, err
}
return resp, nil
}
func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) {
ctx, span := trace.Continue(ctx, "oidc: DeviceAccessToken")
defer span.End()
token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) {
return oa.DeviceAccessToken(ctx, da)
})
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return token, nil
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()