mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 14:37:12 +02:00
mcp: redirect to upstream oauth2 for authentication (#5594)
This commit is contained in:
parent
5b024a8ada
commit
561b6040b5
4 changed files with 166 additions and 9 deletions
|
@ -34,6 +34,7 @@ type Handler struct {
|
|||
trace oteltrace.TracerProvider
|
||||
storage *Storage
|
||||
cipher cipher.AEAD
|
||||
relyingParties *OAuth2Configs
|
||||
}
|
||||
|
||||
func New(
|
||||
|
@ -58,6 +59,7 @@ func New(
|
|||
trace: tracerProvider,
|
||||
storage: NewStorage(client),
|
||||
cipher: cipher,
|
||||
relyingParties: NewOAuthConfig(cfg, http.DefaultClient),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -68,8 +68,13 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, id)
|
||||
if ok {
|
||||
http.Redirect(w, r, loginURL, http.StatusFound)
|
||||
} else {
|
||||
srv.AuthorizationResponse(ctx, w, r, id, v)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizationResponse generates the successful authorization response
|
||||
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2
|
||||
|
|
93
internal/mcp/oauth_config.go
Normal file
93
internal/mcp/oauth_config.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
)
|
||||
|
||||
type OAuth2Configs struct {
|
||||
cfg *config.Config
|
||||
prefix string
|
||||
httpClient *http.Client
|
||||
|
||||
buildOnce sync.Once
|
||||
perHost map[string]*oauth2.Config
|
||||
}
|
||||
|
||||
func NewOAuthConfig(
|
||||
cfg *config.Config,
|
||||
httpClient *http.Client,
|
||||
) *OAuth2Configs {
|
||||
return &OAuth2Configs{
|
||||
prefix: DefaultPrefix,
|
||||
cfg: cfg,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
||||
r.buildOnce.Do(r.build)
|
||||
|
||||
cfg, ok := r.perHost[host]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
|
||||
}
|
||||
|
||||
func (r *OAuth2Configs) build() {
|
||||
r.perHost = BuildOAuthConfig(r.cfg, r.prefix)
|
||||
}
|
||||
|
||||
// BuildOAuthConfig builds a map of OAuth2 configs per host
|
||||
func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config {
|
||||
configs := make(map[string]*oauth2.Config)
|
||||
for policy := range cfg.Options.GetAllPolicies() {
|
||||
if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil {
|
||||
continue
|
||||
}
|
||||
u, err := url.Parse(policy.GetFrom())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
host := u.Hostname()
|
||||
if _, ok := configs[host]; ok {
|
||||
continue
|
||||
}
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
|
||||
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: policy.MCP.UpstreamOAuth2.Endpoint.AuthURL,
|
||||
TokenURL: policy.MCP.UpstreamOAuth2.Endpoint.TokenURL,
|
||||
AuthStyle: authStyleEnum(policy.MCP.UpstreamOAuth2.Endpoint.AuthStyle),
|
||||
},
|
||||
RedirectURL: (&url.URL{
|
||||
Scheme: "https",
|
||||
Host: host,
|
||||
Path: path.Join(prefix, oauthCallbackEndpoint),
|
||||
}).String(),
|
||||
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
|
||||
}
|
||||
configs[host] = cfg
|
||||
}
|
||||
return configs
|
||||
}
|
||||
|
||||
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
||||
switch o {
|
||||
case config.OAuth2EndpointAuthStyleInHeader:
|
||||
return oauth2.AuthStyleInHeader
|
||||
case config.OAuth2EndpointAuthStyleInParams:
|
||||
return oauth2.AuthStyleInParams
|
||||
default:
|
||||
return oauth2.AuthStyleAutoDetect
|
||||
}
|
||||
}
|
57
internal/mcp/oauth_config_test.go
Normal file
57
internal/mcp/oauth_config_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package mcp_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/mcp"
|
||||
)
|
||||
|
||||
func TestBuildOAuthConfig(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{
|
||||
{
|
||||
From: "https://regular.example.com",
|
||||
},
|
||||
{
|
||||
From: "https://mcp1.example.com",
|
||||
MCP: &config.MCP{},
|
||||
},
|
||||
{
|
||||
From: "https://mcp2.example.com",
|
||||
MCP: &config.MCP{
|
||||
UpstreamOAuth2: &config.UpstreamOAuth2{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
Endpoint: config.OAuth2Endpoint{
|
||||
AuthURL: "https://auth.example.com/auth",
|
||||
TokenURL: "https://auth.example.com/token",
|
||||
AuthStyle: config.OAuth2EndpointAuthStyleInParams,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := mcp.BuildOAuthConfig(cfg, "/prefix")
|
||||
diff := cmp.Diff(got, map[string]*oauth2.Config{
|
||||
"mcp2.example.com": {
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://auth.example.com/auth",
|
||||
TokenURL: "https://auth.example.com/token",
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
|
||||
},
|
||||
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
|
||||
require.Empty(t, diff)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue