From c1a522cd8206bdc682d1780779d3a0bedc40d50c Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 22 Nov 2022 10:26:35 -0700 Subject: [PATCH] proxy: add userinfo and webauthn endpoints (#3755) * proxy: add userinfo and webauthn endpoints * use TLD for RP id * use EffectiveTLDPlusOne * upgrade webauthn * fix test * Update internal/handlers/jwks.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> --- authenticate/authenticate.go | 2 +- authenticate/handlers.go | 15 +- authenticate/handlers/handlers.go | 2 - authenticate/handlers_test.go | 2 +- authenticate/state.go | 9 -- go.mod | 6 +- go.sum | 14 +- internal/controlplane/http.go | 15 +- .../handlers/device-enrolled.go | 0 internal/handlers/handlers.go | 2 + internal/handlers/health_check.go | 20 +++ internal/handlers/health_check_test.go | 36 +++++ internal/handlers/jwks.go | 33 ++++ internal/handlers/jwks_test.go | 22 +++ .../handlers/signout.go | 0 .../handlers/userinfo.go | 0 .../handlers/webauthn/webauthn.go | 28 ++-- internal/handlers/well_known_pomerium.go | 29 ++++ internal/handlers/well_known_pomerium_test.go | 24 +++ internal/httputil/handlers.go | 60 ------- internal/httputil/handlers_test.go | 58 ------- pkg/webauthnutil/options.go | 15 +- pkg/webauthnutil/options_test.go | 25 ++- pkg/webauthnutil/webauthnutil.go | 30 ++++ pkg/webauthnutil/webauthnutil_test.go | 31 ++++ proxy/data.go | 151 ++++++++++++++++++ proxy/handlers.go | 34 ++-- proxy/handlers_test.go | 23 --- proxy/proxy.go | 3 + proxy/state.go | 21 +++ .../components/WebAuthnAuthenticateButton.tsx | 1 + ui/src/components/WebAuthnRegisterButton.tsx | 1 + ui/src/types/index.ts | 2 + 33 files changed, 498 insertions(+), 216 deletions(-) delete mode 100644 authenticate/handlers/handlers.go rename {authenticate => internal}/handlers/device-enrolled.go (100%) create mode 100644 internal/handlers/handlers.go create mode 100644 internal/handlers/health_check.go create mode 100644 internal/handlers/health_check_test.go create mode 100644 internal/handlers/jwks.go create mode 100644 internal/handlers/jwks_test.go rename {authenticate => internal}/handlers/signout.go (100%) rename {authenticate => internal}/handlers/userinfo.go (100%) rename {authenticate => internal}/handlers/webauthn/webauthn.go (94%) create mode 100644 internal/handlers/well_known_pomerium.go create mode 100644 internal/handlers/well_known_pomerium_test.go create mode 100644 pkg/webauthnutil/webauthnutil_test.go create mode 100644 proxy/data.go diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 206c0c3b6..9748f9e9a 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -7,9 +7,9 @@ import ( "errors" "fmt" - "github.com/pomerium/pomerium/authenticate/handlers/webauthn" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/cryptutil" ) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index d81404f28..715f1d01b 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -18,8 +18,8 @@ import ( "github.com/pomerium/csrf" "github.com/pomerium/datasource/pkg/directory" - "github.com/pomerium/pomerium/authenticate/handlers" - "github.com/pomerium/pomerium/authenticate/handlers/webauthn" + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/manager" @@ -33,6 +33,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/webauthnutil" ) // Handler returns the authenticate service's handler chain. @@ -544,7 +545,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, Id: pbSession.GetUserId(), } } - creationOptions, requestOptions, _ := a.webauthn.GetOptions(r.Context()) + creationOptions, requestOptions, _ := a.webauthn.GetOptions(r) data := handlers.UserInfoData{ CSRFToken: csrf.Token(r), @@ -715,15 +716,15 @@ func (a *Authenticate) getUser(ctx context.Context, userID string) (*user.User, return user.Get(ctx, client, userID) } -func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) { +func (a *Authenticate) getWebauthnState(r *http.Request) (*webauthn.State, error) { state := a.state.Load() - s, _, err := a.getCurrentSession(ctx) + s, _, err := a.getCurrentSession(r.Context()) if err != nil { return nil, err } - ss, err := a.getSessionFromCtx(ctx) + ss, err := a.getSessionFromCtx(r.Context()) if err != nil { return nil, err } @@ -752,7 +753,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e Session: s, SessionState: ss, SessionStore: state.sessionStore, - RelyingParty: state.webauthnRelyingParty, + RelyingParty: webauthnutil.GetRelyingParty(r, state.dataBrokerClient), BrandingOptions: a.options.Load().BrandingOptions, }, nil } diff --git a/authenticate/handlers/handlers.go b/authenticate/handlers/handlers.go deleted file mode 100644 index 5e6ef3192..000000000 --- a/authenticate/handlers/handlers.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package handlers contains various web handlers for the authenticate service. -package handlers diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 8573538ff..17fd8313d 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -21,12 +21,12 @@ import ( "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/pomerium/pomerium/authenticate/handlers/webauthn" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/oidc" diff --git a/authenticate/state.go b/authenticate/state.go index bc2c665ed..d9b77ef8f 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -18,8 +18,6 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/webauthnutil" - "github.com/pomerium/webauthn" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) @@ -46,8 +44,6 @@ type authenticateState struct { jwk *jose.JSONWebKeySet dataBrokerClient databroker.DataBrokerServiceClient - - webauthnRelyingParty *webauthn.RelyingParty } func newAuthenticateState() *authenticateState { @@ -153,10 +149,5 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) - state.webauthnRelyingParty = webauthn.NewRelyingParty( - authenticateURL.String(), - webauthnutil.NewCredentialStorage(state.dataBrokerClient), - ) - return state, nil } diff --git a/go.mod b/go.mod index 01e7b27a0..130c3cc7f 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,7 @@ require ( github.com/peterbourgon/ff/v3 v3.3.0 github.com/pomerium/csrf v1.7.0 github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 - github.com/pomerium/webauthn v0.0.0-20211014213840-422c7ce1077f + github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_model v0.3.0 github.com/prometheus/common v0.37.0 @@ -126,7 +126,7 @@ require ( github.com/fatih/structtag v1.2.0 // indirect github.com/felixge/httpsnoop v1.0.2 // indirect github.com/firefart/nonamedreturns v1.0.4 // indirect - github.com/fxamacker/cbor/v2 v2.3.0 // indirect + github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/fzipp/gocyclo v0.6.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-critic/go-critic v0.6.5 // indirect @@ -155,7 +155,7 @@ require ( github.com/golangci/misspell v0.3.5 // indirect github.com/golangci/revgrep v0.0.0-20220804021717-745bb2f7c2e6 // indirect github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4 // indirect - github.com/google/go-tpm v0.3.2 // indirect + github.com/google/go-tpm v0.3.3 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8 // indirect diff --git a/go.sum b/go.sum index d9bb85256..b79e14666 100644 --- a/go.sum +++ b/go.sum @@ -280,8 +280,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/fxamacker/cbor/v2 v2.3.0 h1:aM45YGMctNakddNNAezPxDUpv38j44Abh+hifNuqXik= -github.com/fxamacker/cbor/v2 v2.3.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= +github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= @@ -435,8 +435,8 @@ github.com/google/go-jsonnet v0.19.1 h1:MORxkrG0elylUqh36R4AcSPX0oZQa9hvI3lroN+k github.com/google/go-jsonnet v0.19.1/go.mod h1:5JVT33JVCoehdTj5Z2KJq1eIdt3Nb8PCmZ+W5D8U350= github.com/google/go-tpm v0.1.2-0.20190725015402-ae6dd98980d4/go.mod h1:H9HbmUG2YgV/PHITkO7p6wxEEj/v5nlsVWIwumwH2NI= github.com/google/go-tpm v0.3.0/go.mod h1:iVLWvrPp/bHeEkxTFi9WG6K9w0iy2yIszHwZGHPbzAw= -github.com/google/go-tpm v0.3.2 h1:3iQQ2dlEf+1no7CLlfLPYzxhQy7j2G/emBqU5okydaw= -github.com/google/go-tpm v0.3.2/go.mod h1:j71sMBTfp3X5jPHz852ZOfQMUOf65Gb/Th8pRmp7fvg= +github.com/google/go-tpm v0.3.3 h1:P/ZFNBZYXRxc+z7i5uyd8VP7MaDteuLZInzrH2idRGo= +github.com/google/go-tpm v0.3.3/go.mod h1:9Hyn3rgnzWF9XBWVk6ml6A6hNkbWjNFlDQL51BeghL4= github.com/google/go-tpm-tools v0.0.0-20190906225433-1614c142f845/go.mod h1:AVfHadzbdzHo54inR2x1v640jdi1YSi3NauM2DUsxk0= github.com/google/go-tpm-tools v0.2.0/go.mod h1:npUd03rQ60lxN7tzeBJreG38RvWwme2N1reF/eeiBk4= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -791,8 +791,8 @@ github.com/pomerium/csrf v1.7.0 h1:Qp4t6oyEod3svQtKfJZs589mdUTWKVf7q0PgCKYCshY= github.com/pomerium/csrf v1.7.0/go.mod h1:hAPZV47mEj2T9xFs+ysbum4l7SF1IdrryYaY6PdoIqw= github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb54tEEbr0L73rjHkpLB0IB6qh3zl1+XQbMLis= github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0= -github.com/pomerium/webauthn v0.0.0-20211014213840-422c7ce1077f h1:442shkoI4Oh4RHdzFaGma1t9Ji/T+8pfCxQQzmY5kj8= -github.com/pomerium/webauthn v0.0.0-20211014213840-422c7ce1077f/go.mod h1:wgH3ualWdXu/qwbhOoSQedXzco+38Iz7qKKGCJcKPXg= +github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b h1:oll/aOfJudnqFAwCvoXK9+WN2zVjTzHVPLXCggHQmHk= +github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b/go.mod h1:KswTenBBh4y1pmhU2dpm8VgJQCgSErCg7OOFTeebrNc= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -1280,7 +1280,6 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1298,6 +1297,7 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index 162254b47..d7ce27c5a 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -7,10 +7,11 @@ import ( "time" "github.com/CAFxX/httpcompression" - "github.com/gorilla/handlers" + gorillahandlers "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry" @@ -37,7 +38,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) { Str("path", r.URL.String()). Msg("http-request") })) - root.Use(handlers.RecoveryHandler()) + root.Use(gorillahandlers.RecoveryHandler()) root.Use(log.HeadersHandler(httputil.HeadersXForwarded)) root.Use(log.RemoteAddrHandler("ip")) root.Use(log.UserAgentHandler("user_agent")) @@ -59,10 +60,10 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er return fmt.Errorf("invalid signing key: %w", err) } - root.HandleFunc("/healthz", httputil.HealthCheck) - root.HandleFunc("/ping", httputil.HealthCheck) - root.Handle("/.well-known/pomerium", httputil.WellKnownPomeriumHandler(authenticateURL)) - root.Handle("/.well-known/pomerium/", httputil.WellKnownPomeriumHandler(authenticateURL)) - root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(httputil.JWKSHandler(rawSigningKey)) + root.HandleFunc("/healthz", handlers.HealthCheck) + root.HandleFunc("/ping", handlers.HealthCheck) + root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL)) + root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL)) + root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey)) return nil } diff --git a/authenticate/handlers/device-enrolled.go b/internal/handlers/device-enrolled.go similarity index 100% rename from authenticate/handlers/device-enrolled.go rename to internal/handlers/device-enrolled.go diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go new file mode 100644 index 000000000..cad896677 --- /dev/null +++ b/internal/handlers/handlers.go @@ -0,0 +1,2 @@ +// Package handlers contains HTTP handlers used by Pomerium. +package handlers diff --git a/internal/handlers/health_check.go b/internal/handlers/health_check.go new file mode 100644 index 000000000..3a77961ae --- /dev/null +++ b/internal/handlers/health_check.go @@ -0,0 +1,20 @@ +package handlers + +import ( + "fmt" + "net/http" +) + +// HealthCheck is a simple healthcheck handler that responds to GET and HEAD +// http requests. +func HealthCheck(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + if r.Method == http.MethodGet { + fmt.Fprintln(w, http.StatusText(http.StatusOK)) + } +} diff --git a/internal/handlers/health_check_test.go b/internal/handlers/health_check_test.go new file mode 100644 index 000000000..bbf5b83cd --- /dev/null +++ b/internal/handlers/health_check_test.go @@ -0,0 +1,36 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestHealthCheck(t *testing.T) { + t.Parallel() + tests := []struct { + name string + method string + + wantStatus int + }{ + {"good - Get", http.MethodGet, http.StatusOK}, + {"good - Head", http.MethodHead, http.StatusOK}, + {"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed}, + {"bad - Put", http.MethodPut, http.StatusMethodNotAllowed}, + {"bad - Post", http.MethodPost, http.StatusMethodNotAllowed}, + {"bad - route miss", http.MethodGet, http.StatusOK}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(tt.method, "/", nil) + w := httptest.NewRecorder() + + HealthCheck(w, r) + if w.Code != tt.wantStatus { + t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String()) + } + }) + } +} diff --git a/internal/handlers/jwks.go b/internal/handlers/jwks.go new file mode 100644 index 000000000..7abee06b0 --- /dev/null +++ b/internal/handlers/jwks.go @@ -0,0 +1,33 @@ +package handlers + +import ( + "encoding/base64" + "errors" + "net/http" + + "github.com/go-jose/go-jose/v3" + "github.com/rs/cors" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +// JWKSHandler returns the /.well-known/pomerium/jwks.json handler. +func JWKSHandler(rawSigningKey string) http.Handler { + return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + var jwks jose.JSONWebKeySet + if rawSigningKey != "" { + decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, errors.New("bad base64 encoding for signing key")) + } + jwk, err := cryptutil.PublicJWKFromBytes(decodedCert) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, errors.New("bad signing key")) + } + jwks.Keys = append(jwks.Keys, *jwk) + } + httputil.RenderJSON(w, http.StatusOK, jwks) + return nil + })) +} diff --git a/internal/handlers/jwks_test.go b/internal/handlers/jwks_test.go new file mode 100644 index 000000000..3c3442b89 --- /dev/null +++ b/internal/handlers/jwks_test.go @@ -0,0 +1,22 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestJWKSHandler(t *testing.T) { + t.Parallel() + + t.Run("cors", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodOptions, "/", nil) + r.Header.Set("Origin", "https://www.example.com") + r.Header.Set("Access-Control-Request-Method", "GET") + JWKSHandler("").ServeHTTP(w, r) + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + }) +} diff --git a/authenticate/handlers/signout.go b/internal/handlers/signout.go similarity index 100% rename from authenticate/handlers/signout.go rename to internal/handlers/signout.go diff --git a/authenticate/handlers/userinfo.go b/internal/handlers/userinfo.go similarity index 100% rename from authenticate/handlers/userinfo.go rename to internal/handlers/userinfo.go diff --git a/authenticate/handlers/webauthn/webauthn.go b/internal/handlers/webauthn/webauthn.go similarity index 94% rename from authenticate/handlers/webauthn/webauthn.go rename to internal/handlers/webauthn/webauthn.go index 84e792cf2..976d6ff39 100644 --- a/authenticate/handlers/webauthn/webauthn.go +++ b/internal/handlers/webauthn/webauthn.go @@ -59,7 +59,7 @@ type State struct { } // A StateProvider provides state for the handler. -type StateProvider = func(context.Context) (*State, error) +type StateProvider = func(*http.Request) (*State, error) // Handler is the WebAuthn device handler. type Handler struct { @@ -74,17 +74,17 @@ func New(getState StateProvider) *Handler { } // GetOptions returns the creation and request options for WebAuthn. -func (h *Handler) GetOptions(ctx context.Context) ( +func (h *Handler) GetOptions(r *http.Request) ( creationOptions *webauthn.PublicKeyCredentialCreationOptions, requestOptions *webauthn.PublicKeyCredentialRequestOptions, err error, ) { - state, err := h.getState(ctx) + state, err := h.getState(r) if err != nil { return nil, nil, err } - return h.getOptions(ctx, state, webauthnutil.DefaultDeviceType) + return h.getOptions(r, state, webauthnutil.DefaultDeviceType) } // ServeHTTP serves the HTTP handler. @@ -92,33 +92,33 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { httputil.HandlerFunc(h.handle).ServeHTTP(w, r) } -func (h *Handler) getOptions(ctx context.Context, state *State, deviceTypeParam string) ( +func (h *Handler) getOptions(r *http.Request, state *State, deviceTypeParam string) ( creationOptions *webauthn.PublicKeyCredentialCreationOptions, requestOptions *webauthn.PublicKeyCredentialRequestOptions, err error, ) { // get the user information - u, err := user.Get(ctx, state.Client, state.Session.GetUserId()) + u, err := user.Get(r.Context(), state.Client, state.Session.GetUserId()) if err != nil { return nil, nil, err } // get the device credentials - knownDeviceCredentials, err := getKnownDeviceCredentials(ctx, state.Client, u.GetDeviceCredentialIds()...) + knownDeviceCredentials, err := getKnownDeviceCredentials(r.Context(), state.Client, u.GetDeviceCredentialIds()...) if err != nil { return nil, nil, err } // get the stored device type - deviceType := webauthnutil.GetDeviceType(ctx, state.Client, deviceTypeParam) + deviceType := webauthnutil.GetDeviceType(r.Context(), state.Client, deviceTypeParam) - creationOptions = webauthnutil.GenerateCreationOptions(state.SharedKey, deviceType, u) - requestOptions = webauthnutil.GenerateRequestOptions(state.SharedKey, deviceType, knownDeviceCredentials) + creationOptions = webauthnutil.GenerateCreationOptions(r, state.SharedKey, deviceType, u) + requestOptions = webauthnutil.GenerateRequestOptions(r, state.SharedKey, deviceType, knownDeviceCredentials) return creationOptions, requestOptions, nil } func (h *Handler) handle(w http.ResponseWriter, r *http.Request) error { - s, err := h.getState(r.Context()) + s, err := h.getState(r) if err != nil { return err } @@ -187,6 +187,7 @@ func (h *Handler) handleAuthenticate(w http.ResponseWriter, r *http.Request, sta } requestOptions, err := webauthnutil.GetRequestOptionsForCredential( + r, state.SharedKey, deviceType, knownDeviceCredentials, @@ -273,6 +274,7 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state * deviceType := webauthnutil.GetDeviceType(ctx, state.Client, deviceTypeParam) creationOptions, err := webauthnutil.GetCreationOptionsForCredential( + r, state.SharedKey, deviceType, u, @@ -387,14 +389,12 @@ func (h *Handler) handleUnregister(w http.ResponseWriter, r *http.Request, state } func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *State) error { - ctx := r.Context() - deviceTypeParam := r.FormValue(urlutil.QueryDeviceType) if deviceTypeParam == "" { return errMissingDeviceType } - creationOptions, requestOptions, err := h.getOptions(ctx, state, deviceTypeParam) + creationOptions, requestOptions, err := h.getOptions(r, state, deviceTypeParam) if err != nil { return err } diff --git a/internal/handlers/well_known_pomerium.go b/internal/handlers/well_known_pomerium.go new file mode 100644 index 000000000..c9525deca --- /dev/null +++ b/internal/handlers/well_known_pomerium.go @@ -0,0 +1,29 @@ +package handlers + +import ( + "net/http" + "net/url" + + "github.com/rs/cors" + + "github.com/pomerium/csrf" + "github.com/pomerium/pomerium/internal/httputil" +) + +// WellKnownPomerium returns the /.well-known/pomerium handler. +func WellKnownPomerium(authenticateURL *url.URL) http.Handler { + return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + wellKnownURLs := struct { + OAuth2Callback string `json:"authentication_callback_endpoint"` // RFC6749 + JSONWebKeySetURL string `json:"jwks_uri"` // RFC7517 + FrontchannelLogoutURI string `json:"frontchannel_logout_uri"` // https://openid.net/specs/openid-connect-frontchannel-1_0.html + }{ + authenticateURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(), + authenticateURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(), + authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/sign_out"}).String(), + } + w.Header().Set("X-CSRF-Token", csrf.Token(r)) + httputil.RenderJSON(w, http.StatusOK, wellKnownURLs) + return nil + })) +} diff --git a/internal/handlers/well_known_pomerium_test.go b/internal/handlers/well_known_pomerium_test.go new file mode 100644 index 000000000..fc690e882 --- /dev/null +++ b/internal/handlers/well_known_pomerium_test.go @@ -0,0 +1,24 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWellKnownPomeriumHandler(t *testing.T) { + t.Parallel() + + t.Run("cors", func(t *testing.T) { + authenticateURL, _ := url.Parse("https://authenticate.example.com") + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodOptions, "/", nil) + r.Header.Set("Origin", authenticateURL.String()) + r.Header.Set("Access-Control-Request-Method", "GET") + WellKnownPomerium(authenticateURL).ServeHTTP(w, r) + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + }) +} diff --git a/internal/httputil/handlers.go b/internal/httputil/handlers.go index f775f5dd2..03826de57 100644 --- a/internal/httputil/handlers.go +++ b/internal/httputil/handlers.go @@ -2,34 +2,12 @@ package httputil import ( "bytes" - "encoding/base64" "encoding/json" "errors" "fmt" "net/http" - "net/url" - - "github.com/go-jose/go-jose/v3" - "github.com/rs/cors" - - "github.com/pomerium/csrf" - "github.com/pomerium/pomerium/pkg/cryptutil" ) -// HealthCheck is a simple healthcheck handler that responds to GET and HEAD -// http requests. -func HealthCheck(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet && r.Method != http.MethodHead { - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(http.StatusOK) - if r.Method == http.MethodGet { - fmt.Fprintln(w, http.StatusText(http.StatusOK)) - } -} - // Redirect wraps the std libs's redirect method indicating that pomerium is // the origin of the response. func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) { @@ -72,41 +50,3 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { e.ErrorResponse(r.Context(), w, r) } } - -// JWKSHandler returns the /.well-known/pomerium/jwks.json handler. -func JWKSHandler(rawSigningKey string) http.Handler { - return cors.AllowAll().Handler(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - var jwks jose.JSONWebKeySet - if rawSigningKey != "" { - decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey) - if err != nil { - return NewError(http.StatusInternalServerError, errors.New("bad signing key")) - } - jwk, err := cryptutil.PublicJWKFromBytes(decodedCert) - if err != nil { - return NewError(http.StatusInternalServerError, errors.New("bad signing key")) - } - jwks.Keys = append(jwks.Keys, *jwk) - } - RenderJSON(w, http.StatusOK, jwks) - return nil - })) -} - -// WellKnownPomeriumHandler returns the /.well-known/pomerium handler. -func WellKnownPomeriumHandler(authenticateURL *url.URL) http.Handler { - return cors.AllowAll().Handler(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - wellKnownURLs := struct { - OAuth2Callback string `json:"authentication_callback_endpoint"` // RFC6749 - JSONWebKeySetURL string `json:"jwks_uri"` // RFC7517 - FrontchannelLogoutURI string `json:"frontchannel_logout_uri"` // https://openid.net/specs/openid-connect-frontchannel-1_0.html - }{ - authenticateURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(), - authenticateURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(), - authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/sign_out"}).String(), - } - w.Header().Set("X-CSRF-Token", csrf.Token(r)) - RenderJSON(w, http.StatusOK, wellKnownURLs) - return nil - })) -} diff --git a/internal/httputil/handlers_test.go b/internal/httputil/handlers_test.go index 55e8ae2e5..60e4fda36 100644 --- a/internal/httputil/handlers_test.go +++ b/internal/httputil/handlers_test.go @@ -5,42 +5,11 @@ import ( "math" "net/http" "net/http/httptest" - "net/url" "testing" "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/assert" ) -func TestHealthCheck(t *testing.T) { - t.Parallel() - tests := []struct { - name string - method string - - wantStatus int - }{ - {"good - Get", http.MethodGet, http.StatusOK}, - {"good - Head", http.MethodHead, http.StatusOK}, - {"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed}, - {"bad - Put", http.MethodPut, http.StatusMethodNotAllowed}, - {"bad - Post", http.MethodPost, http.StatusMethodNotAllowed}, - {"bad - route miss", http.MethodGet, http.StatusOK}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := httptest.NewRequest(tt.method, "/", nil) - w := httptest.NewRecorder() - - HealthCheck(w, r) - if w.Code != tt.wantStatus { - t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String()) - } - }) - } -} - func TestRedirect(t *testing.T) { t.Parallel() tests := []struct { @@ -150,30 +119,3 @@ func TestRenderJSON(t *testing.T) { }) } } - -func TestJWKSHandler(t *testing.T) { - t.Parallel() - - t.Run("cors", func(t *testing.T) { - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodOptions, "/", nil) - r.Header.Set("Origin", "https://www.example.com") - r.Header.Set("Access-Control-Request-Method", "GET") - JWKSHandler("").ServeHTTP(w, r) - assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) - }) -} - -func TestWellKnownPomeriumHandler(t *testing.T) { - t.Parallel() - - t.Run("cors", func(t *testing.T) { - authenticateURL, _ := url.Parse("https://authenticate.example.com") - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodOptions, "/", nil) - r.Header.Set("Origin", authenticateURL.String()) - r.Header.Set("Access-Control-Request-Method", "GET") - WellKnownPomeriumHandler(authenticateURL).ServeHTTP(w, r) - assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) - }) -} diff --git a/pkg/webauthnutil/options.go b/pkg/webauthnutil/options.go index 5348617ca..88b018d12 100644 --- a/pkg/webauthnutil/options.go +++ b/pkg/webauthnutil/options.go @@ -3,6 +3,7 @@ package webauthnutil import ( "encoding/base64" "fmt" + "net/http" "time" "github.com/pomerium/webauthn" @@ -25,12 +26,14 @@ func GenerateChallenge(key []byte, expiry time.Time) cryptutil.SecureToken { // GenerateCreationOptions generates creation options for WebAuthn. func GenerateCreationOptions( + r *http.Request, key []byte, deviceType *device.Type, user *user.User, ) *webauthn.PublicKeyCredentialCreationOptions { expiry := time.Now().Add(ceremonyTimeout) return newCreationOptions( + r, GenerateChallenge(key, expiry).Bytes(), deviceType, user, @@ -39,12 +42,14 @@ func GenerateCreationOptions( // GenerateRequestOptions generates request options for WebAuthn. func GenerateRequestOptions( + r *http.Request, key []byte, deviceType *device.Type, knownDeviceCredentials []*device.Credential, ) *webauthn.PublicKeyCredentialRequestOptions { expiry := time.Now().Add(ceremonyTimeout) return newRequestOptions( + r, GenerateChallenge(key, expiry).Bytes(), deviceType, knownDeviceCredentials, @@ -54,6 +59,7 @@ func GenerateRequestOptions( // GetCreationOptionsForCredential gets the creation options for the public key creation credential. An error may be // returned if the challenge used to generate the credential is invalid. func GetCreationOptionsForCredential( + r *http.Request, key []byte, deviceType *device.Type, user *user.User, @@ -76,12 +82,13 @@ func GetCreationOptionsForCredential( return nil, err } - return newCreationOptions(challenge.Bytes(), deviceType, user), nil + return newCreationOptions(r, challenge.Bytes(), deviceType, user), nil } // GetRequestOptionsForCredential gets the request options for the public key request credential. An error may be // returned if the challenge used to generate the credential is invalid. func GetRequestOptionsForCredential( + r *http.Request, key []byte, deviceType *device.Type, knownDeviceCredentials []*device.Credential, @@ -104,11 +111,12 @@ func GetRequestOptionsForCredential( return nil, err } - return newRequestOptions(challenge.Bytes(), deviceType, knownDeviceCredentials), nil + return newRequestOptions(r, challenge.Bytes(), deviceType, knownDeviceCredentials), nil } // newCreationOptions gets the creation options for WebAuthn with the provided challenge. func newCreationOptions( + r *http.Request, challenge []byte, deviceType *device.Type, user *user.User, @@ -116,6 +124,7 @@ func newCreationOptions( options := &webauthn.PublicKeyCredentialCreationOptions{ RP: webauthn.PublicKeyCredentialRPEntity{ Name: rpName, + ID: GetEffectiveDomain(r), }, User: GetUserEntity(user), Challenge: challenge, @@ -133,6 +142,7 @@ func newCreationOptions( // newRequestOptions gets the request options for WebAuthn with the provided challenge. func newRequestOptions( + r *http.Request, challenge []byte, deviceType *device.Type, knownDeviceCredentials []*device.Credential, @@ -140,6 +150,7 @@ func newRequestOptions( options := &webauthn.PublicKeyCredentialRequestOptions{ Challenge: challenge, Timeout: ceremonyTimeout, + RPID: GetEffectiveDomain(r), } fillRequestUserVerificationRequirement( options, diff --git a/pkg/webauthnutil/options_test.go b/pkg/webauthnutil/options_test.go index 64276750f..be5a9e0d0 100644 --- a/pkg/webauthnutil/options_test.go +++ b/pkg/webauthnutil/options_test.go @@ -1,9 +1,11 @@ package webauthnutil import ( + "net/http" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/pkg/grpc/device" "github.com/pomerium/pomerium/pkg/grpc/user" @@ -12,14 +14,17 @@ import ( ) func TestGenerateCreationOptions(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "https://www.example.com", nil) + require.NoError(t, err) + t.Run("random challenge", func(t *testing.T) { key := []byte{1, 2, 3} - options1 := GenerateCreationOptions(key, predefinedDeviceTypes[DefaultDeviceType], &user.User{ + options1 := GenerateCreationOptions(r, key, predefinedDeviceTypes[DefaultDeviceType], &user.User{ Id: "example", Email: "test@example.com", Name: "Test User", }) - options2 := GenerateCreationOptions(key, predefinedDeviceTypes[DefaultDeviceType], &user.User{ + options2 := GenerateCreationOptions(r, key, predefinedDeviceTypes[DefaultDeviceType], &user.User{ Id: "example", Email: "test@example.com", Name: "Test User", @@ -28,7 +33,7 @@ func TestGenerateCreationOptions(t *testing.T) { }) t.Run(DefaultDeviceType, func(t *testing.T) { key := []byte{1, 2, 3} - options := GenerateCreationOptions(key, predefinedDeviceTypes[DefaultDeviceType], &user.User{ + options := GenerateCreationOptions(r, key, predefinedDeviceTypes[DefaultDeviceType], &user.User{ Id: "example", Email: "test@example.com", Name: "Test User", @@ -37,6 +42,7 @@ func TestGenerateCreationOptions(t *testing.T) { assert.Equal(t, &webauthn.PublicKeyCredentialCreationOptions{ RP: webauthn.PublicKeyCredentialRPEntity{ Name: "Pomerium", + ID: "example.com", }, User: webauthn.PublicKeyCredentialUserEntity{ ID: []byte{ @@ -63,15 +69,18 @@ func TestGenerateCreationOptions(t *testing.T) { } func TestGenerateRequestOptions(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "https://www.example.com", nil) + require.NoError(t, err) + t.Run("random challenge", func(t *testing.T) { key := []byte{1, 2, 3} - options1 := GenerateRequestOptions(key, predefinedDeviceTypes[DefaultDeviceType], nil) - options2 := GenerateRequestOptions(key, predefinedDeviceTypes[DefaultDeviceType], nil) + options1 := GenerateRequestOptions(r, key, predefinedDeviceTypes[DefaultDeviceType], nil) + options2 := GenerateRequestOptions(r, key, predefinedDeviceTypes[DefaultDeviceType], nil) assert.NotEqual(t, options1.Challenge, options2.Challenge) }) t.Run(DefaultDeviceType, func(t *testing.T) { key := []byte{1, 2, 3} - options := GenerateRequestOptions(key, predefinedDeviceTypes[DefaultDeviceType], []*device.Credential{ + options := GenerateRequestOptions(r, key, predefinedDeviceTypes[DefaultDeviceType], []*device.Credential{ {Id: "device1", Specifier: &device.Credential_Webauthn{Webauthn: &device.Credential_WebAuthn{ Id: []byte{4, 5, 6}, }}}, @@ -79,6 +88,7 @@ func TestGenerateRequestOptions(t *testing.T) { options.Challenge = nil assert.Equal(t, &webauthn.PublicKeyCredentialRequestOptions{ Timeout: 900000000000, + RPID: "example.com", AllowCredentials: []webauthn.PublicKeyCredentialDescriptor{ {Type: "public-key", ID: []byte{4, 5, 6}}, }, @@ -129,7 +139,8 @@ func TestFillPublicKeyCredentialParameters(t *testing.T) { }{ {"", 0, nil}, {"public-key", -7, &device.WebAuthnOptions_PublicKeyCredentialParameters{ - Type: device.WebAuthnOptions_PUBLIC_KEY, Alg: -7}}, + Type: device.WebAuthnOptions_PUBLIC_KEY, Alg: -7, + }}, } { params := new(webauthn.PublicKeyCredentialParameters) fillPublicKeyCredentialParameters(params, testCase.in) diff --git a/pkg/webauthnutil/webauthnutil.go b/pkg/webauthnutil/webauthnutil.go index c46b62fe5..e7546bf69 100644 --- a/pkg/webauthnutil/webauthnutil.go +++ b/pkg/webauthnutil/webauthnutil.go @@ -1,2 +1,32 @@ // Package webauthnutil contains types and functions for working with the webauthn package. package webauthnutil + +import ( + "net" + "net/http" + + "golang.org/x/net/publicsuffix" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/webauthn" +) + +// GetRelyingParty gets a RelyingParty for the given request and databroker client. +func GetRelyingParty(r *http.Request, client databroker.DataBrokerServiceClient) *webauthn.RelyingParty { + return webauthn.NewRelyingParty( + "https://"+GetEffectiveDomain(r), + NewCredentialStorage(client), + ) +} + +// GetEffectiveDomain returns the effective domain for an HTTP request. +func GetEffectiveDomain(r *http.Request) string { + h, _, err := net.SplitHostPort(r.Host) + if err != nil { + h = r.Host + } + if tld, err := publicsuffix.EffectiveTLDPlusOne(h); err == nil { + return tld + } + return h +} diff --git a/pkg/webauthnutil/webauthnutil_test.go b/pkg/webauthnutil/webauthnutil_test.go new file mode 100644 index 000000000..97b537e29 --- /dev/null +++ b/pkg/webauthnutil/webauthnutil_test.go @@ -0,0 +1,31 @@ +package webauthnutil + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetEffectiveDomain(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + in string + expect string + }{ + {"https://www.example.com/some/path", "example.com"}, + {"https://www.example.com:8080/some/path", "example.com"}, + {"https://www.subdomain.example.com/some/path", "example.com"}, + {"https://example.com/some/path", "example.com"}, + } { + tc := tc + t.Run(tc.expect, func(t *testing.T) { + t.Parallel() + + r, err := http.NewRequest(http.MethodGet, tc.in, nil) + require.NoError(t, err) + assert.Equal(t, tc.expect, GetEffectiveDomain(r)) + }) + } +} diff --git a/proxy/data.go b/proxy/data.go new file mode 100644 index 000000000..cab09fb6d --- /dev/null +++ b/proxy/data.go @@ -0,0 +1,151 @@ +package proxy + +import ( + "context" + "net/http" + + "github.com/pomerium/csrf" + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/internal/handlers/webauthn" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/webauthnutil" +) + +func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) { + client := p.state.Load().dataBrokerClient + + isImpersonated = false + s, err = session.Get(ctx, client, sessionID) + if s.GetImpersonateSessionId() != "" { + s, err = session.Get(ctx, client, s.GetImpersonateSessionId()) + isImpersonated = true + } + + return s, isImpersonated, err +} + +func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) { + state := p.state.Load() + + rawJWT, err := state.sessionStore.LoadSession(r) + if err != nil { + return sessions.State{}, err + } + + encoder, err := jws.NewHS256Signer(state.sharedKey) + if err != nil { + return sessions.State{}, err + } + + var sessionState sessions.State + if err := encoder.Unmarshal([]byte(rawJWT), &sessionState); err != nil { + return sessions.State{}, httputil.NewError(http.StatusBadRequest, err) + } + + return sessionState, nil +} + +func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) { + client := p.state.Load().dataBrokerClient + return user.Get(ctx, client, userID) +} + +func (p *Proxy) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) { + options := p.currentOptions.Load() + state := p.state.Load() + + data := handlers.UserInfoData{ + CSRFToken: csrf.Token(r), + BrandingOptions: options.BrandingOptions, + } + + ss, err := p.getSessionState(r) + if err != nil { + return handlers.UserInfoData{}, err + } + + data.Session, data.IsImpersonated, err = p.getSession(r.Context(), ss.ID) + if err != nil { + data.Session = &session.Session{Id: ss.ID} + } + + data.User, err = p.getUser(r.Context(), data.Session.GetUserId()) + if err != nil { + data.User = &user.User{Id: data.Session.GetUserId()} + } + + data.WebAuthnCreationOptions, data.WebAuthnRequestOptions, _ = p.webauthn.GetOptions(r) + data.WebAuthnURL = urlutil.WebAuthnURL(r, urlutil.GetAbsoluteURL(r), state.sharedKey, r.URL.Query()) + p.fillEnterpriseUserInfoData(r.Context(), &data) + return data, nil +} + +func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) { + client := p.state.Load().dataBrokerClient + + res, _ := client.Get(ctx, &databroker.GetRequest{Type: "type.googleapis.com/pomerium.config.Config", Id: "dashboard"}) + data.IsEnterprise = res.GetRecord() != nil + if !data.IsEnterprise { + return + } + + data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId()) + if data.DirectoryUser != nil { + for _, groupID := range data.DirectoryUser.GroupIDs { + directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID) + if directoryGroup != nil { + data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup) + } + } + } +} + +func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) { + options := p.currentOptions.Load() + state := p.state.Load() + + ss, err := p.getSessionState(r) + if err != nil { + return nil, err + } + + s, _, err := p.getSession(r.Context(), ss.ID) + if err != nil { + return nil, err + } + + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return nil, err + } + + internalAuthenticateURL, err := options.GetInternalAuthenticateURL() + if err != nil { + return nil, err + } + + pomeriumDomains, err := options.GetAllRouteableHTTPDomains() + if err != nil { + return nil, err + } + + return &webauthn.State{ + AuthenticateURL: authenticateURL, + InternalAuthenticateURL: internalAuthenticateURL, + SharedKey: state.sharedKey, + Client: state.dataBrokerClient, + PomeriumDomains: pomeriumDomains, + Session: s, + SessionState: &ss, + SessionStore: state.sessionStore, + RelyingParty: webauthnutil.GetRelyingParty(r, state.dataBrokerClient), + BrandingOptions: options.BrandingOptions, + }, nil +} diff --git a/proxy/handlers.go b/proxy/handlers.go index e22980887..5c8ef42b2 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" + "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/urlutil" @@ -22,9 +23,11 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) // special pomerium endpoints for users to view their session - h.Path("/").HandlerFunc(p.userInfo).Methods(http.MethodGet) - h.Path("/sign_out").Handler(httputil.HandlerFunc(p.SignOut)).Methods(http.MethodGet, http.MethodPost) + h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet) + h.Path("/device-enrolled").Handler(httputil.HandlerFunc(p.deviceEnrolled)) h.Path("/jwt").Handler(httputil.HandlerFunc(p.jwtAssertion)).Methods(http.MethodGet) + h.Path("/sign_out").Handler(httputil.HandlerFunc(p.SignOut)).Methods(http.MethodGet, http.MethodPost) + h.Path("/webauthn").Handler(p.webauthn) // called following authenticate auth flow to grab a new or existing session // the route specific cookie is returned in a signed query params @@ -81,21 +84,22 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error { return nil } -func (p *Proxy) userInfo(w http.ResponseWriter, r *http.Request) { - state := p.state.Load() - - redirectURL := urlutil.GetAbsoluteURL(r).String() - if ref := r.Header.Get(httputil.HeaderReferrer); ref != "" { - redirectURL = ref +func (p *Proxy) userInfo(w http.ResponseWriter, r *http.Request) error { + data, err := p.getUserInfoData(r) + if err != nil { + return err } + handlers.UserInfo(data).ServeHTTP(w, r) + return nil +} - uri := state.authenticateDashboardURL.ResolveReference(&url.URL{ - RawQuery: url.Values{ - urlutil.QueryRedirectURI: {redirectURL}, - }.Encode(), - }) - uri = urlutil.NewSignedURL(state.sharedKey, uri).Sign() - httputil.Redirect(w, r, uri.String(), http.StatusFound) +func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error { + data, err := p.getUserInfoData(r) + if err != nil { + return err + } + handlers.DeviceEnrolled(data).ServeHTTP(w, r) + return nil } // Callback handles the result of a successful call to the authenticate service diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 3e6b1e517..d1b01049f 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -64,29 +64,6 @@ func TestProxy_Signout(t *testing.T) { } } -func TestProxy_userInfo(t *testing.T) { - opts := testOptions(t) - err := ValidateOptions(opts) - if err != nil { - t.Fatal(err) - } - proxy, err := New(&config.Config{Options: opts}) - if err != nil { - t.Fatal(err) - } - req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil) - rr := httptest.NewRecorder() - proxy.userInfo(rr, req) - if status := rr.Code; status != http.StatusFound { - t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) - } - body := rr.Body.String() - want := proxy.state.Load().authenticateURL.String() - if !strings.Contains(body, want) { - t.Errorf("handler returned unexpected body: got %v want %s ", body, want) - } -} - func TestProxy_SignOut(t *testing.T) { t.Parallel() tests := []struct { diff --git a/proxy/proxy.go b/proxy/proxy.go index 85f382be4..10005ba51 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" @@ -54,6 +55,7 @@ type Proxy struct { state *atomicutil.Value[*proxyState] currentOptions *atomicutil.Value[*config.Options] currentRouter *atomicutil.Value[*mux.Router] + webauthn *webauthn.Handler } // New takes a Proxy service from options and a validation function. @@ -69,6 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) { currentOptions: config.NewAtomicOptions(), currentRouter: atomicutil.NewValue(httputil.NewRouter()), } + p.webauthn = webauthn.New(p.getWebauthnState) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { return int64(len(p.currentOptions.Load().GetAllPolicies())) diff --git a/proxy/state.go b/proxy/state.go index b4ed3cdfd..73893722b 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "crypto/cipher" "net/url" @@ -10,8 +11,12 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc" + "github.com/pomerium/pomerium/pkg/grpc/databroker" ) +var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) + type proxyState struct { sharedKey []byte sharedCipher cipher.AEAD @@ -26,6 +31,8 @@ type proxyState struct { sessionStore sessions.SessionStore jwtClaimHeaders config.JWTClaimHeaders + dataBrokerClient databroker.DataBrokerServiceClient + programmaticRedirectDomainWhitelist []string } @@ -36,6 +43,7 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { } state := new(proxyState) + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err @@ -81,6 +89,19 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { if err != nil { return nil, err } + + dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ + OutboundPort: cfg.OutboundPort, + InstallationID: cfg.Options.InstallationID, + ServiceName: cfg.Options.Services, + SignedJWTKey: state.sharedKey, + }) + if err != nil { + return nil, err + } + + state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) + state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist return state, nil diff --git a/ui/src/components/WebAuthnAuthenticateButton.tsx b/ui/src/components/WebAuthnAuthenticateButton.tsx index fb9b547d9..bf4b1ed40 100644 --- a/ui/src/components/WebAuthnAuthenticateButton.tsx +++ b/ui/src/components/WebAuthnAuthenticateButton.tsx @@ -28,6 +28,7 @@ async function authenticateCredential( challenge: decode(requestOptions?.challenge), timeout: requestOptions?.timeout, userVerification: requestOptions?.userVerification, + rpId: requestOptions?.rpId, }, }); return credential as CredentialForAuthenticate; diff --git a/ui/src/components/WebAuthnRegisterButton.tsx b/ui/src/components/WebAuthnRegisterButton.tsx index 05f25dd2d..44a1bfb29 100644 --- a/ui/src/components/WebAuthnRegisterButton.tsx +++ b/ui/src/components/WebAuthnRegisterButton.tsx @@ -39,6 +39,7 @@ async function createCredential( })), rp: { name: creationOptions?.rp?.name, + id: creationOptions?.rp?.id, }, timeout: creationOptions?.timeout, user: { diff --git a/ui/src/types/index.ts b/ui/src/types/index.ts index d060645cb..da54433ea 100644 --- a/ui/src/types/index.ts +++ b/ui/src/types/index.ts @@ -58,6 +58,7 @@ export type WebAuthnCreationOptions = { pubKeyCredParams: PublicKeyCredentialParameters[]; rp: { name: string; + id: string; }; timeout: number; user: { @@ -75,6 +76,7 @@ export type WebAuthnRequestOptions = { challenge: string; timeout: number; userVerification: UserVerificationRequirement; + rpId: string; }; // page data