authenticate: implement hpke-based login flow (#3779)

* urlutil: add time validation functions

* authenticate: implement hpke-based login flow

* fix import cycle

* fix tests

* log error

* fix callback url

* add idp param

* fix test

* fix test
This commit is contained in:
Caleb Doxsey 2022-12-05 15:31:07 -07:00 committed by GitHub
parent 8d1235a5cc
commit 57217af7dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 656 additions and 661 deletions

View file

@ -1,7 +1,7 @@
package proxy
import (
"encoding/base64"
"context"
"errors"
"fmt"
"io"
@ -9,12 +9,17 @@ import (
"net/url"
"github.com/gorilla/mux"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/hpke"
)
// registerDashboardHandlers returns the proxy service's ServeMux
@ -32,9 +37,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
// called following authenticate auth flow to grab a new or existing session
// the route specific cookie is returned in a signed query params
c := r.PathPrefix(dashboardPath + "/callback").Subrouter()
c.Use(func(h http.Handler) http.Handler {
return middleware.ValidateSignature(p.state.Load().sharedKey)(h)
})
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
// Programmatic API handlers and middleware
@ -105,55 +107,96 @@ func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error {
// Callback handles the result of a successful call to the authenticate service
// and is responsible setting per-route sessions.
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
state := p.state.Load()
options := p.currentOptions.Load()
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
// decrypt the URL values
senderPublicKey, values, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
if err != nil {
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err))
}
// confirm this request came from the authenticate service
err = p.validateSenderPublicKey(r.Context(), senderPublicKey)
if err != nil {
return err
}
// validate that the request has not expired
err = urlutil.ValidateTimeParameters(values)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
profile, err := getProfileFromValues(values)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
return err
}
ss := newSessionStateFromProfile(profile)
s, err := session.Get(r.Context(), state.dataBrokerClient, ss.ID)
if err != nil {
s = &session.Session{Id: ss.ID}
}
populateSessionFromProfile(s, profile, ss, options.CookieExpire)
u, err := user.Get(r.Context(), state.dataBrokerClient, ss.UserID())
if err != nil {
u = &user.User{Id: ss.UserID()}
}
populateUserFromProfile(u, profile, ss)
redirectURI, err := getRedirectURIFromValues(values)
if err != nil {
return err
}
// save the records
res, err := state.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{
Records: []*databroker.Record{
databroker.NewRecord(s),
databroker.NewRecord(u),
},
})
if err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err))
}
ss.DatabrokerServerVersion = res.GetServerVersion()
for _, record := range res.GetRecords() {
if record.GetVersion() > ss.DatabrokerRecordVersion {
ss.DatabrokerRecordVersion = record.GetVersion()
}
}
// save the session state
rawJWT, err := state.encoder.Marshal(ss)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err))
}
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
}
// if programmatic, encode the session jwt as a query param
if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
q := redirectURL.Query()
if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
q := redirectURI.Query()
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
redirectURL.RawQuery = q.Encode()
redirectURI.RawQuery = q.Encode()
}
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
// redirect
httputil.Redirect(w, r, redirectURI.String(), http.StatusFound)
return nil
}
// saveCallbackSession takes an encrypted per-route session token, decrypts
// it using the shared service key, then stores it the local session store.
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
state := p.state.Load()
// 1. extract the base64 encoded and encrypted JWT from query params
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
if err != nil {
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
}
// 2. decrypt the JWT using the cipher using the _shared_ secret key
rawJWT, err := cryptutil.Decrypt(state.sharedCipher, encryptedJWT, nil)
if err != nil {
return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
}
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
}
return rawJWT, nil
}
// ProgrammaticLogin returns a signed url that can be used to login
// using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load()
options := p.currentOptions.Load()
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
@ -164,19 +207,32 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
return httputil.NewError(http.StatusBadRequest, errors.New("invalid redirect uri"))
}
idp, err := options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
hpkeAuthenticateKey, err := state.authenticateKeyFetcher.FetchPublicKey(r.Context())
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
signinURL := *state.authenticateSigninURL
callbackURI := urlutil.GetAbsoluteURL(r)
callbackURI.Path = dashboardPath + "/callback/"
q := signinURL.Query()
q.Set(urlutil.QueryCallbackURI, callbackURI.String())
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
q.Set(urlutil.QueryIsProgrammatic, "true")
signinURL.RawQuery = q.Encode()
response := urlutil.NewSignedURL(state.sharedKey, &signinURL).String()
rawURL, err := handlers.BuildSignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId())
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, response)
_, _ = io.WriteString(w, rawURL)
return nil
}
@ -191,3 +247,44 @@ func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error {
_, _ = io.WriteString(w, assertionJWT)
return nil
}
func (p *Proxy) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error {
state := p.state.Load()
authenticatePublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err))
}
if !authenticatePublicKey.Equals(senderPublicKey) {
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key"))
}
return nil
}
func getProfileFromValues(values url.Values) (*identity.Profile, error) {
rawProfile := values.Get(urlutil.QueryIdentityProfile)
if rawProfile == "" {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile))
}
var profile identity.Profile
err := protojson.Unmarshal([]byte(rawProfile), &profile)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err))
}
return &profile, nil
}
func getRedirectURIFromValues(values url.Values) (*url.URL, error) {
rawRedirectURI := values.Get(urlutil.QueryRedirectURI)
if rawRedirectURI == "" {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI))
}
redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err))
}
return redirectURI, nil
}

View file

