pomerium/internal/testutil/mockidp/mockidp.go
Kenneth Jenkins 6a65c52a6c
refactor testenv mock IdP to also work standalone (#5678)
Refactor the testenv mock IdP implementation to split off the core
functionality from the testenv environment setup. Add a Start() method
to run the mock IdP as an httptest server, tied to a test lifecycle.
This allows the mock IdP to be used also in tests that do not start a
full Pomerium instance.
2025-06-27 14:04:58 -07:00

383 lines
9.1 KiB
Go

package mockidp
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/google/uuid"
"github.com/gorilla/mux"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/encoding"
)
type IDP struct {
publicJWK jose.JSONWebKey
signingKey jose.SigningKey
stateEncoder encoding.MarshalUnmarshaler
userLookup map[string]*User
enableDeviceAuth bool
}
type Config struct {
Users []*User
EnableDeviceAuth bool
}
func New(cfg Config) *IDP {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
publicKey := &privateKey.PublicKey
signingKey := jose.SigningKey{
Algorithm: jose.ES256,
Key: privateKey,
}
publicJWK := jose.JSONWebKey{
Key: publicKey,
Algorithm: string(jose.ES256),
Use: "sig",
}
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
if err != nil {
panic(err)
}
publicJWK.KeyID = hex.EncodeToString(thumbprint)
userLookup := map[string]*User{}
for _, user := range cfg.Users {
user.ID = uuid.NewString()
userLookup[user.ID] = user
}
return &IDP{
publicJWK: publicJWK,
signingKey: signingKey,
userLookup: userLookup,
}
}
func (idp *IDP) Start(t *testing.T) string {
r := mux.NewRouter()
idp.Register(r)
server := httptest.NewServer(r)
t.Cleanup(server.Close)
return server.URL
}
func (idp *IDP) Register(router *mux.Router) {
router.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{idp.publicJWK},
})
})
router.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
rootURL := getRootURL(r)
config := map[string]interface{}{
"issuer": rootURL.String(),
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
"id_token_signing_alg_values_supported": []string{
"ES256",
},
}
if idp.enableDeviceAuth {
config["device_authorization_endpoint"] = rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String()
}
_ = json.NewEncoder(w).Encode(config)
})
router.HandleFunc("/oidc/auth", idp.handleAuth)
if idp.enableDeviceAuth {
router.HandleFunc("/oidc/device/code", idp.handleDeviceCode)
}
router.HandleFunc("/oidc/token", idp.handleToken)
router.HandleFunc("/oidc/userinfo", idp.handleUserInfo)
}
// handleAuth handles the auth flow for OIDC.
func (idp *IDP) handleAuth(w http.ResponseWriter, r *http.Request) {
rawRedirectURI := r.FormValue("redirect_uri")
if rawRedirectURI == "" {
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
return
}
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
rawClientID := r.FormValue("client_id")
if rawClientID == "" {
http.Error(w, "missing client_id", http.StatusBadRequest)
return
}
rawEmail := r.FormValue("email")
if rawEmail != "" {
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
RawQuery: (url.Values{
"state": {r.FormValue("state")},
"code": {state{
Email: rawEmail,
ClientID: rawClientID,
}.Encode()},
}).Encode(),
}).String(), http.StatusFound)
return
}
serveHTML(w, `<!doctype html>
<html>
<head>
<title>Login</title>
</head>
<body>
<form method="POST" style="max-width: 200px">
<fieldset>
<legend>Login</legend>
<table>
<tbody>
<tr>
<th><label for="email">Email</label></th>
<td>
<input type="email" name="email" placeholder="email" />
</td>
</tr>
<tr>
<td colspan="2">
<input type="submit" />
</td>
</tr>
</tbody>
</table>
</fieldset>
</form>
</body>
</html>
`)
}
// handleToken handles the token flow for OIDC.
func (idp *IDP) handleToken(w http.ResponseWriter, r *http.Request) {
if idp.enableDeviceAuth && r.FormValue("device_code") != "" {
idp.serveToken(w, r, &state{
ClientID: r.FormValue("client_id"),
Email: "fake.user@example.com",
})
return
}
rawCode := r.FormValue("code")
state, err := decodeState(rawCode)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
idp.serveToken(w, r, state)
}
func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *state) {
serveJSON(w, map[string]interface{}{
"access_token": state.Encode(),
"refresh_token": state.Encode(),
"token_type": "Bearer",
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
})
}
// handleUserInfo handles retrieving the user info.
func (idp *IDP) handleUserInfo(w http.ResponseWriter, r *http.Request) {
authz := r.Header.Get("Authorization")
if authz == "" {
http.Error(w, "missing authorization header", http.StatusUnauthorized)
return
}
if strings.HasPrefix(authz, "Bearer ") {
authz = authz[len("Bearer "):]
} else if strings.HasPrefix(authz, "token ") {
authz = authz[len("token "):]
} else {
http.Error(w, "missing bearer token", http.StatusUnauthorized)
return
}
state, err := decodeState(authz)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
serveJSON(w, state.GetUserInfo(idp.userLookup))
}
// handleDeviceCode initiates a device auth code flow.
//
// This is the bare minimum to simulate the device auth code flow. There is no client_id
// verification or any actual login.
func (idp *IDP) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS"
userCode := "ABCD-EFGH"
rootURL := getRootURL(r)
u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented
verificationURI := u.String()
u.RawQuery = "user_code=" + userCode
verificationURIComplete := u.String()
serveJSON(w, &oauth2.DeviceAuthResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: verificationURI,
VerificationURIComplete: verificationURIComplete,
Expiry: time.Now().Add(5 * time.Minute),
Interval: 1,
})
}
func getRootURL(r *http.Request) *url.URL {
u := *r.URL
if r.Host != "" {
u.Host = r.Host
}
if u.Scheme == "" {
if r.TLS != nil {
u.Scheme = "https"
} else {
u.Scheme = "http"
}
}
u.Path = ""
return &u
}
func serveHTML(w http.ResponseWriter, html string) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, html)
}
func serveJSON(w http.ResponseWriter, obj interface{}) {
bs, err := json.Marshal(obj)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bs)
}
type state struct {
Email string `json:"email"`
ClientID string `json:"client_id"`
}
func decodeState(rawCode string) (*state, error) {
var state state
bs, _ := base64.URLEncoding.DecodeString(rawCode)
err := json.Unmarshal(bs, &state)
if err != nil {
return nil, err
}
return &state, nil
}
func (state state) Encode() string {
bs, _ := json.Marshal(state)
return base64.URLEncoding.EncodeToString(bs)
}
func (state state) GetIDToken(r *http.Request, users map[string]*User) *idToken {
token := &idToken{
userInfo: state.GetUserInfo(users),
Issuer: getRootURL(r).String(),
Audience: state.ClientID,
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
IssuedAt: jwt.NewNumericDate(time.Now()),
}
return token
}
func (state state) GetUserInfo(users map[string]*User) *userInfo {
userInfo := &userInfo{
Subject: state.Email,
Email: state.Email,
}
for _, u := range users {
if u.Email == state.Email {
userInfo.Subject = u.ID
userInfo.Name = u.FirstName + " " + u.LastName
userInfo.FamilyName = u.LastName
userInfo.GivenName = u.FirstName
}
}
return userInfo
}
type userInfo struct {
Subject string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
FamilyName string `json:"family_name"`
GivenName string `json:"given_name"`
}
type idToken struct {
*userInfo
Issuer string `json:"iss"`
Audience string `json:"aud"`
Expiry *jwt.NumericDate `json:"exp"`
IssuedAt *jwt.NumericDate `json:"iat"`
}
func (token *idToken) Encode(signingKey jose.SigningKey) string {
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
panic(err)
}
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
if err != nil {
panic(err)
}
return str
}
type User struct {
ID string
Email string
FirstName string
LastName string
}