diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 5476a0999..b2c999ac6 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -18,6 +18,7 @@ import ( "github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/grpc" "github.com/pomerium/pomerium/internal/grpc/cache/client" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/sessions" @@ -151,7 +152,7 @@ func New(opts config.Options) (*Authenticate, error) { WrappedStore: cookieStore}) qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token") - headerStore := header.NewStore(encryptedEncoder, "Pomerium") + headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium) redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) redirectURL.Path = opts.AuthenticateCallbackPath diff --git a/authorize/grpc.go b/authorize/grpc.go index f5cdece25..e42275ed8 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -6,7 +6,6 @@ import ( "io" "io/ioutil" "net/http" - "net/http/httptest" "net/url" "strings" @@ -15,7 +14,6 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" @@ -37,11 +35,12 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe isForwardAuth := handleForwardAuth(opts, in) hattrs := in.GetAttributes().GetRequest().GetHttp() + hreq := getHTTPRequestFromCheckRequest(in) hdrs := getCheckRequestHeaders(in) var requestHeaders []*envoy_api_v2_core.HeaderValueOption - sess, sesserr := a.loadSessionFromCheckRequest(in) + sess, sesserr := loadSession(hreq, a.currentOptions.Load(), a.currentEncoder.Load()) if a.isExpired(sess) { log.Info().Msg("refreshing session") if newSession, err := a.refreshSession(ctx, sess); err == nil { @@ -162,27 +161,24 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe } func (a *Authorize) getEnvoyRequestHeaders(rawSession []byte) ([]*envoy_api_v2_core.HeaderValueOption, error) { - cookieStore, err := a.getCookieStore() + cookieStore, err := getCookieStore(a.currentOptions.Load(), a.currentEncoder.Load()) if err != nil { return nil, err } - recorder := httptest.NewRecorder() - err = cookieStore.SaveSession(recorder, nil /* unused by cookie store */, string(rawSession)) + hdrs, err := getJWTSetCookieHeaders(cookieStore, rawSession) if err != nil { - return nil, fmt.Errorf("authorize: error saving cookie: %w", err) + return nil, err } var hvos []*envoy_api_v2_core.HeaderValueOption - for k, vs := range recorder.Header() { - for _, v := range vs { - hvos = append(hvos, &envoy_api_v2_core.HeaderValueOption{ - Header: &envoy_api_v2_core.HeaderValue{ - Key: "x-pomerium-" + k, - Value: v, - }, - }) - } + for k, v := range hdrs { + hvos = append(hvos, &envoy_api_v2_core.HeaderValueOption{ + Header: &envoy_api_v2_core.HeaderValue{ + Key: "x-pomerium-" + k, + Value: v, + }, + }) } return hvos, nil @@ -229,59 +225,22 @@ func (a *Authorize) refreshSession(ctx context.Context, rawSession []byte) (newS return newJwt, nil } -func (a *Authorize) loadSessionFromCheckRequest(req *envoy_service_auth_v2.CheckRequest) ([]byte, error) { - cookieStore, err := a.getCookieStore() - if err != nil { - return nil, err - } - - sess, err := cookieStore.LoadSession(&http.Request{ - Header: getCheckRequestHeaders(req), - }) - if err != nil { - return nil, err - } - - return []byte(sess), nil -} - func (a *Authorize) isExpired(rawSession []byte) bool { state := sessions.State{} err := a.currentEncoder.Load().Unmarshal(rawSession, &state) return err == nil && state.IsExpired() } -func (a *Authorize) getCookieStore() (sessions.SessionStore, error) { - opts := a.currentOptions.Load() - encoder := a.currentEncoder.Load() - - cookieOptions := &cookie.Options{ - Name: opts.CookieName, - Domain: opts.CookieDomain, - Secure: opts.CookieSecure, - HTTPOnly: opts.CookieHTTPOnly, - Expire: opts.CookieExpire, +func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v2.CheckRequest) *http.Request { + hattrs := req.GetAttributes().GetRequest().GetHttp() + return &http.Request{ + Method: hattrs.GetMethod(), + URL: getCheckRequestURL(req), + Header: getCheckRequestHeaders(req), + Body: ioutil.NopCloser(strings.NewReader(hattrs.GetBody())), + Host: hattrs.GetHost(), + RequestURI: hattrs.GetPath(), } - - cookieStore, err := cookie.NewStore(cookieOptions, encoder) - if err != nil { - return nil, err - } - return cookieStore, nil -} - -func getFullURL(rawurl, host string) string { - u, err := url.Parse(rawurl) - if err != nil { - u = &url.URL{Path: rawurl} - } - if u.Host == "" { - u.Host = host - } - if u.Scheme == "" { - u.Scheme = "http" - } - return u.String() } func getCheckRequestHeaders(req *envoy_service_auth_v2.CheckRequest) map[string][]string { diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go deleted file mode 100644 index b61830b55..000000000 --- a/authorize/grpc_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package authorize - -import ( - "testing" -) - -func Test_getFullURL(t *testing.T) { - tests := []struct { - rawurl, host, expect string - }{ - {"https://www.example.com/admin", "", "https://www.example.com/admin"}, - {"https://www.example.com/admin", "example.com", "https://www.example.com/admin"}, - {"/admin", "example.com", "http://example.com/admin"}, - } - for _, tt := range tests { - actual := getFullURL(tt.rawurl, tt.host) - if actual != tt.expect { - t.Errorf("expected getFullURL(%s, %s) to be %s, but got %s", tt.rawurl, tt.host, tt.expect, actual) - } - } -} diff --git a/authorize/session.go b/authorize/session.go new file mode 100644 index 000000000..22257008a --- /dev/null +++ b/authorize/session.go @@ -0,0 +1,75 @@ +package authorize + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/sessions/cookie" + "github.com/pomerium/pomerium/internal/sessions/header" + "github.com/pomerium/pomerium/internal/sessions/queryparam" + "github.com/pomerium/pomerium/internal/urlutil" +) + +func loadSession(req *http.Request, options config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) { + var loaders []sessions.SessionLoader + cookieStore, err := getCookieStore(options, encoder) + if err != nil { + return nil, err + } + loaders = append(loaders, + cookieStore, + header.NewStore(encoder, httputil.AuthorizationTypePomerium), + queryparam.NewStore(encoder, urlutil.QuerySession), + ) + + for _, loader := range loaders { + sess, err := loader.LoadSession(req) + if err != nil && !errors.Is(err, sessions.ErrNoSessionFound) { + return nil, err + } else if err == nil { + return []byte(sess), nil + } + } + + return nil, sessions.ErrNoSessionFound +} + +func getCookieStore(options config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { + cookieOptions := &cookie.Options{ + Name: options.CookieName, + Domain: options.CookieDomain, + Secure: options.CookieSecure, + HTTPOnly: options.CookieHTTPOnly, + Expire: options.CookieExpire, + } + cookieStore, err := cookie.NewStore(cookieOptions, encoder) + if err != nil { + return nil, err + } + return cookieStore, nil +} + +func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (map[string]string, error) { + recorder := httptest.NewRecorder() + err := cookieStore.SaveSession(recorder, nil /* unused by cookie store */, string(rawjwt)) + if err != nil { + return nil, fmt.Errorf("authorize: error saving cookie: %w", err) + } + + res := recorder.Result() + res.Body.Close() + + hdrs := make(map[string]string) + for k, vs := range res.Header { + for _, v := range vs { + hdrs[k] = v + } + } + return hdrs, nil +} diff --git a/authorize/session_test.go b/authorize/session_test.go new file mode 100644 index 000000000..74b4fd3ad --- /dev/null +++ b/authorize/session_test.go @@ -0,0 +1,110 @@ +package authorize + +import ( + "net/url" + "regexp" + "testing" + + envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/sessions" +) + +func TestLoadSession(t *testing.T) { + opts := *config.NewDefaultOptions() + encoder, err := jws.NewHS256Signer(nil, "example.com") + if !assert.NoError(t, err) { + return + } + state := &sessions.State{ + Email: "bob@example.com", + } + rawjwt, err := encoder.Marshal(state) + if !assert.NoError(t, err) { + return + } + + load := func(t *testing.T, hattrs *envoy_service_auth_v2.AttributeContext_HttpRequest) (*sessions.State, error) { + req := getHTTPRequestFromCheckRequest(&envoy_service_auth_v2.CheckRequest{ + Attributes: &envoy_service_auth_v2.AttributeContext{ + Request: &envoy_service_auth_v2.AttributeContext_Request{ + Http: hattrs, + }, + }, + }) + raw, err := loadSession(req, opts, encoder) + if err != nil { + return nil, err + } + var state sessions.State + err = encoder.Unmarshal(raw, &state) + if err != nil { + return nil, err + } + return &state, nil + } + + t.Run("cookie", func(t *testing.T) { + cookieStore, err := getCookieStore(opts, encoder) + if !assert.NoError(t, err) { + return + } + hdrs, err := getJWTSetCookieHeaders(cookieStore, rawjwt) + if !assert.NoError(t, err) { + return + } + cookie := regexp.MustCompile(`^([^;]+)(;.*)?$`).ReplaceAllString(hdrs["Set-Cookie"], "$1") + + hattrs := &envoy_service_auth_v2.AttributeContext_HttpRequest{ + Id: "req-1", + Method: "GET", + Headers: map[string]string{ + "Cookie": cookie, + }, + Path: "/hello/world", + Host: "example.com", + Scheme: "https", + } + sess, err := load(t, hattrs) + assert.NoError(t, err) + if assert.NotNil(t, sess) { + assert.Equal(t, "bob@example.com", sess.Email) + } + }) + t.Run("header", func(t *testing.T) { + hattrs := &envoy_service_auth_v2.AttributeContext_HttpRequest{ + Id: "req-1", + Method: "GET", + Headers: map[string]string{ + "Authorization": "Pomerium " + string(rawjwt), + }, + Path: "/hello/world", + Host: "example.com", + Scheme: "https", + } + sess, err := load(t, hattrs) + assert.NoError(t, err) + if assert.NotNil(t, sess) { + assert.Equal(t, "bob@example.com", sess.Email) + } + }) + t.Run("query param", func(t *testing.T) { + hattrs := &envoy_service_auth_v2.AttributeContext_HttpRequest{ + Id: "req-1", + Method: "GET", + Path: "/hello/world?" + url.Values{ + "pomerium_session": []string{string(rawjwt)}, + }.Encode(), + Host: "example.com", + Scheme: "https", + } + sess, err := load(t, hattrs) + assert.NoError(t, err) + if assert.NotNil(t, sess) { + assert.Equal(t, "bob@example.com", sess.Email) + } + }) +} diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go index d97394aa6..f7f72265b 100644 --- a/internal/httputil/headers.go +++ b/internal/httputil/headers.go @@ -1,5 +1,8 @@ package httputil +// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers +const AuthorizationTypePomerium = "Pomerium" + // Pomerium headers contain information added to a request. const ( // HeaderPomeriumResponse is set when pomerium itself creates a response, diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 9c4bafc29..94d7768aa 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -11,6 +11,7 @@ const ( QueryIsProgrammatic = "pomerium_programmatic" QueryForwardAuth = "pomerium_forward_auth" QueryPomeriumJWT = "pomerium_jwt" + QuerySession = "pomerium_session" QuerySessionEncrypted = "pomerium_session_encrypted" QueryRedirectURI = "pomerium_redirect_uri" QueryRefreshToken = "pomerium_refresh_token" diff --git a/proxy/proxy.go b/proxy/proxy.go index b98746aad..5f5c19039 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -127,7 +127,7 @@ func New(opts config.Options) (*Proxy, error) { sessionStore: cookieStore, sessionLoaders: []sessions.SessionLoader{ cookieStore, - header.NewStore(encoder, "Pomerium"), + header.NewStore(encoder, httputil.AuthorizationTypePomerium), queryparam.NewStore(encoder, "pomerium_session")}, templates: template.Must(frontend.NewTemplates()), jwtClaimHeaders: opts.JWTClaimsHeaders,