@ -2,30 +2,20 @@ package proxy
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
)
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{}
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
@ -40,29 +30,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
}
}
func TestProxy_Signout(t *testing.T) {
opts := testOptions(t)
err := ValidateOptions(opts)
if err != nil {
t.Fatal(err)
}
proxy, err := New(&config.Config{Options: opts})
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
rr := httptest.NewRecorder()
proxy.SignOut(rr, req)
if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
}
body := rr.Body.String()
want := proxy.state.Load().authenticateURL.String()
if !strings.Contains(body, want) {
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
}
}
func TestProxy_SignOut(t *testing.T) {
t.Parallel()
tests := []struct {
@ -104,165 +71,6 @@ func TestProxy_SignOut(t *testing.T) {
}
}
func TestProxy_Callback(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options *config.Options
method string
scheme string
host string
path string
headers map[string]string
qp map[string]string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
wantStatus int
wantBody string
}{
{
"good",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"good programmatic",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"bad decrypt",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"bad save session",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{SaveError: errors.New("hi")},
http.StatusBadRequest,
"",
},
{
"bad base64",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: "^"},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"malformed redirect",
opts,
http.MethodGet,
"http",
"example.com",
"/",
nil,
nil,
&mock.Encoder{},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder()
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}
func TestProxy_ProgrammaticLogin(t *testing.T) {
t.Parallel()
opts := testOptions(t)
@ -360,155 +168,6 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
}
}
func TestProxy_ProgrammaticCallback(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options *config.Options
method string
redirectURI string
headers map[string]string
qp map[string]string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
wantStatus int
wantBody string
}{
{
"good",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"good programmatic",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{
urlutil.QueryIsProgrammatic: "true",
urlutil.QueryCallbackURI: "ok",
urlutil.QuerySessionEncrypted: goodEncryptionString,
},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"bad decrypt",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"bad save session",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{SaveError: errors.New("hi")},
http.StatusBadRequest,
"",
},
{
"bad base64",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
map[string]string{urlutil.QuerySessionEncrypted: "^"},
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
{
"malformed redirect",
opts,
http.MethodGet,
"http://pomerium.io/",
nil,
nil,
&mock.Encoder{},
&mstore.Store{Session: &sessions.State{}},
http.StatusBadRequest,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore
redirectURI, _ := url.Parse(tt.redirectURI)
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder()
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}
func TestProxy_jwt(t *testing.T) {
// without upstream headers being set
req, _ := http.NewRequest("GET", "https://www.example.com/.pomerium/jwt", nil)

79
proxy/identity_profile.go Normal file
View file

@ -0,0 +1,79 @@
package proxy
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/manager"
"github.com/pomerium/pomerium/internal/sessions"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State {
claims := p.GetClaims().AsMap()
ss := sessions.NewState(p.GetProviderId())
// set the subject
if v, ok := claims["sub"]; ok {
ss.Subject = fmt.Sprint(v)
} else if v, ok := claims["user"]; ok {
ss.Subject = fmt.Sprint(v)
}
// set the oid
if v, ok := claims["oid"]; ok {
ss.OID = fmt.Sprint(v)
}
return ss
}
func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) {
claims := p.GetClaims().AsMap()
oauthToken := new(oauth2.Token)
_ = json.Unmarshal(p.GetOauthToken(), oauthToken)
s.UserId = ss.UserID()
s.IssuedAt = timestamppb.Now()
s.AccessedAt = timestamppb.Now()
s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire))
s.IdToken = &session.IDToken{
Issuer: ss.Issuer,
Subject: ss.Subject,
ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)),
IssuedAt: timestamppb.Now(),
Raw: string(p.GetIdToken()),
}
s.OauthToken = manager.ToOAuthToken(oauthToken)
if s.Claims == nil {
s.Claims = make(map[string]*structpb.ListValue)
}
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
s.Claims[k] = vs
}
}
func populateUserFromProfile(u *user.User, p *identitypb.Profile, ss *sessions.State) {
claims := p.GetClaims().AsMap()
if v, ok := claims["name"]; ok {
u.Name = fmt.Sprint(v)
}
if v, ok := claims["email"]; ok {
u.Email = fmt.Sprint(v)
}
if u.Claims == nil {
u.Claims = make(map[string]*structpb.ListValue)
}
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
u.Claims[k] = vs
}
}

View file

@ -9,13 +9,15 @@ import (
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/stretchr/testify/require"
)
func testOptions(t *testing.T) *config.Options {
t.Helper()
opts := config.NewDefaultOptions()
opts.AuthenticateURLString = "https://authenticate.example"
to, err := config.ParseWeightedUrls("https://example.example")
require.NoError(t, err)
@ -28,6 +30,13 @@ func testOptions(t *testing.T) *config.Options {
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
htpkePrivateKey, err := opts.GetHPKEPrivateKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(opts.SigningKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opts.AuthenticateURLString = authnSrv.URL
require.NoError(t, opts.Validate())
return opts

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/hpke"
)
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -26,10 +27,12 @@ type proxyState struct {
authenticateSigninURL *url.URL
authenticateRefreshURL *url.URL
encoder encoding.MarshalUnmarshaler
cookieSecret []byte
sessionStore sessions.SessionStore
jwtClaimHeaders config.JWTClaimHeaders
encoder encoding.MarshalUnmarshaler
cookieSecret []byte
sessionStore sessions.SessionStore
jwtClaimHeaders config.JWTClaimHeaders
hpkePrivateKey *hpke.PrivateKey
authenticateKeyFetcher hpke.KeyFetcher
dataBrokerClient databroker.DataBrokerServiceClient
@ -44,11 +47,24 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
state := new(proxyState)
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return nil, err
}
state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil {
return nil, err
}
state.hpkePrivateKey, err = cfg.Options.GetHPKEPrivateKey()
if err != nil {
return nil, err
}
state.authenticateKeyFetcher = hpke.NewKeyFetcher(authenticateURL.ResolveReference(&url.URL{
Path: "/.well-known/pomerium/jwks.json",
}).String())
state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey)
if err != nil {
return nil, err