mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-31 23:41:09 +02:00
refactor testenv mock IdP to also work standalone (#5678)
Refactor the testenv mock IdP implementation to split off the core functionality from the testenv environment setup. Add a Start() method to run the mock IdP as an httptest server, tied to a test lifecycle. This allows the mock IdP to be used also in tests that do not start a full Pomerium instance.
This commit is contained in:
parent
717a7bdf5a
commit
6a65c52a6c
3 changed files with 402 additions and 372 deletions
|
@ -2,48 +2,25 @@ package scenarios
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/internal/testutil/mockidp"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
||||
type IDP struct {
|
||||
IDPOptions
|
||||
id values.Value[string]
|
||||
url values.Value[string]
|
||||
publicJWK jose.JSONWebKey
|
||||
signingKey jose.SigningKey
|
||||
|
||||
stateEncoder encoding.MarshalUnmarshaler
|
||||
userLookup map[string]*User
|
||||
id values.Value[string]
|
||||
url values.Value[string]
|
||||
mockIDP *mockidp.IDP
|
||||
}
|
||||
|
||||
type IDPOptions struct {
|
||||
|
@ -92,9 +69,9 @@ func (idp *IDP) Attach(ctx context.Context) {
|
|||
})
|
||||
}
|
||||
|
||||
router := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP"))
|
||||
up := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP"))
|
||||
|
||||
idp.url = values.Bind2(idpURL, router.Addr(), func(urlStr string, addr string) string {
|
||||
idp.url = values.Bind2(idpURL, up.Addr(), func(urlStr string, addr string) string {
|
||||
u, _ := url.Parse(urlStr)
|
||||
host, _, _ := net.SplitHostPort(u.Host)
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
|
@ -105,9 +82,6 @@ func (idp *IDP) Attach(ctx context.Context) {
|
|||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
}).String()
|
||||
})
|
||||
var err error
|
||||
idp.stateEncoder, err = jws.NewHS256Signer(env.SharedSecret())
|
||||
env.Require().NoError(err)
|
||||
|
||||
idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string {
|
||||
provider := identity.Provider{
|
||||
|
@ -121,37 +95,9 @@ func (idp *IDP) Attach(ctx context.Context) {
|
|||
return provider.Hash()
|
||||
})
|
||||
|
||||
router.Handle("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{idp.publicJWK},
|
||||
})
|
||||
})
|
||||
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
|
||||
rootURL, _ := url.Parse(idp.url.Value())
|
||||
config := map[string]interface{}{
|
||||
"issuer": rootURL.String(),
|
||||
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
||||
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
||||
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
|
||||
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
|
||||
"id_token_signing_alg_values_supported": []string{
|
||||
"ES256",
|
||||
},
|
||||
}
|
||||
if idp.enableDeviceAuth {
|
||||
config["device_authorization_endpoint"] = rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String()
|
||||
}
|
||||
serveJSON(w, config)
|
||||
})
|
||||
router.Handle("/oidc/auth", idp.HandleAuth)
|
||||
router.Handle("/oidc/token", idp.HandleToken)
|
||||
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
|
||||
if idp.enableDeviceAuth {
|
||||
router.Handle("/oidc/device/code", idp.HandleDeviceCode)
|
||||
}
|
||||
idp.mockIDP.Register(up.Router())
|
||||
|
||||
env.AddUpstream(router)
|
||||
env.AddUpstream(up)
|
||||
}
|
||||
|
||||
// Modify implements testenv.Modifier.
|
||||
|
@ -165,323 +111,19 @@ func (idp *IDP) Modify(cfg *config.Config) {
|
|||
|
||||
var _ testenv.Modifier = (*IDP)(nil)
|
||||
|
||||
func NewIDP(users []*User, opts ...IDPOption) *IDP {
|
||||
func NewIDP(users []*mockidp.User, opts ...IDPOption) *IDP {
|
||||
options := IDPOptions{
|
||||
enableTLS: true,
|
||||
}
|
||||
options.apply(opts...)
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
signingKey := jose.SigningKey{
|
||||
Algorithm: jose.ES256,
|
||||
Key: privateKey,
|
||||
}
|
||||
publicJWK := jose.JSONWebKey{
|
||||
Key: publicKey,
|
||||
Algorithm: string(jose.ES256),
|
||||
Use: "sig",
|
||||
}
|
||||
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
publicJWK.KeyID = hex.EncodeToString(thumbprint)
|
||||
|
||||
userLookup := map[string]*User{}
|
||||
for _, user := range users {
|
||||
user.ID = uuid.NewString()
|
||||
userLookup[user.ID] = user
|
||||
}
|
||||
return &IDP{
|
||||
IDPOptions: options,
|
||||
publicJWK: publicJWK,
|
||||
signingKey: signingKey,
|
||||
userLookup: userLookup,
|
||||
mockIDP: mockidp.New(mockidp.Config{
|
||||
Users: users,
|
||||
EnableDeviceAuth: options.enableDeviceAuth,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAuth handles the auth flow for OIDC.
|
||||
func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||
rawRedirectURI := r.FormValue("redirect_uri")
|
||||
if rawRedirectURI == "" {
|
||||
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI, err := url.Parse(rawRedirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawClientID := r.FormValue("client_id")
|
||||
if rawClientID == "" {
|
||||
http.Error(w, "missing client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawEmail := r.FormValue("email")
|
||||
if rawEmail != "" {
|
||||
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
|
||||
RawQuery: (url.Values{
|
||||
"state": {r.FormValue("state")},
|
||||
"code": {State{
|
||||
Email: rawEmail,
|
||||
ClientID: rawClientID,
|
||||
}.Encode()},
|
||||
}).Encode(),
|
||||
}).String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
serveHTML(w, `<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Login</title>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST" style="max-width: 200px">
|
||||
<fieldset>
|
||||
<legend>Login</legend>
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th><label for="email">Email</label></th>
|
||||
<td>
|
||||
<input type="email" name="email" placeholder="email" />
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="2">
|
||||
<input type="submit" />
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</fieldset>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
`)
|
||||
}
|
||||
|
||||
// HandleToken handles the token flow for OIDC.
|
||||
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
|
||||
if idp.enableDeviceAuth && r.FormValue("device_code") != "" {
|
||||
idp.serveToken(w, r, &State{
|
||||
ClientID: r.FormValue("client_id"),
|
||||
Email: "fake.user@example.com",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rawCode := r.FormValue("code")
|
||||
state, err := DecodeState(rawCode)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idp.serveToken(w, r, state)
|
||||
}
|
||||
|
||||
func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *State) {
|
||||
serveJSON(w, map[string]interface{}{
|
||||
"access_token": state.Encode(),
|
||||
"refresh_token": state.Encode(),
|
||||
"token_type": "Bearer",
|
||||
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleUserInfo handles retrieving the user info.
|
||||
func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
authz := r.Header.Get("Authorization")
|
||||
if authz == "" {
|
||||
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(authz, "Bearer ") {
|
||||
authz = authz[len("Bearer "):]
|
||||
} else if strings.HasPrefix(authz, "token ") {
|
||||
authz = authz[len("token "):]
|
||||
} else {
|
||||
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := DecodeState(authz)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
||||
}
|
||||
|
||||
// HandleDeviceCode initiates a device auth code flow.
|
||||
//
|
||||
// This is the bare minimum to simulate the device auth code flow. There is no client_id
|
||||
// verification or any actual login.
|
||||
func (idp *IDP) HandleDeviceCode(w http.ResponseWriter, _ *http.Request) {
|
||||
deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS"
|
||||
userCode := "ABCD-EFGH"
|
||||
|
||||
rootURL, _ := url.Parse(idp.url.Value())
|
||||
u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented
|
||||
verificationURI := u.String()
|
||||
u.RawQuery = "user_code=" + userCode
|
||||
verificationURIComplete := u.String()
|
||||
|
||||
serveJSON(w, &oauth2.DeviceAuthResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: verificationURI,
|
||||
VerificationURIComplete: verificationURIComplete,
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
Interval: 1,
|
||||
})
|
||||
}
|
||||
|
||||
type RootURLKey struct{}
|
||||
|
||||
var rootURLKey RootURLKey
|
||||
|
||||
// WithRootURL sets the Root URL in a context.
|
||||
func WithRootURL(ctx context.Context, rootURL *url.URL) context.Context {
|
||||
return context.WithValue(ctx, rootURLKey, rootURL)
|
||||
}
|
||||
|
||||
func getRootURL(r *http.Request) *url.URL {
|
||||
if u, ok := r.Context().Value(rootURLKey).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
u := *r.URL
|
||||
if r.Host != "" {
|
||||
u.Host = r.Host
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
if r.TLS != nil {
|
||||
u.Scheme = "https"
|
||||
} else {
|
||||
u.Scheme = "http"
|
||||
}
|
||||
}
|
||||
u.Path = ""
|
||||
return &u
|
||||
}
|
||||
|
||||
func serveHTML(w http.ResponseWriter, html string) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, html)
|
||||
}
|
||||
|
||||
func serveJSON(w http.ResponseWriter, obj interface{}) {
|
||||
bs, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bs)
|
||||
}
|
||||
|
||||
type State struct {
|
||||
Email string `json:"email"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
func DecodeState(rawCode string) (*State, error) {
|
||||
var state State
|
||||
bs, _ := base64.URLEncoding.DecodeString(rawCode)
|
||||
err := json.Unmarshal(bs, &state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (state State) Encode() string {
|
||||
bs, _ := json.Marshal(state)
|
||||
return base64.URLEncoding.EncodeToString(bs)
|
||||
}
|
||||
|
||||
func (state State) GetIDToken(r *http.Request, users map[string]*User) *IDToken {
|
||||
token := &IDToken{
|
||||
UserInfo: state.GetUserInfo(users),
|
||||
|
||||
Issuer: getRootURL(r).String(),
|
||||
Audience: state.ClientID,
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func (state State) GetUserInfo(users map[string]*User) *UserInfo {
|
||||
userInfo := &UserInfo{
|
||||
Subject: state.Email,
|
||||
Email: state.Email,
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if u.Email == state.Email {
|
||||
userInfo.Subject = u.ID
|
||||
userInfo.Name = u.FirstName + " " + u.LastName
|
||||
userInfo.FamilyName = u.LastName
|
||||
userInfo.GivenName = u.FirstName
|
||||
}
|
||||
}
|
||||
|
||||
return userInfo
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
Subject string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
FamilyName string `json:"family_name"`
|
||||
GivenName string `json:"given_name"`
|
||||
}
|
||||
|
||||
type IDToken struct {
|
||||
*UserInfo
|
||||
|
||||
Issuer string `json:"iss"`
|
||||
Audience string `json:"aud"`
|
||||
Expiry *jwt.NumericDate `json:"exp"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat"`
|
||||
}
|
||||
|
||||
func (token *IDToken) Encode(signingKey jose.SigningKey) string {
|
||||
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
type User = mockidp.User
|
||||
|
|
|
@ -181,6 +181,7 @@ type HTTPUpstream interface {
|
|||
|
||||
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
||||
Router() *mux.Router
|
||||
|
||||
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
|
@ -244,6 +245,10 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req
|
|||
return h.router.HandleFunc(path, f)
|
||||
}
|
||||
|
||||
func (h *httpUpstream) Router() *mux.Router {
|
||||
return h.router
|
||||
}
|
||||
|
||||
// Router implements HTTPUpstream.
|
||||
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
||||
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
383
internal/testutil/mockidp/mockidp.go
Normal file
383
internal/testutil/mockidp/mockidp.go
Normal file
|
@ -0,0 +1,383 @@
|
|||
package mockidp
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
)
|
||||
|
||||
type IDP struct {
|
||||
publicJWK jose.JSONWebKey
|
||||
signingKey jose.SigningKey
|
||||
|
||||
stateEncoder encoding.MarshalUnmarshaler
|
||||
userLookup map[string]*User
|
||||
|
||||
enableDeviceAuth bool
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Users []*User
|
||||
EnableDeviceAuth bool
|
||||
}
|
||||
|
||||
func New(cfg Config) *IDP {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
signingKey := jose.SigningKey{
|
||||
Algorithm: jose.ES256,
|
||||
Key: privateKey,
|
||||
}
|
||||
publicJWK := jose.JSONWebKey{
|
||||
Key: publicKey,
|
||||
Algorithm: string(jose.ES256),
|
||||
Use: "sig",
|
||||
}
|
||||
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
publicJWK.KeyID = hex.EncodeToString(thumbprint)
|
||||
|
||||
userLookup := map[string]*User{}
|
||||
for _, user := range cfg.Users {
|
||||
user.ID = uuid.NewString()
|
||||
userLookup[user.ID] = user
|
||||
}
|
||||
return &IDP{
|
||||
publicJWK: publicJWK,
|
||||
signingKey: signingKey,
|
||||
userLookup: userLookup,
|
||||
}
|
||||
}
|
||||
|
||||
func (idp *IDP) Start(t *testing.T) string {
|
||||
r := mux.NewRouter()
|
||||
idp.Register(r)
|
||||
server := httptest.NewServer(r)
|
||||
t.Cleanup(server.Close)
|
||||
return server.URL
|
||||
}
|
||||
|
||||
func (idp *IDP) Register(router *mux.Router) {
|
||||
router.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{idp.publicJWK},
|
||||
})
|
||||
})
|
||||
router.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
rootURL := getRootURL(r)
|
||||
config := map[string]interface{}{
|
||||
"issuer": rootURL.String(),
|
||||
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
||||
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
||||
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
|
||||
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
|
||||
"id_token_signing_alg_values_supported": []string{
|
||||
"ES256",
|
||||
},
|
||||
}
|
||||
if idp.enableDeviceAuth {
|
||||
config["device_authorization_endpoint"] = rootURL.ResolveReference(&url.URL{Path: "/oidc/device/code"}).String()
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(config)
|
||||
})
|
||||
router.HandleFunc("/oidc/auth", idp.handleAuth)
|
||||
if idp.enableDeviceAuth {
|
||||
router.HandleFunc("/oidc/device/code", idp.handleDeviceCode)
|
||||
}
|
||||
router.HandleFunc("/oidc/token", idp.handleToken)
|
||||
router.HandleFunc("/oidc/userinfo", idp.handleUserInfo)
|
||||
}
|
||||
|
||||
// handleAuth handles the auth flow for OIDC.
|
||||
func (idp *IDP) handleAuth(w http.ResponseWriter, r *http.Request) {
|
||||
rawRedirectURI := r.FormValue("redirect_uri")
|
||||
if rawRedirectURI == "" {
|
||||
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI, err := url.Parse(rawRedirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawClientID := r.FormValue("client_id")
|
||||
if rawClientID == "" {
|
||||
http.Error(w, "missing client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawEmail := r.FormValue("email")
|
||||
if rawEmail != "" {
|
||||
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
|
||||
RawQuery: (url.Values{
|
||||
"state": {r.FormValue("state")},
|
||||
"code": {state{
|
||||
Email: rawEmail,
|
||||
ClientID: rawClientID,
|
||||
}.Encode()},
|
||||
}).Encode(),
|
||||
}).String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
serveHTML(w, `<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Login</title>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST" style="max-width: 200px">
|
||||
<fieldset>
|
||||
<legend>Login</legend>
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th><label for="email">Email</label></th>
|
||||
<td>
|
||||
<input type="email" name="email" placeholder="email" />
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="2">
|
||||
<input type="submit" />
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</fieldset>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
`)
|
||||
}
|
||||
|
||||
// handleToken handles the token flow for OIDC.
|
||||
func (idp *IDP) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
if idp.enableDeviceAuth && r.FormValue("device_code") != "" {
|
||||
idp.serveToken(w, r, &state{
|
||||
ClientID: r.FormValue("client_id"),
|
||||
Email: "fake.user@example.com",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rawCode := r.FormValue("code")
|
||||
state, err := decodeState(rawCode)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idp.serveToken(w, r, state)
|
||||
}
|
||||
|
||||
func (idp *IDP) serveToken(w http.ResponseWriter, r *http.Request, state *state) {
|
||||
serveJSON(w, map[string]interface{}{
|
||||
"access_token": state.Encode(),
|
||||
"refresh_token": state.Encode(),
|
||||
"token_type": "Bearer",
|
||||
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
|
||||
})
|
||||
}
|
||||
|
||||
// handleUserInfo handles retrieving the user info.
|
||||
func (idp *IDP) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
authz := r.Header.Get("Authorization")
|
||||
if authz == "" {
|
||||
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(authz, "Bearer ") {
|
||||
authz = authz[len("Bearer "):]
|
||||
} else if strings.HasPrefix(authz, "token ") {
|
||||
authz = authz[len("token "):]
|
||||
} else {
|
||||
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := decodeState(authz)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
||||
}
|
||||
|
||||
// handleDeviceCode initiates a device auth code flow.
|
||||
//
|
||||
// This is the bare minimum to simulate the device auth code flow. There is no client_id
|
||||
// verification or any actual login.
|
||||
func (idp *IDP) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
deviceCode := "GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS"
|
||||
userCode := "ABCD-EFGH"
|
||||
|
||||
rootURL := getRootURL(r)
|
||||
u := rootURL.ResolveReference(&url.URL{Path: "/oidc/device"}) // note: not actually implemented
|
||||
verificationURI := u.String()
|
||||
u.RawQuery = "user_code=" + userCode
|
||||
verificationURIComplete := u.String()
|
||||
|
||||
serveJSON(w, &oauth2.DeviceAuthResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: verificationURI,
|
||||
VerificationURIComplete: verificationURIComplete,
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
Interval: 1,
|
||||
})
|
||||
}
|
||||
|
||||
func getRootURL(r *http.Request) *url.URL {
|
||||
u := *r.URL
|
||||
if r.Host != "" {
|
||||
u.Host = r.Host
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
if r.TLS != nil {
|
||||
u.Scheme = "https"
|
||||
} else {
|
||||
u.Scheme = "http"
|
||||
}
|
||||
}
|
||||
u.Path = ""
|
||||
return &u
|
||||
}
|
||||
|
||||
func serveHTML(w http.ResponseWriter, html string) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, html)
|
||||
}
|
||||
|
||||
func serveJSON(w http.ResponseWriter, obj interface{}) {
|
||||
bs, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bs)
|
||||
}
|
||||
|
||||
type state struct {
|
||||
Email string `json:"email"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
func decodeState(rawCode string) (*state, error) {
|
||||
var state state
|
||||
bs, _ := base64.URLEncoding.DecodeString(rawCode)
|
||||
err := json.Unmarshal(bs, &state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (state state) Encode() string {
|
||||
bs, _ := json.Marshal(state)
|
||||
return base64.URLEncoding.EncodeToString(bs)
|
||||
}
|
||||
|
||||
func (state state) GetIDToken(r *http.Request, users map[string]*User) *idToken {
|
||||
token := &idToken{
|
||||
userInfo: state.GetUserInfo(users),
|
||||
|
||||
Issuer: getRootURL(r).String(),
|
||||
Audience: state.ClientID,
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func (state state) GetUserInfo(users map[string]*User) *userInfo {
|
||||
userInfo := &userInfo{
|
||||
Subject: state.Email,
|
||||
Email: state.Email,
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if u.Email == state.Email {
|
||||
userInfo.Subject = u.ID
|
||||
userInfo.Name = u.FirstName + " " + u.LastName
|
||||
userInfo.FamilyName = u.LastName
|
||||
userInfo.GivenName = u.FirstName
|
||||
}
|
||||
}
|
||||
|
||||
return userInfo
|
||||
}
|
||||
|
||||
type userInfo struct {
|
||||
Subject string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
FamilyName string `json:"family_name"`
|
||||
GivenName string `json:"given_name"`
|
||||
}
|
||||
|
||||
type idToken struct {
|
||||
*userInfo
|
||||
|
||||
Issuer string `json:"iss"`
|
||||
Audience string `json:"aud"`
|
||||
Expiry *jwt.NumericDate `json:"exp"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat"`
|
||||
}
|
||||
|
||||
func (token *idToken) Encode(signingKey jose.SigningKey) string {
|
||||
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue