mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-27 21:49:12 +02:00
proxy: add userinfo and webauthn endpoints
This commit is contained in:
parent
6b5096b0fe
commit
5d64e158c7
26 changed files with 404 additions and 167 deletions
|
@ -7,9 +7,9 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
|
|
@ -18,8 +18,8 @@ import (
|
|||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/authenticate/handlers"
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
)
|
||||
|
||||
// Handler returns the authenticate service's handler chain.
|
||||
|
@ -544,7 +545,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
|
|||
Id: pbSession.GetUserId(),
|
||||
}
|
||||
}
|
||||
creationOptions, requestOptions, _ := a.webauthn.GetOptions(r.Context())
|
||||
creationOptions, requestOptions, _ := a.webauthn.GetOptions(r)
|
||||
|
||||
data := handlers.UserInfoData{
|
||||
CSRFToken: csrf.Token(r),
|
||||
|
@ -715,15 +716,15 @@ func (a *Authenticate) getUser(ctx context.Context, userID string) (*user.User,
|
|||
return user.Get(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
|
||||
func (a *Authenticate) getWebauthnState(r *http.Request) (*webauthn.State, error) {
|
||||
state := a.state.Load()
|
||||
|
||||
s, _, err := a.getCurrentSession(ctx)
|
||||
s, _, err := a.getCurrentSession(r.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss, err := a.getSessionFromCtx(ctx)
|
||||
ss, err := a.getSessionFromCtx(r.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -752,7 +753,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
|||
Session: s,
|
||||
SessionState: ss,
|
||||
SessionStore: state.sessionStore,
|
||||
RelyingParty: state.webauthnRelyingParty,
|
||||
RelyingParty: webauthnutil.GetRelyingParty(r, state.dataBrokerClient),
|
||||
BrandingOptions: a.options.Load().BrandingOptions,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
// Package handlers contains various web handlers for the authenticate service.
|
||||
package handlers
|
|
@ -21,12 +21,12 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
|
|
|
@ -18,8 +18,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
"github.com/pomerium/webauthn"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
@ -46,8 +44,6 @@ type authenticateState struct {
|
|||
jwk *jose.JSONWebKeySet
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
|
||||
webauthnRelyingParty *webauthn.RelyingParty
|
||||
}
|
||||
|
||||
func newAuthenticateState() *authenticateState {
|
||||
|
@ -153,10 +149,5 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
|
||||
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||
|
||||
state.webauthnRelyingParty = webauthn.NewRelyingParty(
|
||||
authenticateURL.String(),
|
||||
webauthnutil.NewCredentialStorage(state.dataBrokerClient),
|
||||
)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
|
2
go.sum
2
go.sum
|
@ -1085,7 +1085,6 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y
|
|||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||
golang.org/x/crypto v0.2.0 h1:BRXPfhNivWL5Yq0BGQ39a2sW6t44aODpfxkWjYdzewE=
|
||||
golang.org/x/crypto v0.2.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
|
@ -1192,7 +1191,6 @@ golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug
|
|||
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
|
||||
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
|
||||
golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
|
||||
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
|
|
|
@ -7,10 +7,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/CAFxX/httpcompression"
|
||||
"github.com/gorilla/handlers"
|
||||
gorillahandlers "github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
|
@ -37,7 +38,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) {
|
|||
Str("path", r.URL.String()).
|
||||
Msg("http-request")
|
||||
}))
|
||||
root.Use(handlers.RecoveryHandler())
|
||||
root.Use(gorillahandlers.RecoveryHandler())
|
||||
root.Use(log.HeadersHandler(httputil.HeadersXForwarded))
|
||||
root.Use(log.RemoteAddrHandler("ip"))
|
||||
root.Use(log.UserAgentHandler("user_agent"))
|
||||
|
@ -59,10 +60,10 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
|
|||
return fmt.Errorf("invalid signing key: %w", err)
|
||||
}
|
||||
|
||||
root.HandleFunc("/healthz", httputil.HealthCheck)
|
||||
root.HandleFunc("/ping", httputil.HealthCheck)
|
||||
root.Handle("/.well-known/pomerium", httputil.WellKnownPomeriumHandler(authenticateURL))
|
||||
root.Handle("/.well-known/pomerium/", httputil.WellKnownPomeriumHandler(authenticateURL))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(httputil.JWKSHandler(rawSigningKey))
|
||||
root.HandleFunc("/healthz", handlers.HealthCheck)
|
||||
root.HandleFunc("/ping", handlers.HealthCheck)
|
||||
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey))
|
||||
return nil
|
||||
}
|
||||
|
|
2
internal/handlers/handlers.go
Normal file
2
internal/handlers/handlers.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package handlers contains HTTP handlers used by Pomerium.
|
||||
package handlers
|
20
internal/handlers/health_check.go
Normal file
20
internal/handlers/health_check.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HealthCheck is a simple healthcheck handler that responds to GET and HEAD
|
||||
// http requests.
|
||||
func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if r.Method == http.MethodGet {
|
||||
fmt.Fprintln(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
36
internal/handlers/health_check_test.go
Normal file
36
internal/handlers/health_check_test.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good - Get", http.MethodGet, http.StatusOK},
|
||||
{"good - Head", http.MethodHead, http.StatusOK},
|
||||
{"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed},
|
||||
{"bad - Put", http.MethodPut, http.StatusMethodNotAllowed},
|
||||
{"bad - Post", http.MethodPost, http.StatusMethodNotAllowed},
|
||||
{"bad - route miss", http.MethodGet, http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HealthCheck(w, r)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
33
internal/handlers/jwks.go
Normal file
33
internal/handlers/jwks.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||
func JWKSHandler(rawSigningKey string) http.Handler {
|
||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
var jwks jose.JSONWebKeySet
|
||||
if rawSigningKey != "" {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwk, err := cryptutil.PublicJWKFromBytes(decodedCert)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwks.Keys = append(jwks.Keys, *jwk)
|
||||
}
|
||||
httputil.RenderJSON(w, http.StatusOK, jwks)
|
||||
return nil
|
||||
}))
|
||||
}
|
22
internal/handlers/jwks_test.go
Normal file
22
internal/handlers/jwks_test.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestJWKSHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", "https://www.example.com")
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
JWKSHandler("").ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
|
@ -59,7 +59,7 @@ type State struct {
|
|||
}
|
||||
|
||||
// A StateProvider provides state for the handler.
|
||||
type StateProvider = func(context.Context) (*State, error)
|
||||
type StateProvider = func(*http.Request) (*State, error)
|
||||
|
||||
// Handler is the WebAuthn device handler.
|
||||
type Handler struct {
|
||||
|
@ -74,17 +74,17 @@ func New(getState StateProvider) *Handler {
|
|||
}
|
||||
|
||||
// GetOptions returns the creation and request options for WebAuthn.
|
||||
func (h *Handler) GetOptions(ctx context.Context) (
|
||||
func (h *Handler) GetOptions(r *http.Request) (
|
||||
creationOptions *webauthn.PublicKeyCredentialCreationOptions,
|
||||
requestOptions *webauthn.PublicKeyCredentialRequestOptions,
|
||||
err error,
|
||||
) {
|
||||
state, err := h.getState(ctx)
|
||||
state, err := h.getState(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return h.getOptions(ctx, state, webauthnutil.DefaultDeviceType)
|
||||
return h.getOptions(r.Context(), state, webauthnutil.DefaultDeviceType)
|
||||
}
|
||||
|
||||
// ServeHTTP serves the HTTP handler.
|
||||
|
@ -118,7 +118,7 @@ func (h *Handler) getOptions(ctx context.Context, state *State, deviceTypeParam
|
|||
}
|
||||
|
||||
func (h *Handler) handle(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := h.getState(r.Context())
|
||||
s, err := h.getState(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
29
internal/handlers/well_known_pomerium.go
Normal file
29
internal/handlers/well_known_pomerium.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// WellKnownPomerium returns the /.well-known/pomerium handler.
|
||||
func WellKnownPomerium(authenticateURL *url.URL) http.Handler {
|
||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
wellKnownURLs := struct {
|
||||
OAuth2Callback string `json:"authentication_callback_endpoint"` // RFC6749
|
||||
JSONWebKeySetURL string `json:"jwks_uri"` // RFC7517
|
||||
FrontchannelLogoutURI string `json:"frontchannel_logout_uri"` // https://openid.net/specs/openid-connect-frontchannel-1_0.html
|
||||
}{
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/sign_out"}).String(),
|
||||
}
|
||||
w.Header().Set("X-CSRF-Token", csrf.Token(r))
|
||||
httputil.RenderJSON(w, http.StatusOK, wellKnownURLs)
|
||||
return nil
|
||||
}))
|
||||
}
|
24
internal/handlers/well_known_pomerium_test.go
Normal file
24
internal/handlers/well_known_pomerium_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWellKnownPomeriumHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
authenticateURL, _ := url.Parse("https://authenticate.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", authenticateURL.String())
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
WellKnownPomerium(authenticateURL).ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
|
@ -2,34 +2,12 @@ package httputil
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
// HealthCheck is a simple healthcheck handler that responds to GET and HEAD
|
||||
// http requests.
|
||||
func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if r.Method == http.MethodGet {
|
||||
fmt.Fprintln(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect wraps the std libs's redirect method indicating that pomerium is
|
||||
// the origin of the response.
|
||||
func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) {
|
||||
|
@ -72,41 +50,3 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
e.ErrorResponse(r.Context(), w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||
func JWKSHandler(rawSigningKey string) http.Handler {
|
||||
return cors.AllowAll().Handler(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
var jwks jose.JSONWebKeySet
|
||||
if rawSigningKey != "" {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
|
||||
if err != nil {
|
||||
return NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwk, err := cryptutil.PublicJWKFromBytes(decodedCert)
|
||||
if err != nil {
|
||||
return NewError(http.StatusInternalServerError, errors.New("bad signing key"))
|
||||
}
|
||||
jwks.Keys = append(jwks.Keys, *jwk)
|
||||
}
|
||||
RenderJSON(w, http.StatusOK, jwks)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// WellKnownPomeriumHandler returns the /.well-known/pomerium handler.
|
||||
func WellKnownPomeriumHandler(authenticateURL *url.URL) http.Handler {
|
||||
return cors.AllowAll().Handler(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
wellKnownURLs := struct {
|
||||
OAuth2Callback string `json:"authentication_callback_endpoint"` // RFC6749
|
||||
JSONWebKeySetURL string `json:"jwks_uri"` // RFC7517
|
||||
FrontchannelLogoutURI string `json:"frontchannel_logout_uri"` // https://openid.net/specs/openid-connect-frontchannel-1_0.html
|
||||
}{
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(),
|
||||
authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/sign_out"}).String(),
|
||||
}
|
||||
w.Header().Set("X-CSRF-Token", csrf.Token(r))
|
||||
RenderJSON(w, http.StatusOK, wellKnownURLs)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -5,42 +5,11 @@ import (
|
|||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good - Get", http.MethodGet, http.StatusOK},
|
||||
{"good - Head", http.MethodHead, http.StatusOK},
|
||||
{"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed},
|
||||
{"bad - Put", http.MethodPut, http.StatusMethodNotAllowed},
|
||||
{"bad - Post", http.MethodPost, http.StatusMethodNotAllowed},
|
||||
{"bad - route miss", http.MethodGet, http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HealthCheck(w, r)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
@ -150,30 +119,3 @@ func TestRenderJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWKSHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", "https://www.example.com")
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
JWKSHandler("").ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWellKnownPomeriumHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
authenticateURL, _ := url.Parse("https://authenticate.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", authenticateURL.String())
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
WellKnownPomeriumHandler(authenticateURL).ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -92,6 +92,14 @@ func GetAbsoluteURL(r *http.Request) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
// GetOrigin gets the Origin header for a request, or builds the origin based on the request host.
|
||||
func GetOrigin(r *http.Request) string {
|
||||
if v := r.Header.Get("Origin"); v != "" {
|
||||
return v
|
||||
}
|
||||
return "https://" + r.Host
|
||||
}
|
||||
|
||||
// GetDomainsForURL returns the available domains for given url.
|
||||
//
|
||||
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`.
|
||||
|
|
|
@ -1,2 +1,15 @@
|
|||
// Package webauthnutil contains types and functions for working with the webauthn package.
|
||||
package webauthnutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/webauthn"
|
||||
)
|
||||
|
||||
// GetRelyingParty gets a RelyingParty for the given request and databroker client.
|
||||
func GetRelyingParty(r *http.Request, client databroker.DataBrokerServiceClient) *webauthn.RelyingParty {
|
||||
return webauthn.NewRelyingParty(urlutil.GetOrigin(r), NewCredentialStorage(client))
|
||||
}
|
||||
|
|
151
proxy/data.go
Normal file
151
proxy/data.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
)
|
||||
|
||||
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
|
||||
isImpersonated = false
|
||||
s, err = session.Get(ctx, client, sessionID)
|
||||
if s.GetImpersonateSessionId() != "" {
|
||||
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
|
||||
isImpersonated = true
|
||||
}
|
||||
|
||||
return s, isImpersonated, err
|
||||
}
|
||||
|
||||
func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) {
|
||||
state := p.state.Load()
|
||||
|
||||
rawJWT, err := state.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
return sessions.State{}, err
|
||||
}
|
||||
|
||||
encoder, err := jws.NewHS256Signer(state.sharedKey)
|
||||
if err != nil {
|
||||
return sessions.State{}, err
|
||||
}
|
||||
|
||||
var sessionState sessions.State
|
||||
if err := encoder.Unmarshal([]byte(rawJWT), &sessionState); err != nil {
|
||||
return sessions.State{}, httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
return sessionState, nil
|
||||
}
|
||||
|
||||
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
return user.Get(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (p *Proxy) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) {
|
||||
options := p.currentOptions.Load()
|
||||
state := p.state.Load()
|
||||
|
||||
data := handlers.UserInfoData{
|
||||
CSRFToken: csrf.Token(r),
|
||||
BrandingOptions: options.BrandingOptions,
|
||||
}
|
||||
|
||||
ss, err := p.getSessionState(r)
|
||||
if err != nil {
|
||||
return handlers.UserInfoData{}, err
|
||||
}
|
||||
|
||||
data.Session, data.IsImpersonated, err = p.getSession(r.Context(), ss.ID)
|
||||
if err != nil {
|
||||
data.Session = &session.Session{Id: ss.ID}
|
||||
}
|
||||
|
||||
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())
|
||||
if err != nil {
|
||||
data.User = &user.User{Id: data.Session.GetUserId()}
|
||||
}
|
||||
|
||||
data.WebAuthnCreationOptions, data.WebAuthnRequestOptions, _ = p.webauthn.GetOptions(r)
|
||||
data.WebAuthnURL = urlutil.WebAuthnURL(r, urlutil.GetAbsoluteURL(r), state.sharedKey, r.URL.Query())
|
||||
p.fillEnterpriseUserInfoData(r.Context(), &data)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
|
||||
res, _ := client.Get(ctx, &databroker.GetRequest{Type: "type.googleapis.com/pomerium.config.Config", Id: "dashboard"})
|
||||
data.IsEnterprise = res.GetRecord() != nil
|
||||
if !data.IsEnterprise {
|
||||
return
|
||||
}
|
||||
|
||||
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
|
||||
if data.DirectoryUser != nil {
|
||||
for _, groupID := range data.DirectoryUser.GroupIDs {
|
||||
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
|
||||
if directoryGroup != nil {
|
||||
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) {
|
||||
options := p.currentOptions.Load()
|
||||
state := p.state.Load()
|
||||
|
||||
ss, err := p.getSessionState(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, _, err := p.getSession(r.Context(), ss.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authenticateURL, err := options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
internalAuthenticateURL, err := options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pomeriumDomains, err := options.GetAllRouteableHTTPDomains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &webauthn.State{
|
||||
AuthenticateURL: authenticateURL,
|
||||
InternalAuthenticateURL: internalAuthenticateURL,
|
||||
SharedKey: state.sharedKey,
|
||||
Client: state.dataBrokerClient,
|
||||
PomeriumDomains: pomeriumDomains,
|
||||
Session: s,
|
||||
SessionState: &ss,
|
||||
SessionStore: state.sessionStore,
|
||||
RelyingParty: webauthnutil.GetRelyingParty(r, state.dataBrokerClient),
|
||||
BrandingOptions: options.BrandingOptions,
|
||||
}, nil
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
@ -22,9 +23,11 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
|||
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||
|
||||
// special pomerium endpoints for users to view their session
|
||||
h.Path("/").HandlerFunc(p.userInfo).Methods(http.MethodGet)
|
||||
h.Path("/sign_out").Handler(httputil.HandlerFunc(p.SignOut)).Methods(http.MethodGet, http.MethodPost)
|
||||
h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet)
|
||||
h.Path("/device-enrolled").Handler(httputil.HandlerFunc(p.deviceEnrolled))
|
||||
h.Path("/jwt").Handler(httputil.HandlerFunc(p.jwtAssertion)).Methods(http.MethodGet)
|
||||
h.Path("/sign_out").Handler(httputil.HandlerFunc(p.SignOut)).Methods(http.MethodGet, http.MethodPost)
|
||||
h.Path("/webauthn").Handler(p.webauthn)
|
||||
|
||||
// called following authenticate auth flow to grab a new or existing session
|
||||
// the route specific cookie is returned in a signed query params
|
||||
|
@ -81,21 +84,22 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) userInfo(w http.ResponseWriter, r *http.Request) {
|
||||
state := p.state.Load()
|
||||
|
||||
redirectURL := urlutil.GetAbsoluteURL(r).String()
|
||||
if ref := r.Header.Get(httputil.HeaderReferrer); ref != "" {
|
||||
redirectURL = ref
|
||||
func (p *Proxy) userInfo(w http.ResponseWriter, r *http.Request) error {
|
||||
data, err := p.getUserInfoData(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handlers.UserInfo(data).ServeHTTP(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
uri := state.authenticateDashboardURL.ResolveReference(&url.URL{
|
||||
RawQuery: url.Values{
|
||||
urlutil.QueryRedirectURI: {redirectURL},
|
||||
}.Encode(),
|
||||
})
|
||||
uri = urlutil.NewSignedURL(state.sharedKey, uri).Sign()
|
||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error {
|
||||
data, err := p.getUserInfoData(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handlers.DeviceEnrolled(data).ServeHTTP(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Callback handles the result of a successful call to the authenticate service
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
@ -54,6 +55,7 @@ type Proxy struct {
|
|||
state *atomicutil.Value[*proxyState]
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentRouter *atomicutil.Value[*mux.Router]
|
||||
webauthn *webauthn.Handler
|
||||
}
|
||||
|
||||
// New takes a Proxy service from options and a validation function.
|
||||
|
@ -69,6 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) {
|
|||
currentOptions: config.NewAtomicOptions(),
|
||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||
}
|
||||
p.webauthn = webauthn.New(p.getWebauthnState)
|
||||
|
||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||
return int64(len(p.currentOptions.Load().GetAllPolicies()))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"net/url"
|
||||
|
||||
|
@ -10,8 +11,12 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
||||
type proxyState struct {
|
||||
sharedKey []byte
|
||||
sharedCipher cipher.AEAD
|
||||
|
@ -26,6 +31,8 @@ type proxyState struct {
|
|||
sessionStore sessions.SessionStore
|
||||
jwtClaimHeaders config.JWTClaimHeaders
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
|
||||
programmaticRedirectDomainWhitelist []string
|
||||
}
|
||||
|
||||
|
@ -36,6 +43,7 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
|||
}
|
||||
|
||||
state := new(proxyState)
|
||||
|
||||
state.sharedKey, err = cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -81,6 +89,19 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
||||
OutboundPort: cfg.OutboundPort,
|
||||
InstallationID: cfg.Options.InstallationID,
|
||||
ServiceName: cfg.Options.Services,
|
||||
SignedJWTKey: state.sharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||
|
||||
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
|
||||
|
||||
return state, nil
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue