mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 11:26:29 +02:00
proxy: remove unused handlers (#1317)
proxy: remove unused handlers authenticate: remove unused references to refresh_token Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
82deafee63
commit
c1b3b45d12
9 changed files with 63 additions and 235 deletions
|
@ -107,11 +107,9 @@ func (a *Authenticate) wellKnown(w http.ResponseWriter, r *http.Request) error {
|
||||||
// RFC7517 document, which contains the client's public keys.
|
// RFC7517 document, which contains the client's public keys.
|
||||||
JSONWebKeySetURL string `json:"jwks_uri"`
|
JSONWebKeySetURL string `json:"jwks_uri"`
|
||||||
OAuth2Callback string `json:"authentication_callback_endpoint"`
|
OAuth2Callback string `json:"authentication_callback_endpoint"`
|
||||||
ProgrammaticRefreshAPI string `json:"api_refresh_endpoint"`
|
|
||||||
}{
|
}{
|
||||||
state.redirectURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(),
|
state.redirectURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(),
|
||||||
state.redirectURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(),
|
state.redirectURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(),
|
||||||
state.redirectURL.ResolveReference(&url.URL{Path: "/api/v1/refresh"}).String(),
|
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
@ -234,17 +232,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
|
||||||
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
||||||
newSession.Programmatic = true
|
newSession.Programmatic = true
|
||||||
|
|
||||||
pbSession, err := session.Get(ctx, state.dataBrokerClient, s.ID)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
encSession, err := state.encryptedEncoder.Marshal(pbSession.GetOauthToken())
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
|
|
||||||
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -548,7 +548,7 @@ func TestWellKnownEndpoint(t *testing.T) {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rr, req)
|
h.ServeHTTP(rr, req)
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
expected := `{"jwks_uri":"https://auth.example.com/.well-known/pomerium/jwks.json","authentication_callback_endpoint":"https://auth.example.com/oauth2/callback","api_refresh_endpoint":"https://auth.example.com/api/v1/refresh"}`
|
expected := `{"jwks_uri":"https://auth.example.com/.well-known/pomerium/jwks.json","authentication_callback_endpoint":"https://auth.example.com/oauth2/callback"}`
|
||||||
assert.Equal(t, body, expected)
|
assert.Equal(t, body, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ func TestDashboard(t *testing.T) {
|
||||||
t.Run("user dashboard", func(t *testing.T) {
|
t.Run("user dashboard", func(t *testing.T) {
|
||||||
client := testcluster.NewHTTPClient()
|
client := testcluster.NewHTTPClient()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/.pomerium", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/.pomerium/", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -31,6 +31,22 @@ func TestDashboard(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, res.StatusCode, "unexpected status code")
|
assert.Equal(t, http.StatusFound, res.StatusCode, "unexpected status code")
|
||||||
})
|
})
|
||||||
|
t.Run("dashboard strict slash redirect", func(t *testing.T) {
|
||||||
|
client := testcluster.NewHTTPClient()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/.pomerium", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if !assert.NoError(t, err, "unexpected http error") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusMovedPermanently, res.StatusCode, "unexpected status code")
|
||||||
|
})
|
||||||
t.Run("image asset", func(t *testing.T) {
|
t.Run("image asset", func(t *testing.T) {
|
||||||
client := testcluster.NewHTTPClient()
|
client := testcluster.NewHTTPClient()
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ const (
|
||||||
QuerySession = "pomerium_session"
|
QuerySession = "pomerium_session"
|
||||||
QuerySessionEncrypted = "pomerium_session_encrypted"
|
QuerySessionEncrypted = "pomerium_session_encrypted"
|
||||||
QueryRedirectURI = "pomerium_redirect_uri"
|
QueryRedirectURI = "pomerium_redirect_uri"
|
||||||
QueryRefreshToken = "pomerium_refresh_token"
|
|
||||||
QueryAccessTokenID = "pomerium_session_access_token_id"
|
QueryAccessTokenID = "pomerium_session_access_token_id"
|
||||||
QueryAudience = "pomerium_session_audience"
|
QueryAudience = "pomerium_session_audience"
|
||||||
QueryProgrammaticToken = "pomerium_programmatic_token"
|
QueryProgrammaticToken = "pomerium_programmatic_token"
|
||||||
|
|
|
@ -7,11 +7,9 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/pomerium/csrf"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
)
|
)
|
||||||
|
@ -20,23 +18,7 @@ import (
|
||||||
func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
h := r.PathPrefix(dashboardPath).Subrouter()
|
h := r.PathPrefix(dashboardPath).Subrouter()
|
||||||
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||||
// 1. Retrieve the user session and add it to the request context
|
|
||||||
h.Use(func(h http.Handler) http.Handler {
|
|
||||||
return sessions.RetrieveSession(p.state.Load().sessionStore)(h)
|
|
||||||
})
|
|
||||||
// 2. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
|
||||||
h.Use(p.AuthenticateSession)
|
|
||||||
// 3. Enforce CSRF protections for any non-idempotent http method
|
|
||||||
h.Use(func(h http.Handler) http.Handler {
|
|
||||||
opts := p.currentOptions.Load()
|
|
||||||
state := p.state.Load()
|
|
||||||
return csrf.Protect(
|
|
||||||
state.cookieSecret,
|
|
||||||
csrf.Secure(opts.CookieSecure),
|
|
||||||
csrf.CookieName(fmt.Sprintf("%s_csrf", opts.CookieName)),
|
|
||||||
csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)),
|
|
||||||
)(h)
|
|
||||||
})
|
|
||||||
// dashboard endpoints can be used by user's to view, or modify their session
|
// dashboard endpoints can be used by user's to view, or modify their session
|
||||||
h.Path("/").HandlerFunc(p.UserDashboard).Methods(http.MethodGet)
|
h.Path("/").HandlerFunc(p.UserDashboard).Methods(http.MethodGet)
|
||||||
h.Path("/sign_out").HandlerFunc(p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
h.Path("/sign_out").HandlerFunc(p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
||||||
|
@ -48,13 +30,8 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
c.Use(func(h http.Handler) http.Handler {
|
c.Use(func(h http.Handler) http.Handler {
|
||||||
return middleware.ValidateSignature(p.state.Load().sharedKey)(h)
|
return middleware.ValidateSignature(p.state.Load().sharedKey)(h)
|
||||||
})
|
})
|
||||||
|
|
||||||
c.Path("/").
|
|
||||||
Handler(httputil.HandlerFunc(p.ProgrammaticCallback)).
|
|
||||||
Methods(http.MethodGet).
|
|
||||||
Queries(urlutil.QueryIsProgrammatic, "true")
|
|
||||||
|
|
||||||
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
|
c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
|
||||||
|
|
||||||
// Programmatic API handlers and middleware
|
// Programmatic API handlers and middleware
|
||||||
a := r.PathPrefix(dashboardPath + "/api").Subrouter()
|
a := r.PathPrefix(dashboardPath + "/api").Subrouter()
|
||||||
// login api handler generates a user-navigable login url to authenticate
|
// login api handler generates a user-navigable login url to authenticate
|
||||||
|
@ -92,7 +69,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signoutURL).String(), http.StatusFound)
|
httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signoutURL).String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserDashboard redirects to the authenticate dasbhoard.
|
// UserDashboard redirects to the authenticate dashboard.
|
||||||
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
|
|
||||||
|
@ -115,10 +92,23 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||||
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
|
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
|
||||||
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
||||||
|
|
||||||
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
|
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
|
||||||
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
httputil.Redirect(w, r, redirectURLString, http.StatusFound)
|
|
||||||
|
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if programmatic, encode the session jwt as a query param
|
||||||
|
if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
|
||||||
|
q := redirectURL.Query()
|
||||||
|
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
||||||
|
redirectURL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,28 +158,3 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
|
||||||
w.Write([]byte(response))
|
w.Write([]byte(response))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgrammaticCallback handles a successful call to the authenticate service.
|
|
||||||
// In addition to returning the individual route session (JWT) it also returns
|
|
||||||
// the refresh token.
|
|
||||||
func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
|
|
||||||
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
|
||||||
|
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
q := redirectURL.Query()
|
|
||||||
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
|
||||||
q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken))
|
|
||||||
redirectURL.RawQuery = q.Encode()
|
|
||||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -64,6 +64,29 @@ func TestProxy_Signout(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxy_UserDashboard(t *testing.T) {
|
||||||
|
opts := testOptions(t)
|
||||||
|
err := ValidateOptions(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
proxy, err := New(&config.Config{Options: opts})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
proxy.UserDashboard(rr, req)
|
||||||
|
if status := rr.Code; status != http.StatusFound {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
||||||
|
}
|
||||||
|
body := rr.Body.String()
|
||||||
|
want := proxy.state.Load().authenticateURL.String()
|
||||||
|
if !strings.Contains(body, want) {
|
||||||
|
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxy_SignOut(t *testing.T) {
|
func TestProxy_SignOut(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -105,13 +128,6 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func uriParseHelper(s string) *url.URL {
|
|
||||||
uri, err := url.Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return uri
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_Callback(t *testing.T) {
|
func TestProxy_Callback(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
@ -464,7 +480,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
httputil.HandlerFunc(p.ProgrammaticCallback).ServeHTTP(w, r)
|
httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
|
||||||
if status := w.Code; status != tt.wantStatus {
|
if status := w.Code; status != tt.wantStatus {
|
||||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||||
t.Errorf("\n%+v", w.Body.String())
|
t.Errorf("\n%+v", w.Body.String())
|
||||||
|
|
|
@ -14,8 +14,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type authorizeResponse struct {
|
type authorizeResponse struct {
|
||||||
|
@ -23,35 +21,6 @@ type authorizeResponse struct {
|
||||||
statusCode int32
|
statusCode int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticateSession is middleware to enforce a valid authentication
|
|
||||||
// session state is retrieved from the users's request context.
|
|
||||||
func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
if _, err := sessions.FromContext(ctx); err != nil {
|
|
||||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: session state")
|
|
||||||
return p.redirectToSignin(w, r)
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
state := p.state.Load()
|
|
||||||
|
|
||||||
signinURL := *state.authenticateSigninURL
|
|
||||||
q := signinURL.Query()
|
|
||||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
|
||||||
signinURL.RawQuery = q.Encode()
|
|
||||||
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
|
|
||||||
httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signinURL).String(), http.StatusFound)
|
|
||||||
state.sessionStore.ClearSession(w, r)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) {
|
func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) {
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
|
|
||||||
|
@ -104,20 +73,6 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorize
|
||||||
return ar, nil
|
return ar, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetResponseHeaders sets a map of response headers.
|
|
||||||
func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, span := trace.StartSpan(r.Context(), "proxy.SetResponseHeaders")
|
|
||||||
defer span.End()
|
|
||||||
for key, val := range headers {
|
|
||||||
r.Header.Set(key, val)
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtClaimMiddleware logs and propagates JWT claim information via request headers
|
// jwtClaimMiddleware logs and propagates JWT claim information via request headers
|
||||||
//
|
//
|
||||||
// if returnJWTInfo is set to true, it will also return JWT claim information in the response
|
// if returnJWTInfo is set to true, it will also return JWT claim information in the response
|
||||||
|
|
|
@ -1,89 +1,17 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxy_AuthenticateSession(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
refreshRespStatus int
|
|
||||||
errOnFailure bool
|
|
||||||
session sessions.SessionStore
|
|
||||||
ctxError error
|
|
||||||
provider identity.Authenticator
|
|
||||||
encoder encoding.MarshalUnmarshaler
|
|
||||||
refreshURL string
|
|
||||||
|
|
||||||
wantStatus int
|
|
||||||
}{
|
|
||||||
{"good", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
|
||||||
{"invalid session", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
|
|
||||||
{"expired", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(tt.refreshRespStatus)
|
|
||||||
fmt.Fprintln(w, "REFRESH GOOD")
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
rURL := ts.URL
|
|
||||||
if tt.refreshURL != "" {
|
|
||||||
rURL = tt.refreshURL
|
|
||||||
}
|
|
||||||
|
|
||||||
a := Proxy{
|
|
||||||
state: newAtomicProxyState(&proxyState{
|
|
||||||
sharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
|
||||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
|
||||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
|
||||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
|
||||||
authenticateRefreshURL: uriParseHelper(rURL),
|
|
||||||
sessionStore: tt.session,
|
|
||||||
encoder: tt.encoder,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
state, _ := tt.session.LoadSession(r)
|
|
||||||
ctx := r.Context()
|
|
||||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
r.Header.Set("Accept", "application/json")
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
got := a.jwtClaimMiddleware(false)(a.AuthenticateSession(fn))
|
|
||||||
got.ServeHTTP(w, r)
|
|
||||||
if status := w.Code; status != tt.wantStatus {
|
|
||||||
t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_jwtClaimMiddleware(t *testing.T) {
|
func Test_jwtClaimMiddleware(t *testing.T) {
|
||||||
claimHeaders := []string{"email", "groups", "missing"}
|
claimHeaders := []string{"email", "groups", "missing"}
|
||||||
sharedKey := "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
sharedKey := "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
||||||
|
@ -125,39 +53,3 @@ func Test_jwtClaimMiddleware(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_SetResponseHeaders(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
var sb strings.Builder
|
|
||||||
for k, v := range r.Header {
|
|
||||||
k = strings.ToLower(k)
|
|
||||||
for _, h := range v {
|
|
||||||
sb.WriteString(fmt.Sprintf("%v: %v\n", k, h))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Fprint(w, sb.String())
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setHeaders map[string]string
|
|
||||||
wantHeaders string
|
|
||||||
}{
|
|
||||||
{"good", map[string]string{"x-gonna": "give-it-to-ya"}, "x-gonna: give-it-to-ya\n"},
|
|
||||||
{"nil", nil, ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
got := SetResponseHeaders(tt.setHeaders)(fn)
|
|
||||||
got.ServeHTTP(w, r)
|
|
||||||
if diff := cmp.Diff(w.Body.String(), tt.wantHeaders); diff != "" {
|
|
||||||
t.Errorf("SetResponseHeaders() :\n %s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -25,9 +25,8 @@ args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
class PomeriumSession:
|
class PomeriumSession:
|
||||||
def __init__(self, jwt, refresh_token):
|
def __init__(self, jwt):
|
||||||
self.jwt = jwt
|
self.jwt = jwt
|
||||||
self.refresh_token = refresh_token
|
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
return json.dumps(self.__dict__, indent=2)
|
return json.dumps(self.__dict__, indent=2)
|
||||||
|
@ -55,7 +54,6 @@ class Callback(http.server.BaseHTTPRequestHandler):
|
||||||
path_qp = urllib.parse.parse_qs(path)
|
path_qp = urllib.parse.parse_qs(path)
|
||||||
session = PomeriumSession(
|
session = PomeriumSession(
|
||||||
path_qp.get("pomerium_jwt")[0],
|
path_qp.get("pomerium_jwt")[0],
|
||||||
path_qp.get("pomerium_refresh_token")[0],
|
|
||||||
)
|
)
|
||||||
done = True
|
done = True
|
||||||
response = session.to_json().encode()
|
response = session.to_json().encode()
|
||||||
|
|
Loading…
Add table
Reference in a new issue