package proxy import ( "bytes" "context" "io" "net/http" "net/http/httptest" "net/url" "testing" "github.com/google/go-cmp/cmp" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" ) func TestProxy_SignOut(t *testing.T) { t.Parallel() tests := []struct { name string verb string redirectURL string wantStatus int }{ {"good post", http.MethodPost, "https://test.example", http.StatusFound}, {"good get", http.MethodGet, "https://test.example", http.StatusFound}, {"good empty default", http.MethodGet, "", http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := testOptions(t) p, err := New(context.Background(), &config.Config{Options: opts}) if err != nil { t.Fatal(err) } postForm := url.Values{} postForm.Add(urlutil.QueryRedirectURI, tt.redirectURL) uri := &url.URL{Path: "/"} query, _ := url.ParseQuery(uri.RawQuery) if tt.verb == http.MethodGet { query.Add(urlutil.QueryRedirectURI, tt.redirectURL) uri.RawQuery = query.Encode() } r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode())) w := httptest.NewRecorder() if tt.verb == http.MethodPost { r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") } p.SignOut(w, r) if status := w.Code; status != tt.wantStatus { t.Errorf("status code: got %v want %v", status, tt.wantStatus) } u, err := urlutil.ParseAndValidateURL(w.HeaderMap.Get("Location")) if assert.NoError(t, err) { assert.Equal(t, "/.pomerium/sign_out", u.Path) } }) } } func TestProxy_ProgrammaticLogin(t *testing.T) { t.Parallel() opts := testOptions(t) tests := []struct { name string options *config.Options method string scheme string host string path string headers map[string]string qp map[string]string wantStatus int wantBody string }{ { "good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, "", }, { "good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, "", }, { "router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, "", }, { "bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "localhost"}, http.StatusBadRequest, "{\"Status\":400}\n", }, { "bad redirect_uri not whitelisted", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "https://example.com"}, http.StatusBadRequest, "{\"Status\":400}\n", }, { "bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusMethodNotAllowed, "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p, err := New(context.Background(), &config.Config{Options: tt.options}) if err != nil { t.Fatal(err) } redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path} queryString := redirectURI.Query() for k, v := range tt.qp { queryString.Set(k, v) } redirectURI.RawQuery = queryString.Encode() r := httptest.NewRequest(tt.method, redirectURI.String(), nil) r.Header.Set("Accept", "application/json") w := httptest.NewRecorder() router := httputil.NewRouter() router = p.registerDashboardHandlers(router, config.NewDefaultOptions()) router.ServeHTTP(w, r) if status := w.Code; status != tt.wantStatus { t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("\n%+v", w.Body.String()) } if tt.wantBody != "" { body := w.Body.String() if diff := cmp.Diff(body, tt.wantBody); diff != "" { t.Errorf("wrong body\n%s", diff) } } }) } } func TestProxy_jwt(t *testing.T) { // without upstream headers being set req, _ := http.NewRequest(http.MethodGet, "https://www.example.com/.pomerium/jwt", nil) w := httptest.NewRecorder() proxy := &Proxy{ state: atomicutil.NewValue(&proxyState{}), } err := proxy.jwtAssertion(w, req) if !assert.Error(t, err) { return } // with upstream request headers being set rawJWT := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTY3MDg4OTI0MSwiZXhwIjoxNjcwODkyODQxfQ.YoROB12_-a8VxikPqrYOA576pLYoLFeGwXAOWCGpXgM" req, _ = http.NewRequest(http.MethodGet, "https://www.example.com/.pomerium/jwt", nil) w = httptest.NewRecorder() req.Header.Set(httputil.HeaderPomeriumJWTAssertion, rawJWT) err = proxy.jwtAssertion(w, req) if !assert.NoError(t, err) { return } assert.Equal(t, "application/jwt", w.Header().Get("Content-Type")) assert.Equal(t, w.Body.String(), rawJWT) } func TestProxy_jsonUserInfo(t *testing.T) { proxy := &Proxy{ state: atomicutil.NewValue(&proxyState{}), } t.Run("no_jwt", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/.pomerium/user", nil) w := httptest.NewRecorder() err := proxy.jsonUserInfo(w, req) assert.ErrorContains(t, err, "not found") }) t.Run("no_sub_claim", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/.pomerium/user", nil) req.Header.Set("X-Pomerium-Jwt-Assertion", "eyJ0eXAiOiJKV1QiLCJhbGciOiJub25lIn0.eyJmb28iOiJiYXIifQ.") w := httptest.NewRecorder() err := proxy.jsonUserInfo(w, req) assert.ErrorContains(t, err, "not found") }) t.Run("valid_jwt", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/.pomerium/user", nil) req.Header.Set("X-Pomerium-Jwt-Assertion", "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTY3MDg4OTI0MSwiZXhwIjoxNjcwODkyODQxfQ.YoROB12_-a8VxikPqrYOA576pLYoLFeGwXAOWCGpXgM") w := httptest.NewRecorder() err := proxy.jsonUserInfo(w, req) require.NoError(t, err) result := w.Result() assert.Equal(t, http.StatusOK, result.StatusCode) assert.Equal(t, "application/json", result.Header.Get("Content-Type")) b, _ := io.ReadAll(result.Body) assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`, string(b)) }) } // The /.pomerium/jwt endpoint should be registered only if explicitly enabled. func TestProxy_registerDashboardHandlers_jwtEndpoint(t *testing.T) { proxy, err := New(context.Background(), &config.Config{Options: config.NewDefaultOptions()}) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "/.pomerium/jwt", nil) rawJWT := "eyJ0eXAiOiJKV1QiLCJhbGciOiJub25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0." req.Header.Set("X-Pomerium-Jwt-Assertion", rawJWT) t.Run("disabled", func(t *testing.T) { opts := config.NewDefaultOptions() opts.RuntimeFlags[config.RuntimeFlagPomeriumJWTEndpoint] = false m := mux.NewRouter() proxy.registerDashboardHandlers(m, opts) w := httptest.NewRecorder() m.ServeHTTP(w, req) result := w.Result() assert.Equal(t, http.StatusNotFound, result.StatusCode) assert.Equal(t, "text/plain; charset=utf-8", result.Header.Get("Content-Type")) b, _ := io.ReadAll(result.Body) assert.Equal(t, "404 page not found\n", string(b)) }) t.Run("enabled", func(t *testing.T) { opts := config.NewDefaultOptions() opts.RuntimeFlags[config.RuntimeFlagPomeriumJWTEndpoint] = true m := mux.NewRouter() proxy.registerDashboardHandlers(m, opts) w := httptest.NewRecorder() m.ServeHTTP(w, req) result := w.Result() assert.Equal(t, http.StatusOK, result.StatusCode) assert.Equal(t, "application/jwt", result.Header.Get("Content-Type")) b, _ := io.ReadAll(result.Body) assert.Equal(t, rawJWT, string(b)) }) } func TestLoadSessionState(t *testing.T) { t.Parallel() t.Run("no session", func(t *testing.T) { t.Parallel() opts := testOptions(t) proxy, err := New(context.Background(), &config.Config{Options: opts}) require.NoError(t, err) r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) w := httptest.NewRecorder() proxy.ServeHTTP(w, r) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "window.POMERIUM_DATA") assert.NotContains(t, w.Body.String(), "___SESSION_ID___") }) t.Run("cookie session", func(t *testing.T) { t.Parallel() opts := testOptions(t) proxy, err := New(context.Background(), &config.Config{Options: opts}) require.NoError(t, err) session := encodeSession(t, opts, &sessions.State{ ID: "___SESSION_ID___", }) r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) r.AddCookie(&http.Cookie{ Name: opts.CookieName, Domain: opts.CookieDomain, Value: session, }) w := httptest.NewRecorder() proxy.ServeHTTP(w, r) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "___SESSION_ID___") }) t.Run("header session", func(t *testing.T) { t.Parallel() opts := testOptions(t) proxy, err := New(context.Background(), &config.Config{Options: opts}) require.NoError(t, err) session := encodeSession(t, opts, &sessions.State{ ID: "___SESSION_ID___", }) r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) r.Header.Set("Authorization", "Bearer Pomerium-"+session) w := httptest.NewRecorder() proxy.ServeHTTP(w, r) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "___SESSION_ID___") }) } func encodeSession(t *testing.T, opts *config.Options, state *sessions.State) string { sharedKey, err := opts.GetSharedKey() require.NoError(t, err) encoder, err := jws.NewHS256Signer(sharedKey) require.NoError(t, err) sessionBS, err := encoder.Marshal(state) require.NoError(t, err) return string(sessionBS) }