mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 02:09:15 +02:00
refactor testenv mock IdP to also work standalone (#5678)
Refactor the testenv mock IdP implementation to split off the core functionality from the testenv environment setup. Add a Start() method to run the mock IdP as an httptest server, tied to a test lifecycle. This allows the mock IdP to be used also in tests that do not start a full Pomerium instance.
This commit is contained in:
parent
717a7bdf5a
commit
6a65c52a6c
3 changed files with 402 additions and 372 deletions
|
@ -2,48 +2,25 @@ package scenarios
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil/mockidp"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDP struct {
|
type IDP struct {
|
||||||
IDPOptions
|
IDPOptions
|
||||||
id values.Value[string]
|
id values.Value[string]
|
||||||
url values.Value[string]
|
url values.Value[string]
|
||||||
publicJWK jose.JSONWebKey
|
mockIDP *mockidp.IDP
|
||||||
signingKey jose.SigningKey
|
|
||||||
|
|
||||||
stateEncoder encoding.MarshalUnmarshaler
|
|
||||||
userLookup map[string]*User
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type IDPOptions struct {
|
type IDPOptions struct {
|
||||||
|
@ -92,9 +69,9 @@ func (idp *IDP) Attach(ctx context.Context) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
router := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP"))
|
up := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP"))
|
||||||
|
|
||||||
idp.url = values.Bind2(idpURL, router.Addr(), func(urlStr string, addr string) string {
|
idp.url = values.Bind2(idpURL, up.Addr(), func(urlStr string, addr string) string {
|
||||||
u, _ := url.Parse(urlStr)
|
u, _ := url.Parse(urlStr)
|
||||||
host, _, _ := net.SplitHostPort(u.Host)
|
host, _, _ := net.SplitHostPort(u.Host)
|
||||||
_, port, err := net.SplitHostPort(addr)
|
_, port, err := net.SplitHostPort(addr)
|
||||||
|
@ -105,9 +82,6 @@ func (idp *IDP) Attach(ctx context.Context) {
|
||||||
Host: fmt.Sprintf("%s:%s", host, port),
|
Host: fmt.Sprintf("%s:%s", host, port),
|
||||||
}).String()
|
}).String()
|
||||||
})
|
})
|
||||||
var err error
|
|
||||||
idp.stateEncoder, err = jws.NewHS256Signer(env.SharedSecret())
|
|
||||||
env.Require().NoError(err)
|
|
||||||
|
|
||||||
idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string {
|
idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string {
|
||||||
provider := identity.Provider{
|
provider := identity.Provider{
|
||||||
|
@ -121,37 +95,9 @@ func (idp *IDP) Attach(ctx context.Context) {
|
||||||
return provider.Hash()
|
return provider.Hash()
|
||||||
})
|
})
|
||||||
|
|
||||||
router.Handle("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
|
idp.mockIDP.Register(up.Router())
|
||||||
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
|
|
||||||
Keys: []jose.JSONWebKey{idp.publicJWK},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
|
|
||||||
rootURL, _ := url.Parse(idp.url.Value())
|
|
||||||
config := map[string]interface{}{
|
|
||||||
"issuer": rootURL.String(),
|
|
||||||
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
|
||||||
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
|
||||||
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
|
|
||||||
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
|
|
||||||
"id_token_signing_alg_values_supported": []string{
|
|
||||||
"ES256",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if idp.enableDeviceAuth {
|
|
||||||
config["device_authorization_endpoint"] = rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String()
|
|
||||||
}
|
|
||||||
serveJSON(w, config)
|
|
||||||
})
|
|
||||||
router.Handle("/oidc/auth", idp.HandleAuth)
|
|
||||||
router.Handle("/oidc/token", idp.HandleToken)
|
|
||||||
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
|
|
||||||
if idp.enableDeviceAuth {
|
|
||||||
router.Handle("/oidc/device/code", idp.HandleDeviceCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
env.AddUpstream(router)
|
env.AddUpstream(up)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Modify implements testenv.Modifier.
|
// Modify implements testenv.Modifier.
|
||||||
|
@ -165,323 +111,19 @@ func (idp *IDP) Modify(cfg *config.Config) {
|
||||||
|
|
||||||
var _ testenv.Modifier = (*IDP)(nil)
|
var _ testenv.Modifier = (*IDP)(nil)
|
||||||
|
|
||||||
func NewIDP(users []*User, opts ...IDPOption) *IDP {
|
func NewIDP(users []*mockidp.User, opts ...IDPOption) *IDP {
|
||||||
options := IDPOptions{
|
options := IDPOptions{
|
||||||
enableTLS: true,
|
enableTLS: true,
|
||||||
}
|
}
|
||||||
options.apply(opts...)
|
options.apply(opts...)
|
||||||
|
|
||||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
publicKey := &privateKey.PublicKey
|
|
||||||
|
|
||||||
signingKey := jose.SigningKey{
|
|
||||||
Algorithm: jose.ES256,
|
|
||||||
Key: privateKey,
|
|
||||||
}
|
|
||||||
publicJWK := jose.JSONWebKey{
|
|
||||||
Key: publicKey,
|
|
||||||
Algorithm: string(jose.ES256),
|
|
||||||
Use: "sig",
|
|
||||||
}
|
|
||||||
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
publicJWK.KeyID = hex.EncodeToString(thumbprint)
|
|
||||||
|
|
||||||
userLookup := map[string]*User{}
|
|
||||||
for _, user := range users {
|
|
||||||
user.ID = uuid.NewString()
|
|
||||||
userLookup[user.ID] = user
|
|
||||||
}
|
|
||||||
return &IDP{
|
return &IDP{
|
||||||
IDPOptions: options,
|
IDPOptions: options,
|
||||||
publicJWK: publicJWK,
|
mockIDP: mockidp.New(mockidp.Config{
|
||||||
signingKey: signingKey,
|
Users: users,
|
||||||
userLookup: userLookup,
|
EnableDeviceAuth: options.enableDeviceAuth,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleAuth handles the auth flow for OIDC.
|
type User = mockidp.User
|
||||||
func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
|
||||||
rawRedirectURI := r.FormValue("redirect_uri")
|
|
||||||
if rawRedirectURI == "" {
|
|
||||||
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectURI, err := url.Parse(rawRedirectURI)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawClientID := r.FormValue("client_id")
|
|
||||||
if rawClientID == "" {
|
|
||||||
http.Error(w, "missing client_id", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawEmail := r.FormValue("email")
|
|
||||||
if rawEmail != "" {
|
|
||||||
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
|
|
||||||
RawQuery: (url.Values{
|
|
||||||
"state": {r.FormValue("state")},
|
|
||||||
"code": {State{
|
|
||||||
Email: rawEmail,
|
|
||||||
ClientID: rawClientID,
|
|
||||||
}.Encode()},
|
|
||||||
}).Encode(),
|
|
||||||
}).String(), http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serveHTML(w, `<!doctype html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<title>Login</title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<form method="POST" style="max-width: 200px">
|
|
||||||
<fieldset>
|
|
||||||
<legend>Login</legend>
|
|
||||||
|
|
||||||
<table>
|
|
||||||
<tbody>
|
|
||||||
<tr>
|
|
||||||
<th><label for="email">Email</label></th>
|
|
||||||
<td>
|
|
||||||
<input type="email" name="email" placeholder="email" />
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td colspan="2">
|
|
||||||
<input type="submit" />
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
</fieldset>
|
|
||||||
</form>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleToken handles the token flow for OIDC.
|
|
||||||
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if idp.enableDeviceAuth && r.FormValue("device_code") != "" {
|
|
||||||
idp.serveToken(w, r, &State{
|
|
||||||
ClientID: r.FormValue("client_id"),
|
|
||||||
Email: "fake.user@example.com",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawCode := r.FormValue("code")
|
|
||||||
state, err := DecodeState(rawCode)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
idp.serveToken(w, r, state)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *State) {
|
|
||||||
serveJSON(w, map[string]interface{}{
|
|
||||||
"access_token": state.Encode(),
|
|
||||||
"refresh_token": state.Encode(),
|
|
||||||
"token_type": "Bearer",
|
|
||||||
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleUserInfo handles retrieving the user info.
|
|
||||||
func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
|
|
||||||
authz := r.Header.Get("Authorization")
|
|
||||||
if authz == "" {
|
|
||||||
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(authz, "Bearer ") {
|
|
||||||
authz = authz[len("Bearer "):]
|
|
||||||
} else if strings.HasPrefix(authz, "token ") {
|
|
||||||
authz = authz[len("token "):]
|
|
||||||
} else {
|
|
||||||
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
state, err := DecodeState(authz)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleDeviceCode initiates a device auth code flow.
|
|
||||||
//
|
|
||||||
// This is the bare minimum to simulate the device auth code flow. There is no client_id
|
|
||||||
// verification or any actual login.
|
|
||||||
func (idp *IDP) HandleDeviceCode(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS"
|
|
||||||
userCode := "ABCD-EFGH"
|
|
||||||
|
|
||||||
rootURL, _ := url.Parse(idp.url.Value())
|
|
||||||
u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented
|
|
||||||
verificationURI := u.String()
|
|
||||||
u.RawQuery = "user_code=" + userCode
|
|
||||||
verificationURIComplete := u.String()
|
|
||||||
|
|
||||||
serveJSON(w, &oauth2.DeviceAuthResponse{
|
|
||||||
DeviceCode: deviceCode,
|
|
||||||
UserCode: userCode,
|
|
||||||
VerificationURI: verificationURI,
|
|
||||||
VerificationURIComplete: verificationURIComplete,
|
|
||||||
Expiry: time.Now().Add(5 * time.Minute),
|
|
||||||
Interval: 1,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type RootURLKey struct{}
|
|
||||||
|
|
||||||
var rootURLKey RootURLKey
|
|
||||||
|
|
||||||
// WithRootURL sets the Root URL in a context.
|
|
||||||
func WithRootURL(ctx context.Context, rootURL *url.URL) context.Context {
|
|
||||||
return context.WithValue(ctx, rootURLKey, rootURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRootURL(r *http.Request) *url.URL {
|
|
||||||
if u, ok := r.Context().Value(rootURLKey).(*url.URL); ok {
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
u := *r.URL
|
|
||||||
if r.Host != "" {
|
|
||||||
u.Host = r.Host
|
|
||||||
}
|
|
||||||
if u.Scheme == "" {
|
|
||||||
if r.TLS != nil {
|
|
||||||
u.Scheme = "https"
|
|
||||||
} else {
|
|
||||||
u.Scheme = "http"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
u.Path = ""
|
|
||||||
return &u
|
|
||||||
}
|
|
||||||
|
|
||||||
func serveHTML(w http.ResponseWriter, html string) {
|
|
||||||
w.Header().Set("Content-Type", "text/html")
|
|
||||||
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = io.WriteString(w, html)
|
|
||||||
}
|
|
||||||
|
|
||||||
func serveJSON(w http.ResponseWriter, obj interface{}) {
|
|
||||||
bs, err := json.Marshal(obj)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
type State struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeState(rawCode string) (*State, error) {
|
|
||||||
var state State
|
|
||||||
bs, _ := base64.URLEncoding.DecodeString(rawCode)
|
|
||||||
err := json.Unmarshal(bs, &state)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &state, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (state State) Encode() string {
|
|
||||||
bs, _ := json.Marshal(state)
|
|
||||||
return base64.URLEncoding.EncodeToString(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (state State) GetIDToken(r *http.Request, users map[string]*User) *IDToken {
|
|
||||||
token := &IDToken{
|
|
||||||
UserInfo: state.GetUserInfo(users),
|
|
||||||
|
|
||||||
Issuer: getRootURL(r).String(),
|
|
||||||
Audience: state.ClientID,
|
|
||||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
|
|
||||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
||||||
}
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
func (state State) GetUserInfo(users map[string]*User) *UserInfo {
|
|
||||||
userInfo := &UserInfo{
|
|
||||||
Subject: state.Email,
|
|
||||||
Email: state.Email,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, u := range users {
|
|
||||||
if u.Email == state.Email {
|
|
||||||
userInfo.Subject = u.ID
|
|
||||||
userInfo.Name = u.FirstName + " " + u.LastName
|
|
||||||
userInfo.FamilyName = u.LastName
|
|
||||||
userInfo.GivenName = u.FirstName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return userInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserInfo struct {
|
|
||||||
Subject string `json:"sub"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
FamilyName string `json:"family_name"`
|
|
||||||
GivenName string `json:"given_name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IDToken struct {
|
|
||||||
*UserInfo
|
|
||||||
|
|
||||||
Issuer string `json:"iss"`
|
|
||||||
Audience string `json:"aud"`
|
|
||||||
Expiry *jwt.NumericDate `json:"exp"`
|
|
||||||
IssuedAt *jwt.NumericDate `json:"iat"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (token *IDToken) Encode(signingKey jose.SigningKey) string {
|
|
||||||
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID string
|
|
||||||
Email string
|
|
||||||
FirstName string
|
|
||||||
LastName string
|
|
||||||
}
|
|
||||||
|
|
|
@ -181,6 +181,7 @@ type HTTPUpstream interface {
|
||||||
|
|
||||||
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||||
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
||||||
|
Router() *mux.Router
|
||||||
|
|
||||||
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
|
@ -244,6 +245,10 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req
|
||||||
return h.router.HandleFunc(path, f)
|
return h.router.HandleFunc(path, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUpstream) Router() *mux.Router {
|
||||||
|
return h.router
|
||||||
|
}
|
||||||
|
|
||||||
// Router implements HTTPUpstream.
|
// Router implements HTTPUpstream.
|
||||||
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
||||||
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
383
internal/testutil/mockidp/mockidp.go
Normal file
383
internal/testutil/mockidp/mockidp.go
Normal file
|
@ -0,0 +1,383 @@
|
||||||
|
package mockidp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IDP struct {
|
||||||
|
publicJWK jose.JSONWebKey
|
||||||
|
signingKey jose.SigningKey
|
||||||
|
|
||||||
|
stateEncoder encoding.MarshalUnmarshaler
|
||||||
|
userLookup map[string]*User
|
||||||
|
|
||||||
|
enableDeviceAuth bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Users []*User
|
||||||
|
EnableDeviceAuth bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg Config) *IDP {
|
||||||
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
|
||||||
|
signingKey := jose.SigningKey{
|
||||||
|
Algorithm: jose.ES256,
|
||||||
|
Key: privateKey,
|
||||||
|
}
|
||||||
|
publicJWK := jose.JSONWebKey{
|
||||||
|
Key: publicKey,
|
||||||
|
Algorithm: string(jose.ES256),
|
||||||
|
Use: "sig",
|
||||||
|
}
|
||||||
|
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
publicJWK.KeyID = hex.EncodeToString(thumbprint)
|
||||||
|
|
||||||
|
userLookup := map[string]*User{}
|
||||||
|
for _, user := range cfg.Users {
|
||||||
|
user.ID = uuid.NewString()
|
||||||
|
userLookup[user.ID] = user
|
||||||
|
}
|
||||||
|
return &IDP{
|
||||||
|
publicJWK: publicJWK,
|
||||||
|
signingKey: signingKey,
|
||||||
|
userLookup: userLookup,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (idp *IDP) Start(t *testing.T) string {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
idp.Register(r)
|
||||||
|
server := httptest.NewServer(r)
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
return server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (idp *IDP) Register(router *mux.Router) {
|
||||||
|
router.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
|
||||||
|
Keys: []jose.JSONWebKey{idp.publicJWK},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
router.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
rootURL := getRootURL(r)
|
||||||
|
config := map[string]interface{}{
|
||||||
|
"issuer": rootURL.String(),
|
||||||
|
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
||||||
|
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
||||||
|
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
|
||||||
|
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
|
||||||
|
"id_token_signing_alg_values_supported": []string{
|
||||||
|
"ES256",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if idp.enableDeviceAuth {
|
||||||
|
config["device_authorization_endpoint"] = rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String()
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(config)
|
||||||
|
})
|
||||||
|
router.HandleFunc("/oidc/auth", idp.handleAuth)
|
||||||
|
if idp.enableDeviceAuth {
|
||||||
|
router.HandleFunc("/oidc/device/code", idp.handleDeviceCode)
|
||||||
|
}
|
||||||
|
router.HandleFunc("/oidc/token", idp.handleToken)
|
||||||
|
router.HandleFunc("/oidc/userinfo", idp.handleUserInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAuth handles the auth flow for OIDC.
|
||||||
|
func (idp *IDP) handleAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
rawRedirectURI := r.FormValue("redirect_uri")
|
||||||
|
if rawRedirectURI == "" {
|
||||||
|
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI, err := url.Parse(rawRedirectURI)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawClientID := r.FormValue("client_id")
|
||||||
|
if rawClientID == "" {
|
||||||
|
http.Error(w, "missing client_id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawEmail := r.FormValue("email")
|
||||||
|
if rawEmail != "" {
|
||||||
|
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
|
||||||
|
RawQuery: (url.Values{
|
||||||
|
"state": {r.FormValue("state")},
|
||||||
|
"code": {state{
|
||||||
|
Email: rawEmail,
|
||||||
|
ClientID: rawClientID,
|
||||||
|
}.Encode()},
|
||||||
|
}).Encode(),
|
||||||
|
}).String(), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serveHTML(w, `<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Login</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<form method="POST" style="max-width: 200px">
|
||||||
|
<fieldset>
|
||||||
|
<legend>Login</legend>
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<th><label for="email">Email</label></th>
|
||||||
|
<td>
|
||||||
|
<input type="email" name="email" placeholder="email" />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">
|
||||||
|
<input type="submit" />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
</fieldset>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleToken handles the token flow for OIDC.
|
||||||
|
func (idp *IDP) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if idp.enableDeviceAuth && r.FormValue("device_code") != "" {
|
||||||
|
idp.serveToken(w, r, &state{
|
||||||
|
ClientID: r.FormValue("client_id"),
|
||||||
|
Email: "fake.user@example.com",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawCode := r.FormValue("code")
|
||||||
|
state, err := decodeState(rawCode)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idp.serveToken(w, r, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *state) {
|
||||||
|
serveJSON(w, map[string]interface{}{
|
||||||
|
"access_token": state.Encode(),
|
||||||
|
"refresh_token": state.Encode(),
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUserInfo handles retrieving the user info.
|
||||||
|
func (idp *IDP) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authz := r.Header.Get("Authorization")
|
||||||
|
if authz == "" {
|
||||||
|
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(authz, "Bearer ") {
|
||||||
|
authz = authz[len("Bearer "):]
|
||||||
|
} else if strings.HasPrefix(authz, "token ") {
|
||||||
|
authz = authz[len("token "):]
|
||||||
|
} else {
|
||||||
|
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := decodeState(authz)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDeviceCode initiates a device auth code flow.
|
||||||
|
//
|
||||||
|
// This is the bare minimum to simulate the device auth code flow. There is no client_id
|
||||||
|
// verification or any actual login.
|
||||||
|
func (idp *IDP) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS"
|
||||||
|
userCode := "ABCD-EFGH"
|
||||||
|
|
||||||
|
rootURL := getRootURL(r)
|
||||||
|
u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented
|
||||||
|
verificationURI := u.String()
|
||||||
|
u.RawQuery = "user_code=" + userCode
|
||||||
|
verificationURIComplete := u.String()
|
||||||
|
|
||||||
|
serveJSON(w, &oauth2.DeviceAuthResponse{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
UserCode: userCode,
|
||||||
|
VerificationURI: verificationURI,
|
||||||
|
VerificationURIComplete: verificationURIComplete,
|
||||||
|
Expiry: time.Now().Add(5 * time.Minute),
|
||||||
|
Interval: 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRootURL(r *http.Request) *url.URL {
|
||||||
|
u := *r.URL
|
||||||
|
if r.Host != "" {
|
||||||
|
u.Host = r.Host
|
||||||
|
}
|
||||||
|
if u.Scheme == "" {
|
||||||
|
if r.TLS != nil {
|
||||||
|
u.Scheme = "https"
|
||||||
|
} else {
|
||||||
|
u.Scheme = "http"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u.Path = ""
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveHTML(w http.ResponseWriter, html string) {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.WriteString(w, html)
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveJSON(w http.ResponseWriter, obj interface{}) {
|
||||||
|
bs, err := json.Marshal(obj)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type state struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeState(rawCode string) (*state, error) {
|
||||||
|
var state state
|
||||||
|
bs, _ := base64.URLEncoding.DecodeString(rawCode)
|
||||||
|
err := json.Unmarshal(bs, &state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state state) Encode() string {
|
||||||
|
bs, _ := json.Marshal(state)
|
||||||
|
return base64.URLEncoding.EncodeToString(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state state) GetIDToken(r *http.Request, users map[string]*User) *idToken {
|
||||||
|
token := &idToken{
|
||||||
|
userInfo: state.GetUserInfo(users),
|
||||||
|
|
||||||
|
Issuer: getRootURL(r).String(),
|
||||||
|
Audience: state.ClientID,
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state state) GetUserInfo(users map[string]*User) *userInfo {
|
||||||
|
userInfo := &userInfo{
|
||||||
|
Subject: state.Email,
|
||||||
|
Email: state.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, u := range users {
|
||||||
|
if u.Email == state.Email {
|
||||||
|
userInfo.Subject = u.ID
|
||||||
|
userInfo.Name = u.FirstName + " " + u.LastName
|
||||||
|
userInfo.FamilyName = u.LastName
|
||||||
|
userInfo.GivenName = u.FirstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return userInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type userInfo struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
FamilyName string `json:"family_name"`
|
||||||
|
GivenName string `json:"given_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type idToken struct {
|
||||||
|
*userInfo
|
||||||
|
|
||||||
|
Issuer string `json:"iss"`
|
||||||
|
Audience string `json:"aud"`
|
||||||
|
Expiry *jwt.NumericDate `json:"exp"`
|
||||||
|
IssuedAt *jwt.NumericDate `json:"iat"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (token *idToken) Encode(signingKey jose.SigningKey) string {
|
||||||
|
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Email string
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue