mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
authenticate: implement hpke-based login flow (#3779)
* urlutil: add time validation functions * authenticate: implement hpke-based login flow * fix import cycle * fix tests * log error * fix callback url * add idp param * fix test * fix test
This commit is contained in:
parent
8d1235a5cc
commit
57217af7dd
25 changed files with 656 additions and 661 deletions
|
@ -1,7 +1,7 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -9,12 +9,17 @@ import (
|
|||
"net/url"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// registerDashboardHandlers returns the proxy service's ServeMux
|
||||
|
@ -32,9 +37,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
|||
// called following authenticate auth flow to grab a new or existing session
|
||||
// the route specific cookie is returned in a signed query params
|
||||
c := r.PathPrefix(dashboardPath + "/callback").Subrouter()
|
||||
c.Use(func(h http.Handler) http.Handler {
|
||||
return middleware.ValidateSignature(p.state.Load().sharedKey)(h)
|
||||
})
|
||||
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
|
||||
|
||||
// Programmatic API handlers and middleware
|
||||
|
@ -105,55 +107,96 @@ func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error {
|
|||
// Callback handles the result of a successful call to the authenticate service
|
||||
// and is responsible setting per-route sessions.
|
||||
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
|
||||
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
||||
state := p.state.Load()
|
||||
options := p.currentOptions.Load()
|
||||
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// decrypt the URL values
|
||||
senderPublicKey, values, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err))
|
||||
}
|
||||
|
||||
// confirm this request came from the authenticate service
|
||||
err = p.validateSenderPublicKey(r.Context(), senderPublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate that the request has not expired
|
||||
err = urlutil.ValidateTimeParameters(values)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
|
||||
profile, err := getProfileFromValues(values)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
return err
|
||||
}
|
||||
|
||||
ss := newSessionStateFromProfile(profile)
|
||||
s, err := session.Get(r.Context(), state.dataBrokerClient, ss.ID)
|
||||
if err != nil {
|
||||
s = &session.Session{Id: ss.ID}
|
||||
}
|
||||
populateSessionFromProfile(s, profile, ss, options.CookieExpire)
|
||||
u, err := user.Get(r.Context(), state.dataBrokerClient, ss.UserID())
|
||||
if err != nil {
|
||||
u = &user.User{Id: ss.UserID()}
|
||||
}
|
||||
populateUserFromProfile(u, profile, ss)
|
||||
|
||||
redirectURI, err := getRedirectURIFromValues(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save the records
|
||||
res, err := state.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{
|
||||
Records: []*databroker.Record{
|
||||
databroker.NewRecord(s),
|
||||
databroker.NewRecord(u),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err))
|
||||
}
|
||||
ss.DatabrokerServerVersion = res.GetServerVersion()
|
||||
for _, record := range res.GetRecords() {
|
||||
if record.GetVersion() > ss.DatabrokerRecordVersion {
|
||||
ss.DatabrokerRecordVersion = record.GetVersion()
|
||||
}
|
||||
}
|
||||
|
||||
// save the session state
|
||||
rawJWT, err := state.encoder.Marshal(ss)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err))
|
||||
}
|
||||
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
|
||||
}
|
||||
|
||||
// if programmatic, encode the session jwt as a query param
|
||||
if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
|
||||
q := redirectURL.Query()
|
||||
if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
|
||||
q := redirectURI.Query()
|
||||
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
redirectURI.RawQuery = q.Encode()
|
||||
}
|
||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
|
||||
// redirect
|
||||
httputil.Redirect(w, r, redirectURI.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCallbackSession takes an encrypted per-route session token, decrypts
|
||||
// it using the shared service key, then stores it the local session store.
|
||||
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
|
||||
state := p.state.Load()
|
||||
|
||||
// 1. extract the base64 encoded and encrypted JWT from query params
|
||||
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
|
||||
}
|
||||
// 2. decrypt the JWT using the cipher using the _shared_ secret key
|
||||
rawJWT, err := cryptutil.Decrypt(state.sharedCipher, encryptedJWT, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
|
||||
}
|
||||
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
|
||||
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||
return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
|
||||
}
|
||||
return rawJWT, nil
|
||||
}
|
||||
|
||||
// ProgrammaticLogin returns a signed url that can be used to login
|
||||
// using the authenticate service.
|
||||
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
|
||||
state := p.state.Load()
|
||||
options := p.currentOptions.Load()
|
||||
|
||||
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||
if err != nil {
|
||||
|
@ -164,19 +207,32 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
|
|||
return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri"))
|
||||
}
|
||||
|
||||
idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
hpkeAuthenticateKey, err := state.authenticateKeyFetcher.FetchPublicKey(r.Context())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
signinURL := *state.authenticateSigninURL
|
||||
callbackURI := urlutil.GetAbsoluteURL(r)
|
||||
callbackURI.Path = dashboardPath + "/callback/"
|
||||
q := signinURL.Query()
|
||||
q.Set(urlutil.QueryCallbackURI, callbackURI.String())
|
||||
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||
q.Set(urlutil.QueryIsProgrammatic, "true")
|
||||
signinURL.RawQuery = q.Encode()
|
||||
response := urlutil.NewSignedURL(state.sharedKey, &signinURL).String()
|
||||
|
||||
rawURL, err := handlers.BuildSignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, response)
|
||||
_, _ = io.WriteString(w, rawURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -191,3 +247,44 @@ func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error {
|
|||
_, _ = io.WriteString(w, assertionJWT)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error {
|
||||
state := p.state.Load()
|
||||
|
||||
authenticatePublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err))
|
||||
}
|
||||
|
||||
if !authenticatePublicKey.Equals(senderPublicKey) {
|
||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getProfileFromValues(values url.Values) (*identity.Profile, error) {
|
||||
rawProfile := values.Get(urlutil.QueryIdentityProfile)
|
||||
if rawProfile == "" {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile))
|
||||
}
|
||||
|
||||
var profile identity.Profile
|
||||
err := protojson.Unmarshal([]byte(rawProfile), &profile)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err))
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func getRedirectURIFromValues(values url.Values) (*url.URL, error) {
|
||||
rawRedirectURI := values.Get(urlutil.QueryRedirectURI)
|
||||
if rawRedirectURI == "" {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI))
|
||||
}
|
||||
redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err))
|
||||
}
|
||||
return redirectURI, nil
|
||||
}
|
||||
|
|
|
@ -2,30 +2,20 @@ package proxy
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
|
||||
|
||||
func TestProxy_RobotsTxt(t *testing.T) {
|
||||
proxy := Proxy{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
||||
|
@ -40,29 +30,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProxy_Signout(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
err := ValidateOptions(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
proxy, err := New(&config.Config{Options: opts})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.SignOut(rr, req)
|
||||
if status := rr.Code; status != http.StatusFound {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
||||
}
|
||||
body := rr.Body.String()
|
||||
want := proxy.state.Load().authenticateURL.String()
|
||||
if !strings.Contains(body, want) {
|
||||
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_SignOut(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
@ -104,165 +71,6 @@ func TestProxy_SignOut(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProxy_Callback(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := testOptions(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
options *config.Options
|
||||
|
||||
method string
|
||||
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
|
||||
headers map[string]string
|
||||
qp map[string]string
|
||||
|
||||
cipher encoding.MarshalUnmarshaler
|
||||
sessionStore sessions.SessionStore
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
"good",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http",
|
||||
"example.com",
|
||||
"/",
|
||||
nil,
|
||||
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusFound,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"good programmatic",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http",
|
||||
"example.com",
|
||||
"/",
|
||||
nil,
|
||||
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusFound,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"bad decrypt",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http",
|
||||
"example.com",
|
||||
"/",
|
||||
nil,
|
||||
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"bad save session",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http",
|
||||
"example.com",
|
||||
"/",
|
||||
nil,
|
||||
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{SaveError: errors.New("hi")},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"bad base64",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http",
|
||||
"example.com",
|
||||
"/",
|
||||
nil,
|
||||
map[string]string{urlutil.QuerySessionEncrypted: "^"},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"malformed redirect",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http",
|
||||
"example.com",
|
||||
"/",
|
||||
nil,
|
||||
nil,
|
||||
&mock.Encoder{},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.options})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
||||
state := p.state.Load()
|
||||
state.encoder = tt.cipher
|
||||
state.sessionStore = tt.sessionStore
|
||||
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
|
||||
queryString := redirectURI.Query()
|
||||
for k, v := range tt.qp {
|
||||
queryString.Set(k, v)
|
||||
}
|
||||
redirectURI.RawQuery = queryString.Encode()
|
||||
|
||||
uri := &url.URL{Path: "/"}
|
||||
if tt.qp != nil {
|
||||
qu := uri.Query()
|
||||
for k, v := range tt.qp {
|
||||
qu.Set(k, v)
|
||||
}
|
||||
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||
uri.RawQuery = qu.Encode()
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||
|
||||
r.Header.Set("Accept", "application/json")
|
||||
if len(tt.headers) != 0 {
|
||||
for k, v := range tt.headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", w.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantBody != "" {
|
||||
body := w.Body.String()
|
||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||
t.Errorf("wrong body\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := testOptions(t)
|
||||
|
@ -360,155 +168,6 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProxy_ProgrammaticCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := testOptions(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
options *config.Options
|
||||
|
||||
method string
|
||||
|
||||
redirectURI string
|
||||
|
||||
headers map[string]string
|
||||
qp map[string]string
|
||||
|
||||
cipher encoding.MarshalUnmarshaler
|
||||
sessionStore sessions.SessionStore
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
"good",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http://pomerium.io/",
|
||||
nil,
|
||||
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusFound,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"good programmatic",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http://pomerium.io/",
|
||||
nil,
|
||||
map[string]string{
|
||||
urlutil.QueryIsProgrammatic: "true",
|
||||
urlutil.QueryCallbackURI: "ok",
|
||||
urlutil.QuerySessionEncrypted: goodEncryptionString,
|
||||
},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusFound,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"bad decrypt",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http://pomerium.io/",
|
||||
nil,
|
||||
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"bad save session",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http://pomerium.io/",
|
||||
nil,
|
||||
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{SaveError: errors.New("hi")},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"bad base64",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http://pomerium.io/",
|
||||
nil,
|
||||
map[string]string{urlutil.QuerySessionEncrypted: "^"},
|
||||
&mock.Encoder{MarshalResponse: []byte("x")},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"malformed redirect",
|
||||
opts,
|
||||
http.MethodGet,
|
||||
"http://pomerium.io/",
|
||||
nil,
|
||||
nil,
|
||||
&mock.Encoder{},
|
||||
&mstore.Store{Session: &sessions.State{}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.options})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
||||
state := p.state.Load()
|
||||
state.encoder = tt.cipher
|
||||
state.sessionStore = tt.sessionStore
|
||||
redirectURI, _ := url.Parse(tt.redirectURI)
|
||||
queryString := redirectURI.Query()
|
||||
for k, v := range tt.qp {
|
||||
queryString.Set(k, v)
|
||||
}
|
||||
redirectURI.RawQuery = queryString.Encode()
|
||||
|
||||
uri := &url.URL{Path: "/"}
|
||||
if tt.qp != nil {
|
||||
qu := uri.Query()
|
||||
for k, v := range tt.qp {
|
||||
qu.Set(k, v)
|
||||
}
|
||||
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||
uri.RawQuery = qu.Encode()
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||
|
||||
r.Header.Set("Accept", "application/json")
|
||||
if len(tt.headers) != 0 {
|
||||
for k, v := range tt.headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", w.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantBody != "" {
|
||||
body := w.Body.String()
|
||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||
t.Errorf("wrong body\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_jwt(t *testing.T) {
|
||||
// without upstream headers being set
|
||||
req, _ := http.NewRequest("GET", "https://www.example.com/.pomerium/jwt", nil)
|
||||
|
|
79
proxy/identity_profile.go
Normal file
79
proxy/identity_profile.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State {
|
||||
claims := p.GetClaims().AsMap()
|
||||
|
||||
ss := sessions.NewState(p.GetProviderId())
|
||||
|
||||
// set the subject
|
||||
if v, ok := claims["sub"]; ok {
|
||||
ss.Subject = fmt.Sprint(v)
|
||||
} else if v, ok := claims["user"]; ok {
|
||||
ss.Subject = fmt.Sprint(v)
|
||||
}
|
||||
|
||||
// set the oid
|
||||
if v, ok := claims["oid"]; ok {
|
||||
ss.OID = fmt.Sprint(v)
|
||||
}
|
||||
|
||||
return ss
|
||||
}
|
||||
|
||||
func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) {
|
||||
claims := p.GetClaims().AsMap()
|
||||
oauthToken := new(oauth2.Token)
|
||||
_ = json.Unmarshal(p.GetOauthToken(), oauthToken)
|
||||
|
||||
s.UserId = ss.UserID()
|
||||
s.IssuedAt = timestamppb.Now()
|
||||
s.AccessedAt = timestamppb.Now()
|
||||
s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire))
|
||||
s.IdToken = &session.IDToken{
|
||||
Issuer: ss.Issuer,
|
||||
Subject: ss.Subject,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)),
|
||||
IssuedAt: timestamppb.Now(),
|
||||
Raw: string(p.GetIdToken()),
|
||||
}
|
||||
s.OauthToken = manager.ToOAuthToken(oauthToken)
|
||||
if s.Claims == nil {
|
||||
s.Claims = make(map[string]*structpb.ListValue)
|
||||
}
|
||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||
s.Claims[k] = vs
|
||||
}
|
||||
}
|
||||
|
||||
func populateUserFromProfile(u *user.User, p *identitypb.Profile, ss *sessions.State) {
|
||||
claims := p.GetClaims().AsMap()
|
||||
if v, ok := claims["name"]; ok {
|
||||
u.Name = fmt.Sprint(v)
|
||||
}
|
||||
if v, ok := claims["email"]; ok {
|
||||
u.Email = fmt.Sprint(v)
|
||||
}
|
||||
if u.Claims == nil {
|
||||
u.Claims = make(map[string]*structpb.ListValue)
|
||||
}
|
||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||
u.Claims[k] = vs
|
||||
}
|
||||
}
|
|
@ -9,13 +9,15 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testOptions(t *testing.T) *config.Options {
|
||||
t.Helper()
|
||||
|
||||
opts := config.NewDefaultOptions()
|
||||
opts.AuthenticateURLString = "https://authenticate.example"
|
||||
|
||||
to, err := config.ParseWeightedUrls("https://example.example")
|
||||
require.NoError(t, err)
|
||||
|
@ -28,6 +30,13 @@ func testOptions(t *testing.T) *config.Options {
|
|||
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
||||
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
|
||||
|
||||
htpkePrivateKey, err := opts.GetHPKEPrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
authnSrv := httptest.NewServer(handlers.JWKSHandler(opts.SigningKey, htpkePrivateKey.PublicKey()))
|
||||
t.Cleanup(authnSrv.Close)
|
||||
opts.AuthenticateURLString = authnSrv.URL
|
||||
|
||||
require.NoError(t, opts.Validate())
|
||||
|
||||
return opts
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
@ -26,10 +27,12 @@ type proxyState struct {
|
|||
authenticateSigninURL *url.URL
|
||||
authenticateRefreshURL *url.URL
|
||||
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
cookieSecret []byte
|
||||
sessionStore sessions.SessionStore
|
||||
jwtClaimHeaders config.JWTClaimHeaders
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
cookieSecret []byte
|
||||
sessionStore sessions.SessionStore
|
||||
jwtClaimHeaders config.JWTClaimHeaders
|
||||
hpkePrivateKey *hpke.PrivateKey
|
||||
authenticateKeyFetcher hpke.KeyFetcher
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
|
||||
|
@ -44,11 +47,24 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
|||
|
||||
state := new(proxyState)
|
||||
|
||||
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.sharedKey, err = cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.hpkePrivateKey, err = cfg.Options.GetHPKEPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state.authenticateKeyFetcher = hpke.NewKeyFetcher(authenticateURL.ResolveReference(&url.URL{
|
||||
Path: "/.well-known/pomerium/jwks.json",
|
||||
}).String())
|
||||
|
||||
state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue