pomerium/internal/mcp/host_info.go

197 lines
4.4 KiB
Go

package mcp
import (
"context"
"fmt"
"iter"
"maps"
"net/http"
"net/url"
"path"
"sync"
"golang.org/x/oauth2"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
)
type HostInfo struct {
cfg *config.Config
prefix string
httpClient *http.Client
buildOnce sync.Once
servers map[string]ServerHostInfo
clients map[string]ClientHostInfo
}
type ServerHostInfo struct {
Name string
Description string
LogoURL string
Host string
URL string
Config *oauth2.Config
}
type ClientHostInfo struct{}
func NewHostInfo(
cfg *config.Config,
httpClient *http.Client,
) *HostInfo {
return &HostInfo{
prefix: DefaultPrefix,
cfg: cfg,
httpClient: httpClient,
}
}
func (r *HostInfo) CodeExchangeForHost(
ctx context.Context,
host string,
code string,
) (*oauth2.Token, error) {
r.buildOnce.Do(r.build)
cfg, ok := r.servers[host]
if !ok || cfg.Config == nil {
return nil, fmt.Errorf("no oauth2 config for host %s", host)
}
return cfg.Config.Exchange(ctx, code)
}
func (r *HostInfo) IsMCPClientForHost(host string) bool {
r.buildOnce.Do(r.build)
_, ok := r.clients[host]
return ok
}
func (r *HostInfo) HasOAuth2ConfigForHost(host string) bool {
r.buildOnce.Do(r.build)
v, ok := r.servers[host]
return ok && v.Config != nil
}
func (r *HostInfo) GetOAuth2ConfigForHost(host string) (*oauth2.Config, bool) {
cfg, ok := r.getConfigForHost(host)
return cfg, ok
}
func (r *HostInfo) GetLoginURLForHost(host string, state string) (string, bool) {
cfg, ok := r.getConfigForHost(host)
if !ok {
return "", false
}
return cfg.AuthCodeURL(state, oauth2.AccessTypeOffline), true
}
func (r *HostInfo) All() iter.Seq[ServerHostInfo] {
r.buildOnce.Do(r.build)
return maps.Values(r.servers)
}
func (r *HostInfo) getConfigForHost(host string) (*oauth2.Config, bool) {
r.buildOnce.Do(r.build)
if v, ok := r.servers[host]; ok && v.Config != nil {
return v.Config, true
}
return nil, false
}
func (r *HostInfo) build() {
r.servers, r.clients = BuildHostInfo(r.cfg, r.prefix)
}
// BuildHostInfo indexes all policies by host
// and builds the oauth2.Config for each host if present.
func BuildHostInfo(cfg *config.Config, prefix string) (map[string]ServerHostInfo, map[string]ClientHostInfo) {
servers := make(map[string]ServerHostInfo)
clients := make(map[string]ClientHostInfo)
for policy := range cfg.Options.GetAllPolicies() {
if policy.MCP == nil {
continue
}
u, err := url.Parse(policy.GetFrom())
if err != nil {
continue
}
host := u.Hostname()
if policy.IsMCPClient() {
clients[host] = ClientHostInfo{}
continue
}
if _, ok := servers[host]; ok {
continue
}
v := ServerHostInfo{
Name: policy.Name,
Description: policy.Description,
LogoURL: policy.LogoURL,
Host: host,
URL: policy.GetFrom(),
}
if oa := policy.MCP.GetServerUpstreamOAuth2(); oa != nil {
v.Config = &oauth2.Config{
ClientID: oa.ClientID,
ClientSecret: oa.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: oa.Endpoint.AuthURL,
TokenURL: oa.Endpoint.TokenURL,
AuthStyle: authStyleEnum(oa.Endpoint.AuthStyle),
},
RedirectURL: (&url.URL{
Scheme: "https",
Host: host,
Path: path.Join(prefix, oauthCallbackEndpoint),
}).String(),
Scopes: oa.Scopes,
}
}
servers[host] = v
}
return servers, clients
}
func authStyleEnum(o config.OAuth2EndpointAuthStyle) oauth2.AuthStyle {
switch o {
case config.OAuth2EndpointAuthStyleInHeader:
return oauth2.AuthStyleInHeader
case config.OAuth2EndpointAuthStyleInParams:
return oauth2.AuthStyleInParams
default:
return oauth2.AuthStyleAutoDetect
}
}
func OAuth2TokenToPB(src *oauth2.Token) *oauth21proto.TokenResponse {
r := &oauth21proto.TokenResponse{
AccessToken: src.AccessToken,
TokenType: src.TokenType,
RefreshToken: proto.String(src.RefreshToken),
ExpiresIn: proto.Int64(src.ExpiresIn),
}
if !src.Expiry.IsZero() {
r.ExpiresAt = timestamppb.New(src.Expiry)
}
return r
}
func PBToOAuth2Token(src *oauth21proto.TokenResponse) *oauth2.Token {
token := oauth2.Token{
AccessToken: src.GetAccessToken(),
TokenType: src.GetTokenType(),
RefreshToken: src.GetRefreshToken(),
}
if src.ExpiresAt != nil {
token.Expiry = src.ExpiresAt.AsTime()
}
return &token
}