mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 22:47:14 +02:00
mcp: redirect to upstream oauth2 for authentication (#5594)
This commit is contained in:
parent
5b024a8ada
commit
561b6040b5
4 changed files with 166 additions and 9 deletions
|
@ -34,6 +34,7 @@ type Handler struct {
|
||||||
trace oteltrace.TracerProvider
|
trace oteltrace.TracerProvider
|
||||||
storage *Storage
|
storage *Storage
|
||||||
cipher cipher.AEAD
|
cipher cipher.AEAD
|
||||||
|
relyingParties *OAuth2Configs
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(
|
||||||
|
@ -58,6 +59,7 @@ func New(
|
||||||
trace: tracerProvider,
|
trace: tracerProvider,
|
||||||
storage: NewStorage(client),
|
storage: NewStorage(client),
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
|
relyingParties: NewOAuthConfig(cfg, http.DefaultClient),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,8 +68,13 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loginURL, ok := srv.relyingParties.GetLoginURLForHost(r.Host, id)
|
||||||
|
if ok {
|
||||||
|
http.Redirect(w, r, loginURL, http.StatusFound)
|
||||||
|
} else {
|
||||||
srv.AuthorizationResponse(ctx, w, r, id, v)
|
srv.AuthorizationResponse(ctx, w, r, id, v)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// AuthorizationResponse generates the successful authorization response
|
// AuthorizationResponse generates the successful authorization response
|
||||||
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2
|
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2
|
||||||
|
|
93
internal/mcp/oauth_config.go
Normal file
93
internal/mcp/oauth_config.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuth2Configs struct {
|
||||||
|
cfg *config.Config
|
||||||
|
prefix string
|
||||||
|
httpClient *http.Client
|
||||||
|
|
||||||
|
buildOnce sync.Once
|
||||||
|
perHost map[string]*oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthConfig(
|
||||||
|
cfg *config.Config,
|
||||||
|
httpClient *http.Client,
|
||||||
|
) *OAuth2Configs {
|
||||||
|
return &OAuth2Configs{
|
||||||
|
prefix: DefaultPrefix,
|
||||||
|
cfg: cfg,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OAuth2Configs) GetLoginURLForHost(host string, state string) (string, bool) {
|
||||||
|
r.buildOnce.Do(r.build)
|
||||||
|
|
||||||
|
cfg, ok := r.perHost[host]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OAuth2Configs) build() {
|
||||||
|
r.perHost = BuildOAuthConfig(r.cfg, r.prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOAuthConfig builds a map of OAuth2 configs per host
|
||||||
|
func BuildOAuthConfig(cfg *config.Config, prefix string) map[string]*oauth2.Config {
|
||||||
|
configs := make(map[string]*oauth2.Config)
|
||||||
|
for policy := range cfg.Options.GetAllPolicies() {
|
||||||
|
if !policy.IsMCPServer() || policy.MCP.UpstreamOAuth2 == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
u, err := url.Parse(policy.GetFrom())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
host := u.Hostname()
|
||||||
|
if _, ok := configs[host]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cfg := &oauth2.Config{
|
||||||
|
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
|
||||||
|
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: policy.MCP.UpstreamOAuth2.Endpoint.AuthURL,
|
||||||
|
TokenURL: policy.MCP.UpstreamOAuth2.Endpoint.TokenURL,
|
||||||
|
AuthStyle: authStyleEnum(policy.MCP.UpstreamOAuth2.Endpoint.AuthStyle),
|
||||||
|
},
|
||||||
|
RedirectURL: (&url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: host,
|
||||||
|
Path: path.Join(prefix, oauthCallbackEndpoint),
|
||||||
|
}).String(),
|
||||||
|
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
|
||||||
|
}
|
||||||
|
configs[host] = cfg
|
||||||
|
}
|
||||||
|
return configs
|
||||||
|
}
|
||||||
|
|
||||||
|
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
|
||||||
|
switch o {
|
||||||
|
case config.OAuth2EndpointAuthStyleInHeader:
|
||||||
|
return oauth2.AuthStyleInHeader
|
||||||
|
case config.OAuth2EndpointAuthStyleInParams:
|
||||||
|
return oauth2.AuthStyleInParams
|
||||||
|
default:
|
||||||
|
return oauth2.AuthStyleAutoDetect
|
||||||
|
}
|
||||||
|
}
|
57
internal/mcp/oauth_config_test.go
Normal file
57
internal/mcp/oauth_config_test.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
package mcp_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOAuthConfig(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Options: &config.Options{
|
||||||
|
Policies: []config.Policy{
|
||||||
|
{
|
||||||
|
From: "https://regular.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
From: "https://mcp1.example.com",
|
||||||
|
MCP: &config.MCP{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
From: "https://mcp2.example.com",
|
||||||
|
MCP: &config.MCP{
|
||||||
|
UpstreamOAuth2: &config.UpstreamOAuth2{
|
||||||
|
ClientID: "client_id",
|
||||||
|
ClientSecret: "client_secret",
|
||||||
|
Endpoint: config.OAuth2Endpoint{
|
||||||
|
AuthURL: "https://auth.example.com/auth",
|
||||||
|
TokenURL: "https://auth.example.com/token",
|
||||||
|
AuthStyle: config.OAuth2EndpointAuthStyleInParams,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got := mcp.BuildOAuthConfig(cfg, "/prefix")
|
||||||
|
diff := cmp.Diff(got, map[string]*oauth2.Config{
|
||||||
|
"mcp2.example.com": {
|
||||||
|
ClientID: "client_id",
|
||||||
|
ClientSecret: "client_secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "https://auth.example.com/auth",
|
||||||
|
TokenURL: "https://auth.example.com/token",
|
||||||
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
|
},
|
||||||
|
RedirectURL: "https://mcp2.example.com/prefix/oauth/callback",
|
||||||
|
},
|
||||||
|
}, cmpopts.IgnoreUnexported(oauth2.Config{}))
|
||||||
|
require.Empty(t, diff)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue