From eacf19cd64236e515c650cbddb7b25f77cc0c0c6 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Tue, 24 Jun 2025 18:05:24 -0400 Subject: [PATCH] oauth: add minimal device auth support for ssh (#5657) ## Summary This adds the necessary logic needed for device auth flow in ssh. The code is not used currently; will follow up with testenv updates that can let us test this with the mock idp. ## Related issues ## User Explanation ## Checklist - [ ] reference any related issues - [ ] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review --- pkg/identity/mock_provider.go | 30 ++++++++--- pkg/identity/oauth/apple/apple.go | 8 +++ pkg/identity/oauth/github/github.go | 8 +++ pkg/identity/oidc/errors.go | 4 ++ pkg/identity/oidc/oidc.go | 82 +++++++++++++++++++++-------- pkg/identity/providers.go | 3 ++ 6 files changed, 104 insertions(+), 31 deletions(-) diff --git a/pkg/identity/mock_provider.go b/pkg/identity/mock_provider.go index ab913dbf1..68cc7b5d3 100644 --- a/pkg/identity/mock_provider.go +++ b/pkg/identity/mock_provider.go @@ -12,14 +12,18 @@ import ( // MockProvider provides a mocked implementation of the providers interface. type MockProvider struct { - AuthenticateResponse oauth2.Token - AuthenticateError error - RefreshResponse oauth2.Token - RefreshError error - RevokeError error - UpdateUserInfoError error - SignInError error - SignOutError error + AuthenticateResponse oauth2.Token + AuthenticateError error + RefreshResponse oauth2.Token + RefreshError error + RevokeError error + UpdateUserInfoError error + SignInError error + SignOutError error + DeviceAuthResponse oauth2.DeviceAuthResponse + DeviceAuthError error + DeviceAccessTokenResponse oauth2.Token + DeviceAccessTokenError error } // Authenticate is a mocked providers function. @@ -57,6 +61,16 @@ func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) return mp.SignInError } +// DeviceAuth implements Authenticator. +func (mp MockProvider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) { + return &mp.DeviceAuthResponse, mp.DeviceAuthError +} + +// DeviceAccessToken implements Authenticator. +func (mp MockProvider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) { + return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError +} + // VerifyAccessToken verifies an access token. func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) { return nil, fmt.Errorf("VerifyAccessToken not implemented") diff --git a/pkg/identity/oauth/apple/apple.go b/pkg/identity/oauth/apple/apple.go index 051c9d224..7e132481b 100644 --- a/pkg/identity/oauth/apple/apple.go +++ b/pkg/identity/oauth/apple/apple.go @@ -188,6 +188,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin return oidc.ErrSignoutNotImplemented } +func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} + +func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} + // VerifyAccessToken verifies an access token. func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) { // apple does not appear to have any way of verifying access tokens diff --git a/pkg/identity/oauth/github/github.go b/pkg/identity/oauth/github/github.go index 8214e6987..72c810d1f 100644 --- a/pkg/identity/oauth/github/github.go +++ b/pkg/identity/oauth/github/github.go @@ -258,6 +258,14 @@ func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ strin return oidc.ErrSignoutNotImplemented } +func (p *Provider) DeviceAuth(_ context.Context) (*oauth2.DeviceAuthResponse, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} + +func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} + // VerifyAccessToken verifies an access token. func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) { claims = jwtutil.Claims(map[string]any{}) diff --git a/pkg/identity/oidc/errors.go b/pkg/identity/oidc/errors.go index be4ee9743..2b5f7d930 100644 --- a/pkg/identity/oidc/errors.go +++ b/pkg/identity/oidc/errors.go @@ -13,6 +13,10 @@ var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented" // https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented") +// ErrDeviceAuthNotImplemented is returned when device auth is not implemented +// by an identity provider. +var ErrDeviceAuthNotImplemented = errors.New("identity/oidc: device auth not implemented") + // ErrMissingProviderURL is returned when an identity provider requires a provider url // does not receive one. var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url") diff --git a/pkg/identity/oidc/oidc.go b/pkg/identity/oidc/oidc.go index d78e4c6a8..e36f8f9df 100644 --- a/pkg/identity/oidc/oidc.go +++ b/pkg/identity/oidc/oidc.go @@ -133,11 +133,30 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { ctx, span := trace.Continue(ctx, "oidc: authenticate") defer span.End() + token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) { + return p.exchange(ctx, code, oa) + }) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + return nil, err + } + return token, nil +} - oauth2Token, idToken, err := p.exchange(ctx, code) +func (p *Provider) authenticate(ctx context.Context, v identity.State, exchange func(oa *oauth2.Config) (*oauth2.Token, error)) (*oauth2.Token, error) { + oa, err := p.GetOauthConfig() if err != nil { return nil, err } + oauth2Token, err := exchange(oa) + if err != nil { + return nil, err + } + + idToken, err := p.getIDToken(ctx, oauth2Token) + if err != nil { + return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err) + } if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok { v.SetRawIDToken(rawIDToken) @@ -146,46 +165,29 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta // hydrate `v` using claims inside the returned `id_token` // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint if err := idToken.Claims(v); err != nil { - err := fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err) - span.SetStatus(codes.Error, err.Error()) - return nil, err + return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err) } if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil { - err := fmt.Errorf("identity/oidc: couldn't update user info %w", err) - span.SetStatus(codes.Error, err.Error()) - return nil, err + return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err) } return oauth2Token, nil } -func (p *Provider) exchange(ctx context.Context, code string) (*oauth2.Token, *go_oidc.IDToken, error) { +func (p *Provider) exchange(ctx context.Context, code string, oa *oauth2.Config) (*oauth2.Token, error) { ctx, span := trace.Continue(ctx, "oidc: token exchange") defer span.End() - oa, err := p.GetOauthConfig() - if err != nil { - span.SetStatus(codes.Error, err.Error()) - return nil, nil, err - } - // Exchange converts an authorization code into a token. oauth2Token, err := oa.Exchange(ctx, code) if err != nil { err := fmt.Errorf("identity/oidc: token exchange failed: %w", err) span.SetStatus(codes.Error, err.Error()) - return nil, nil, err + return nil, err } - idToken, err := p.getIDToken(ctx, oauth2Token) - if err != nil { - err := fmt.Errorf("identity/oidc: failed getting id_token: %w", err) - span.SetStatus(codes.Error, err.Error()) - return nil, nil, err - } - - return oauth2Token, idToken, nil + return oauth2Token, nil } // UpdateUserInfo calls the OIDC (spec required) UserInfo Endpoint as well as any @@ -366,6 +368,40 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, return nil } +func (p *Provider) DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) { + ctx, span := trace.Continue(ctx, "oidc: DeviceAuth") + defer span.End() + + oa, err := p.GetOauthConfig() + if err != nil { + return nil, err + } + + opts := defaultAuthCodeOptions + for k, v := range p.AuthCodeOptions { + opts = append(opts, oauth2.SetAuthURLParam(k, v)) + } + + resp, err := oa.DeviceAuth(ctx, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) { + ctx, span := trace.Continue(ctx, "oidc: DeviceAccessToken") + defer span.End() + token, err := p.authenticate(ctx, v, func(oa *oauth2.Config) (*oauth2.Token, error) { + return oa.DeviceAccessToken(ctx, da) + }) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + return nil, err + } + return token, nil +} + // VerifyAccessToken verifies an access token. func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) { pp, err := p.GetProvider() diff --git a/pkg/identity/providers.go b/pkg/identity/providers.go index e8e0339b3..430709062 100644 --- a/pkg/identity/providers.go +++ b/pkg/identity/providers.go @@ -41,6 +41,9 @@ type Authenticator interface { SignIn(w http.ResponseWriter, r *http.Request, state string) error SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error + + DeviceAuth(ctx context.Context) (*oauth2.DeviceAuthResponse, error) + DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state State) (*oauth2.Token, error) } // AuthenticatorConstructor makes an Authenticator from the given options.