mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-30 17:37:25 +02:00
proxy: add userinfo and webauthn endpoints (#3755)
* proxy: add userinfo and webauthn endpoints * use TLD for RP id * use EffectiveTLDPlusOne * upgrade webauthn * fix test * Update internal/handlers/jwks.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> Co-authored-by: bobby <1544881+desimone@users.noreply.github.com>
This commit is contained in:
parent
81053ac8ef
commit
c1a522cd82
33 changed files with 498 additions and 216 deletions
|
@ -7,10 +7,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/CAFxX/httpcompression"
|
||||
"github.com/gorilla/handlers"
|
||||
gorillahandlers "github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
|
@ -37,7 +38,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) {
|
|||
Str("path", r.URL.String()).
|
||||
Msg("http-request")
|
||||
}))
|
||||
root.Use(handlers.RecoveryHandler())
|
||||
root.Use(gorillahandlers.RecoveryHandler())
|
||||
root.Use(log.HeadersHandler(httputil.HeadersXForwarded))
|
||||
root.Use(log.RemoteAddrHandler("ip"))
|
||||
root.Use(log.UserAgentHandler("user_agent"))
|
||||
|
@ -59,10 +60,10 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
|
|||
return fmt.Errorf("invalid signing key: %w", err)
|
||||
}
|
||||
|
||||
root.HandleFunc("/healthz", httputil.HealthCheck)
|
||||
root.HandleFunc("/ping", httputil.HealthCheck)
|
||||
root.Handle("/.well-known/pomerium", httputil.WellKnownPomeriumHandler(authenticateURL))
|
||||
root.Handle("/.well-known/pomerium/", httputil.WellKnownPomeriumHandler(authenticateURL))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(httputil.JWKSHandler(rawSigningKey))
|
||||
root.HandleFunc("/healthz", handlers.HealthCheck)
|
||||
root.HandleFunc("/ping", handlers.HealthCheck)
|
||||
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey))
|
||||
return nil
|
||||
}
|
||||
|
|
15
internal/handlers/device-enrolled.go
Normal file
15
internal/handlers/device-enrolled.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/ui"
|
||||
)
|
||||
|
||||
// DeviceEnrolled displays an HTML page informing the user that they've successfully enrolled a device.
|
||||
func DeviceEnrolled(data UserInfoData) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return ui.ServePage(w, r, "DeviceEnrolled", data.ToJSON())
|
||||
})
|
||||
}
|
2
internal/handlers/handlers.go
Normal file
2
internal/handlers/handlers.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package handlers contains HTTP handlers used by Pomerium.
|
||||
package handlers
|
20
internal/handlers/health_check.go
Normal file
20
internal/handlers/health_check.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HealthCheck is a simple healthcheck handler that responds to GET and HEAD
|
||||
// http requests.
|
||||
func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if r.Method == http.MethodGet {
|
||||
fmt.Fprintln(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
36
internal/handlers/health_check_test.go
Normal file
36
internal/handlers/health_check_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good - Get", http.MethodGet, http.StatusOK},
|
||||
{"good - Head", http.MethodHead, http.StatusOK},
|
||||
{"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed},
|
||||
{"bad - Put", http.MethodPut, http.StatusMethodNotAllowed},
|
||||
{"bad - Post", http.MethodPost, http.StatusMethodNotAllowed},
|
||||
{"bad - route miss", http.MethodGet, http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HealthCheck(w, r)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
33
internal/handlers/jwks.go
Normal file
33
internal/handlers/jwks.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||
func JWKSHandler(rawSigningKey string) http.Handler {
|
||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
var jwks jose.JSONWebKeySet
|
||||
if rawSigningKey != "" {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, errors.New("bad base64 encoding for signing key"))
|
||||
}
|
||||
jwk, err := cryptutil.PublicJWKFromBytes(decodedCert)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwks.Keys = append(jwks.Keys, *jwk)
|
||||
}
|
||||
httputil.RenderJSON(w, http.StatusOK, jwks)
|
||||
return nil
|
||||
}))
|
||||
}
|
22
internal/handlers/jwks_test.go
Normal file
22
internal/handlers/jwks_test.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestJWKSHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", "https://www.example.com")
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
JWKSHandler("").ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
27
internal/handlers/signout.go
Normal file
27
internal/handlers/signout.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/ui"
|
||||
)
|
||||
|
||||
// SignOutConfirmData is the data for the SignOutConfirm page.
|
||||
type SignOutConfirmData struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
// ToJSON converts the data into a JSON map.
|
||||
func (data SignOutConfirmData) ToJSON() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"url": data.URL,
|
||||
}
|
||||
}
|
||||
|
||||
// SignOutConfirm returns a handler that renders the sign out confirm page.
|
||||
func SignOutConfirm(data SignOutConfirmData) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return ui.ServePage(w, r, "SignOutConfirm", data.ToJSON())
|
||||
})
|
||||
}
|
65
internal/handlers/userinfo.go
Normal file
65
internal/handlers/userinfo.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/ui"
|
||||
"github.com/pomerium/webauthn"
|
||||
)
|
||||
|
||||
// UserInfoData is the data for the UserInfo page.
|
||||
type UserInfoData struct {
|
||||
CSRFToken string
|
||||
IsImpersonated bool
|
||||
Session *session.Session
|
||||
User *user.User
|
||||
|
||||
IsEnterprise bool
|
||||
DirectoryUser *directory.User
|
||||
DirectoryGroups []*directory.Group
|
||||
|
||||
WebAuthnCreationOptions *webauthn.PublicKeyCredentialCreationOptions
|
||||
WebAuthnRequestOptions *webauthn.PublicKeyCredentialRequestOptions
|
||||
WebAuthnURL string
|
||||
|
||||
BrandingOptions httputil.BrandingOptions
|
||||
}
|
||||
|
||||
// ToJSON converts the data into a JSON map.
|
||||
func (data UserInfoData) ToJSON() map[string]any {
|
||||
m := map[string]any{}
|
||||
m["csrfToken"] = data.CSRFToken
|
||||
m["isImpersonated"] = data.IsImpersonated
|
||||
if bs, err := protojson.Marshal(data.Session); err == nil {
|
||||
m["session"] = json.RawMessage(bs)
|
||||
}
|
||||
if bs, err := protojson.Marshal(data.User); err == nil {
|
||||
m["user"] = json.RawMessage(bs)
|
||||
}
|
||||
m["isEnterprise"] = data.IsEnterprise
|
||||
if data.DirectoryUser != nil {
|
||||
m["directoryUser"] = data.DirectoryUser
|
||||
}
|
||||
if len(data.DirectoryGroups) > 0 {
|
||||
m["directoryGroups"] = data.DirectoryGroups
|
||||
}
|
||||
m["webAuthnCreationOptions"] = data.WebAuthnCreationOptions
|
||||
m["webAuthnRequestOptions"] = data.WebAuthnRequestOptions
|
||||
m["webAuthnUrl"] = data.WebAuthnURL
|
||||
httputil.AddBrandingOptionsToMap(m, data.BrandingOptions)
|
||||
return m
|
||||
}
|
||||
|
||||
// UserInfo returns a handler that renders the user info page.
|
||||
func UserInfo(data UserInfoData) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return ui.ServePage(w, r, "UserInfo", data.ToJSON())
|
||||
})
|
||||
}
|
550
internal/handlers/webauthn/webauthn.go
Normal file
550
internal/handlers/webauthn/webauthn.go
Normal file
|
@ -0,0 +1,550 @@
|
|||
// Package webauthn contains handlers for the WebAuthn flow in authenticate.
|
||||
package webauthn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/device"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
"github.com/pomerium/pomerium/ui"
|
||||
"github.com/pomerium/webauthn"
|
||||
)
|
||||
|
||||
const maxAuthenticateResponses = 5
|
||||
|
||||
var (
|
||||
errMissingDeviceCredentialID = httputil.NewError(http.StatusBadRequest, errors.New(
|
||||
urlutil.QueryDeviceCredentialID+" is a required parameter"))
|
||||
errMissingDeviceType = httputil.NewError(http.StatusBadRequest, errors.New(
|
||||
urlutil.QueryDeviceType+" is a required parameter"))
|
||||
errMissingRedirectURI = httputil.NewError(http.StatusBadRequest, errors.New(
|
||||
urlutil.QueryRedirectURI+" is a required parameter"))
|
||||
errInvalidDeviceCredential = httputil.NewError(http.StatusBadRequest, errors.New(
|
||||
"invalid device credential"))
|
||||
)
|
||||
|
||||
// State is the state needed by the Handler to handle requests.
|
||||
type State struct {
|
||||
AuthenticateURL *url.URL
|
||||
InternalAuthenticateURL *url.URL
|
||||
Client databroker.DataBrokerServiceClient
|
||||
PomeriumDomains []string
|
||||
RelyingParty *webauthn.RelyingParty
|
||||
Session *session.Session
|
||||
SessionState *sessions.State
|
||||
SessionStore sessions.SessionStore
|
||||
SharedKey []byte
|
||||
BrandingOptions httputil.BrandingOptions
|
||||
}
|
||||
|
||||
// A StateProvider provides state for the handler.
|
||||
type StateProvider = func(*http.Request) (*State, error)
|
||||
|
||||
// Handler is the WebAuthn device handler.
|
||||
type Handler struct {
|
||||
getState StateProvider
|
||||
}
|
||||
|
||||
// New creates a new Handler.
|
||||
func New(getState StateProvider) *Handler {
|
||||
return &Handler{
|
||||
getState: getState,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOptions returns the creation and request options for WebAuthn.
|
||||
func (h *Handler) GetOptions(r *http.Request) (
|
||||
creationOptions *webauthn.PublicKeyCredentialCreationOptions,
|
||||
requestOptions *webauthn.PublicKeyCredentialRequestOptions,
|
||||
err error,
|
||||
) {
|
||||
state, err := h.getState(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return h.getOptions(r, state, webauthnutil.DefaultDeviceType)
|
||||
}
|
||||
|
||||
// ServeHTTP serves the HTTP handler.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
httputil.HandlerFunc(h.handle).ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (h *Handler) getOptions(r *http.Request, state *State, deviceTypeParam string) (
|
||||
creationOptions *webauthn.PublicKeyCredentialCreationOptions,
|
||||
requestOptions *webauthn.PublicKeyCredentialRequestOptions,
|
||||
err error,
|
||||
) {
|
||||
// get the user information
|
||||
u, err := user.Get(r.Context(), state.Client, state.Session.GetUserId())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// get the device credentials
|
||||
knownDeviceCredentials, err := getKnownDeviceCredentials(r.Context(), state.Client, u.GetDeviceCredentialIds()...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// get the stored device type
|
||||
deviceType := webauthnutil.GetDeviceType(r.Context(), state.Client, deviceTypeParam)
|
||||
|
||||
creationOptions = webauthnutil.GenerateCreationOptions(r, state.SharedKey, deviceType, u)
|
||||
requestOptions = webauthnutil.GenerateRequestOptions(r, state.SharedKey, deviceType, knownDeviceCredentials)
|
||||
return creationOptions, requestOptions, nil
|
||||
}
|
||||
|
||||
func (h *Handler) handle(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := h.getState(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = middleware.ValidateRequestURL(
|
||||
urlutil.GetExternalRequest(s.InternalAuthenticateURL, s.AuthenticateURL, r),
|
||||
s.SharedKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
return h.handleView(w, r, s)
|
||||
case r.FormValue("action") == "authenticate":
|
||||
return h.handleAuthenticate(w, r, s)
|
||||
case r.FormValue("action") == "register":
|
||||
return h.handleRegister(w, r, s)
|
||||
case r.FormValue("action") == "unregister":
|
||||
return h.handleUnregister(w, r, s)
|
||||
}
|
||||
|
||||
return httputil.NewError(http.StatusNotFound, errors.New(http.StatusText(http.StatusNotFound)))
|
||||
}
|
||||
|
||||
func (h *Handler) handleAuthenticate(w http.ResponseWriter, r *http.Request, state *State) error {
|
||||
ctx := r.Context()
|
||||
|
||||
deviceTypeParam := r.FormValue(urlutil.QueryDeviceType)
|
||||
if deviceTypeParam == "" {
|
||||
return errMissingDeviceType
|
||||
}
|
||||
|
||||
redirectURIParam := r.FormValue(urlutil.QueryRedirectURI)
|
||||
if redirectURIParam == "" {
|
||||
return errMissingRedirectURI
|
||||
}
|
||||
|
||||
responseParam := r.FormValue("authenticate_response")
|
||||
var credential webauthn.PublicKeyAssertionCredential
|
||||
err := json.Unmarshal([]byte(responseParam), &credential)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, errors.New("invalid authenticate response"))
|
||||
}
|
||||
credentialJSON, err := json.Marshal(credential)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Set the UserHandle which won't typically be filled in by the client
|
||||
credential.Response.UserHandle = webauthnutil.GetUserEntityID(state.Session.GetUserId())
|
||||
|
||||
// get the user information
|
||||
u, err := user.Get(ctx, state.Client, state.Session.GetUserId())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving user record: %w", err)
|
||||
}
|
||||
|
||||
// get the stored device type
|
||||
deviceType := webauthnutil.GetDeviceType(ctx, state.Client, deviceTypeParam)
|
||||
|
||||
// get the device credentials
|
||||
knownDeviceCredentials, err := getKnownDeviceCredentials(ctx, state.Client, u.GetDeviceCredentialIds()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving webauthn known device credentials: %w", err)
|
||||
}
|
||||
|
||||
requestOptions, err := webauthnutil.GetRequestOptionsForCredential(
|
||||
r,
|
||||
state.SharedKey,
|
||||
deviceType,
|
||||
knownDeviceCredentials,
|
||||
&credential,
|
||||
)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid register options: %w", err))
|
||||
}
|
||||
|
||||
serverCredential, err := state.RelyingParty.VerifyAuthenticationCeremony(
|
||||
ctx,
|
||||
requestOptions,
|
||||
&credential,
|
||||
)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("error verifying registration: %w", err))
|
||||
}
|
||||
|
||||
// store the authenticate response
|
||||
for _, deviceCredential := range knownDeviceCredentials {
|
||||
webauthnCredential := deviceCredential.GetWebauthn()
|
||||
if webauthnCredential == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if !bytes.Equal(webauthnCredential.Id, serverCredential.ID) {
|
||||
continue
|
||||
}
|
||||
|
||||
// add the response to the list and cap it, removing the oldest responses
|
||||
webauthnCredential.AuthenticateResponse = append(webauthnCredential.AuthenticateResponse, credentialJSON)
|
||||
for len(webauthnCredential.AuthenticateResponse) > maxAuthenticateResponses {
|
||||
webauthnCredential.AuthenticateResponse = webauthnCredential.AuthenticateResponse[1:]
|
||||
}
|
||||
|
||||
// store the updated device credential
|
||||
err = device.PutCredential(ctx, state.Client, deviceCredential)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// update the session
|
||||
state.Session.DeviceCredentials = append(state.Session.DeviceCredentials, &session.Session_DeviceCredential{
|
||||
TypeId: deviceType.GetId(),
|
||||
Credential: &session.Session_DeviceCredential_Id{
|
||||
Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID),
|
||||
},
|
||||
})
|
||||
return h.saveSessionAndRedirect(w, r, state, redirectURIParam)
|
||||
}
|
||||
|
||||
func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *State) error {
|
||||
ctx := r.Context()
|
||||
|
||||
deviceTypeParam := r.FormValue(urlutil.QueryDeviceType)
|
||||
if deviceTypeParam == "" {
|
||||
return errMissingDeviceType
|
||||
}
|
||||
|
||||
redirectURIParam := r.FormValue(urlutil.QueryRedirectURI)
|
||||
if redirectURIParam == "" {
|
||||
return errMissingRedirectURI
|
||||
}
|
||||
|
||||
responseParam := r.FormValue("register_response")
|
||||
var credential webauthn.PublicKeyCreationCredential
|
||||
err := json.Unmarshal([]byte(responseParam), &credential)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, errors.New("invalid register response"))
|
||||
}
|
||||
credentialJSON, err := json.Marshal(credential)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// get the user information
|
||||
u, err := user.Get(ctx, state.Client, state.Session.GetUserId())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving user record: %w", err)
|
||||
}
|
||||
|
||||
// get the stored device type
|
||||
deviceType := webauthnutil.GetDeviceType(ctx, state.Client, deviceTypeParam)
|
||||
|
||||
creationOptions, err := webauthnutil.GetCreationOptionsForCredential(
|
||||
r,
|
||||
state.SharedKey,
|
||||
deviceType,
|
||||
u,
|
||||
&credential,
|
||||
)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid register options: %w", err))
|
||||
}
|
||||
creationOptionsJSON, err := json.Marshal(creationOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverCredential, err := state.RelyingParty.VerifyRegistrationCeremony(
|
||||
ctx,
|
||||
creationOptions,
|
||||
&credential,
|
||||
)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("error verifying registration: %w", err))
|
||||
}
|
||||
|
||||
deviceCredentialID := webauthnutil.GetDeviceCredentialID(serverCredential.ID)
|
||||
|
||||
deviceEnrollment, err := getOrCreateDeviceEnrollment(ctx, r, state, deviceType.GetId(), deviceCredentialID, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save the credential
|
||||
deviceCredential := &device.Credential{
|
||||
Id: deviceCredentialID,
|
||||
TypeId: deviceType.GetId(),
|
||||
EnrollmentId: deviceEnrollment.GetId(),
|
||||
UserId: u.GetId(),
|
||||
Specifier: &device.Credential_Webauthn{
|
||||
Webauthn: &device.Credential_WebAuthn{
|
||||
Id: serverCredential.ID,
|
||||
PublicKey: serverCredential.PublicKey,
|
||||
|
||||
RegisterOptions: creationOptionsJSON,
|
||||
RegisterResponse: credentialJSON,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = device.PutCredential(ctx, state.Client, deviceCredential)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save the user
|
||||
u.AddDeviceCredentialID(deviceCredential.GetId())
|
||||
_, err = databroker.Put(ctx, state.Client, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update the session
|
||||
state.Session.DeviceCredentials = append(state.Session.DeviceCredentials, &session.Session_DeviceCredential{
|
||||
TypeId: deviceType.GetId(),
|
||||
Credential: &session.Session_DeviceCredential_Id{
|
||||
Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID),
|
||||
},
|
||||
})
|
||||
|
||||
return h.saveSessionAndRedirect(w, r, state, redirectURIParam)
|
||||
}
|
||||
|
||||
func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state *State) error {
|
||||
ctx := r.Context()
|
||||
|
||||
// get the user information
|
||||
u, err := user.Get(ctx, state.Client, state.Session.GetUserId())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deviceCredentialID := r.FormValue(urlutil.QueryDeviceCredentialID)
|
||||
if deviceCredentialID == "" {
|
||||
return errMissingDeviceCredentialID
|
||||
}
|
||||
|
||||
// ensure we only allow removing a device credential the user owns
|
||||
if !u.HasDeviceCredentialID(deviceCredentialID) {
|
||||
return errInvalidDeviceCredential
|
||||
}
|
||||
|
||||
// delete the credential
|
||||
deviceCredential, err := device.DeleteCredential(ctx, state.Client, deviceCredentialID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete the corresponding enrollment
|
||||
_, err = device.DeleteEnrollment(ctx, state.Client, deviceCredential.GetEnrollmentId())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the credential from the user
|
||||
u.RemoveDeviceCredentialID(deviceCredentialID)
|
||||
_, err = databroker.Put(ctx, state.Client, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the credential from the session
|
||||
state.Session.RemoveDeviceCredentialID(deviceCredentialID)
|
||||
return h.saveSessionAndRedirect(w, r, state, urlutil.GetAbsoluteURL(r).ResolveReference(&url.URL{
|
||||
Path: "/.pomerium",
|
||||
}).String())
|
||||
}
|
||||
|
||||
func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *State) error {
|
||||
deviceTypeParam := r.FormValue(urlutil.QueryDeviceType)
|
||||
if deviceTypeParam == "" {
|
||||
return errMissingDeviceType
|
||||
}
|
||||
|
||||
creationOptions, requestOptions, err := h.getOptions(r, state, deviceTypeParam)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]interface{}{
|
||||
"creationOptions": creationOptions,
|
||||
"requestOptions": requestOptions,
|
||||
"selfUrl": r.URL.String(),
|
||||
}
|
||||
httputil.AddBrandingOptionsToMap(m, state.BrandingOptions)
|
||||
return ui.ServePage(w, r, "WebAuthnRegistration", m)
|
||||
}
|
||||
|
||||
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
|
||||
// save the session to the databroker
|
||||
res, err := session.Put(r.Context(), state.Client, state.Session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add databroker versions to the session cookie and save
|
||||
state.SessionState.DatabrokerServerVersion = res.GetServerVersion()
|
||||
state.SessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
||||
err = state.SessionStore.SaveSession(w, r, state.SessionState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if the redirect URL is for a URL we don't control, just do a plain redirect
|
||||
if !isURLForPomerium(state.PomeriumDomains, rawRedirectURI) {
|
||||
httputil.Redirect(w, r, rawRedirectURI, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sign+encrypt the session JWT
|
||||
encoder, err := jws.NewHS256Signer(state.SharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signedJWT, err := encoder.Marshal(state.SessionState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cipher, err := cryptutil.NewAEADCipher(state.SharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedJWT := cryptutil.Encrypt(cipher, signedJWT, nil)
|
||||
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
||||
|
||||
// redirect to the proxy callback URL with the session
|
||||
callbackURL, err := urlutil.GetCallbackURLForRedirectURI(r, encodedJWT, rawRedirectURI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signedCallbackURL := urlutil.NewSignedURL(state.SharedKey, callbackURL)
|
||||
httputil.Redirect(w, r, signedCallbackURL.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getKnownDeviceCredentials(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
deviceCredentialIDs ...string,
|
||||
) ([]*device.Credential, error) {
|
||||
var knownDeviceCredentials []*device.Credential
|
||||
for _, deviceCredentialID := range deviceCredentialIDs {
|
||||
deviceCredential, err := device.GetCredential(ctx, client, deviceCredentialID)
|
||||
if status.Code(err) == codes.NotFound {
|
||||
// ignore missing devices
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, httputil.NewError(http.StatusInternalServerError,
|
||||
fmt.Errorf("error retrieving device credential: %w", err))
|
||||
}
|
||||
knownDeviceCredentials = append(knownDeviceCredentials, deviceCredential)
|
||||
}
|
||||
return knownDeviceCredentials, nil
|
||||
}
|
||||
|
||||
func getOrCreateDeviceEnrollment(
|
||||
ctx context.Context,
|
||||
r *http.Request,
|
||||
state *State,
|
||||
deviceTypeID string,
|
||||
deviceCredentialID string,
|
||||
u *user.User,
|
||||
) (*device.Enrollment, error) {
|
||||
var deviceEnrollment *device.Enrollment
|
||||
|
||||
enrollmentTokenParam := r.FormValue(urlutil.QueryEnrollmentToken)
|
||||
if enrollmentTokenParam == "" {
|
||||
// create a new enrollment
|
||||
deviceEnrollment = &device.Enrollment{
|
||||
Id: uuid.New().String(),
|
||||
TypeId: deviceTypeID,
|
||||
UserId: u.GetId(),
|
||||
}
|
||||
} else {
|
||||
// use an existing enrollment
|
||||
deviceEnrollmentID, err := webauthnutil.ParseAndVerifyEnrollmentToken(state.SharedKey, enrollmentTokenParam)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid enrollment token: %w", err))
|
||||
}
|
||||
|
||||
deviceEnrollment, err = device.GetEnrollment(ctx, state.Client, deviceEnrollmentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if deviceEnrollment.GetTypeId() != deviceTypeID {
|
||||
return nil, httputil.NewError(http.StatusForbidden, fmt.Errorf("invalid enrollment token: wrong device type"))
|
||||
}
|
||||
|
||||
if deviceEnrollment.GetUserId() != u.GetId() {
|
||||
return nil, httputil.NewError(http.StatusForbidden, fmt.Errorf("invalid enrollment token: wrong user id"))
|
||||
}
|
||||
|
||||
if deviceEnrollment.GetEnrolledAt().IsValid() {
|
||||
return nil, httputil.NewError(http.StatusForbidden, fmt.Errorf("invalid enrollment token: already used for existing credential"))
|
||||
}
|
||||
}
|
||||
|
||||
deviceEnrollment.CredentialId = deviceCredentialID
|
||||
deviceEnrollment.EnrolledAt = timestamppb.Now()
|
||||
deviceEnrollment.UserAgent = r.UserAgent()
|
||||
deviceEnrollment.IpAddress = httputil.GetClientIPAddress(r)
|
||||
|
||||
err := device.PutEnrollment(ctx, state.Client, deviceEnrollment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return deviceEnrollment, nil
|
||||
}
|
||||
|
||||
func isURLForPomerium(pomeriumDomains []string, rawURI string) bool {
|
||||
uri, err := urlutil.ParseAndValidateURL(rawURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, domain := range pomeriumDomains {
|
||||
if urlutil.StripPort(domain) == urlutil.StripPort(uri.Host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
29
internal/handlers/well_known_pomerium.go
Normal file
29
internal/handlers/well_known_pomerium.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// WellKnownPomerium returns the /.well-known/pomerium handler.
|
||||
func WellKnownPomerium(authenticateURL *url.URL) http.Handler {
|
||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
wellKnownURLs := struct {
|
||||
OAuth2Callback string `json:"authentication_callback_endpoint"` // RFC6749
|
||||
JSONWebKeySetURL string `json:"jwks_uri"` // RFC7517
|
||||
FrontchannelLogoutURI string `json:"frontchannel_logout_uri"` // https://openid.net/specs/openid-connect-frontchannel-1_0.html
|
||||
}{
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/sign_out"}).String(),
|
||||
}
|
||||
w.Header().Set("X-CSRF-Token", csrf.Token(r))
|
||||
httputil.RenderJSON(w, http.StatusOK, wellKnownURLs)
|
||||
return nil
|
||||
}))
|
||||
}
|
24
internal/handlers/well_known_pomerium_test.go
Normal file
24
internal/handlers/well_known_pomerium_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWellKnownPomeriumHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
authenticateURL, _ := url.Parse("https://authenticate.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", authenticateURL.String())
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
WellKnownPomerium(authenticateURL).ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
|
@ -2,34 +2,12 @@ package httputil
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
// HealthCheck is a simple healthcheck handler that responds to GET and HEAD
|
||||
// http requests.
|
||||
func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if r.Method == http.MethodGet {
|
||||
fmt.Fprintln(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect wraps the std libs's redirect method indicating that pomerium is
|
||||
// the origin of the response.
|
||||
func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) {
|
||||
|
@ -72,41 +50,3 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
e.ErrorResponse(r.Context(), w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||
func JWKSHandler(rawSigningKey string) http.Handler {
|
||||
return cors.AllowAll().Handler(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
var jwks jose.JSONWebKeySet
|
||||
if rawSigningKey != "" {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
|
||||
if err != nil {
|
||||
return NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwk, err := cryptutil.PublicJWKFromBytes(decodedCert)
|
||||
if err != nil {
|
||||
return NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwks.Keys = append(jwks.Keys, *jwk)
|
||||
}
|
||||
RenderJSON(w, http.StatusOK, jwks)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// WellKnownPomeriumHandler returns the /.well-known/pomerium handler.
|
||||
func WellKnownPomeriumHandler(authenticateURL *url.URL) http.Handler {
|
||||
return cors.AllowAll().Handler(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
wellKnownURLs := struct {
|
||||
OAuth2Callback string `json:"authentication_callback_endpoint"` // RFC6749
|
||||
JSONWebKeySetURL string `json:"jwks_uri"` // RFC7517
|
||||
FrontchannelLogoutURI string `json:"frontchannel_logout_uri"` // https://openid.net/specs/openid-connect-frontchannel-1_0.html
|
||||
}{
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/sign_out"}).String(),
|
||||
}
|
||||
w.Header().Set("X-CSRF-Token", csrf.Token(r))
|
||||
RenderJSON(w, http.StatusOK, wellKnownURLs)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -5,42 +5,11 @@ import (
|
|||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good - Get", http.MethodGet, http.StatusOK},
|
||||
{"good - Head", http.MethodHead, http.StatusOK},
|
||||
{"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed},
|
||||
{"bad - Put", http.MethodPut, http.StatusMethodNotAllowed},
|
||||
{"bad - Post", http.MethodPost, http.StatusMethodNotAllowed},
|
||||
{"bad - route miss", http.MethodGet, http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HealthCheck(w, r)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
@ -150,30 +119,3 @@ func TestRenderJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWKSHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", "https://www.example.com")
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
JWKSHandler("").ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWellKnownPomeriumHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
authenticateURL, _ := url.Parse("https://authenticate.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", authenticateURL.String())
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
WellKnownPomeriumHandler(authenticateURL).ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue