add unit tests, update internal policy evaluator

This commit is contained in:
Kenneth Jenkins 2024-08-29 19:14:47 -07:00
parent 8d38e6b47d
commit 5ff3e9794e
4 changed files with 154 additions and 3 deletions

View file

@ -238,9 +238,15 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
return res, nil
}
// Internal endpoints that require a logged-in user.
var internalPathsNeedingLogin = map[string]struct{}{
"/.pomerium/jwt": {},
"/.pomerium/user": {},
"/.pomerium/webauthn": {},
}
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
// these endpoints require a logged-in user
if req.HTTP.Path == "/.pomerium/webauthn" || req.HTTP.Path == "/.pomerium/jwt" {
if _, needsLogin := internalPathsNeedingLogin[req.HTTP.Path]; needsLogin {
if req.Session.ID == "" {
return &PolicyResponse{
Allow: NewRuleResult(false, criteria.ReasonUserUnauthenticated),

View file

@ -554,6 +554,72 @@ func TestEvaluator(t *testing.T) {
})
}
func TestEvaluator_EvaluateInternal(t *testing.T) {
ctx := context.Background()
store := store.New()
evaluator, err := New(ctx, store, nil)
require.NoError(t, err)
// Internal paths that do not require login.
for _, path := range []string{
"/.pomerium/",
"/.pomerium/device-enrolled",
"/.pomerium/sign_out",
} {
t.Run(path, func(t *testing.T) {
req := Request{
IsInternal: true,
HTTP: RequestHTTP{
Path: path,
},
}
result, err := evaluator.Evaluate(ctx, &req)
require.NoError(t, err)
assert.Equal(t, RuleResult{
Value: true,
Reasons: criteria.NewReasons(criteria.ReasonPomeriumRoute),
AdditionalData: map[string]any{},
}, result.Allow)
assert.Equal(t, RuleResult{}, result.Deny)
})
}
// Internal paths that do require login.
for _, path := range []string{
"/.pomerium/jwt",
"/.pomerium/user",
"/.pomerium/webauthn",
} {
t.Run(path, func(t *testing.T) {
req := Request{
IsInternal: true,
HTTP: RequestHTTP{
Path: path,
},
}
result, err := evaluator.Evaluate(ctx, &req)
require.NoError(t, err)
assert.Equal(t, RuleResult{
Value: false,
Reasons: criteria.NewReasons(criteria.ReasonUserUnauthenticated),
AdditionalData: map[string]any{},
}, result.Allow)
assert.Equal(t, RuleResult{}, result.Deny)
// Simulate a logged-in user by setting a non-empty session ID.
req.Session.ID = "123456"
result, err = evaluator.Evaluate(ctx, &req)
require.NoError(t, err)
assert.Equal(t, RuleResult{
Value: true,
Reasons: criteria.NewReasons(criteria.ReasonPomeriumRoute),
AdditionalData: map[string]any{},
}, result.Allow)
assert.Equal(t, RuleResult{}, result.Deny)
})
}
}
func TestPolicyEvaluatorReuse(t *testing.T) {
ctx := context.Background()

View file

@ -182,7 +182,9 @@ func userInfoFromJWT(rawJWT string) map[string]any {
}
var payload map[string]any
if parsed.UnsafeClaimsWithoutVerification(&payload) != nil || payload["sub"] == "" {
if parsed.UnsafeClaimsWithoutVerification(&payload) != nil {
return nil
} else if sub, ok := payload["sub"].(string); !ok || sub == "" {
return nil
}

View file

@ -2,13 +2,16 @@ package proxy
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
@ -183,3 +186,77 @@ func TestProxy_jwt(t *testing.T) {
assert.Equal(t, "application/jwt", w.Header().Get("Content-Type"))
assert.Equal(t, w.Body.String(), rawJWT)
}
func TestProxy_jsonUserInfo(t *testing.T) {
proxy := &Proxy{
state: atomicutil.NewValue(&proxyState{}),
}
t.Run("no_jwt", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/.pomerium/user", nil)
w := httptest.NewRecorder()
err := proxy.jsonUserInfo(w, req)
assert.ErrorContains(t, err, "not found")
})
t.Run("no_sub_claim", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/.pomerium/user", nil)
req.Header.Set("X-Pomerium-Jwt-Assertion", "eyJ0eXAiOiJKV1QiLCJhbGciOiJub25lIn0.eyJmb28iOiJiYXIifQ.")
w := httptest.NewRecorder()
err := proxy.jsonUserInfo(w, req)
assert.ErrorContains(t, err, "not found")
})
t.Run("valid_jwt", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/.pomerium/user", nil)
req.Header.Set("X-Pomerium-Jwt-Assertion",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTY3MDg4OTI0MSwiZXhwIjoxNjcwODkyODQxfQ.YoROB12_-a8VxikPqrYOA576pLYoLFeGwXAOWCGpXgM")
w := httptest.NewRecorder()
err := proxy.jsonUserInfo(w, req)
require.NoError(t, err)
result := w.Result()
assert.Equal(t, http.StatusOK, result.StatusCode)
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
b, _ := io.ReadAll(result.Body)
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`, string(b))
})
}
// The /.pomerium/jwt endpoint should be registered only if explicitly enabled.
func TestProxy_registerDashboardHandlers_jwtEndpoint(t *testing.T) {
proxy := &Proxy{
state: atomicutil.NewValue(&proxyState{}),
}
req := httptest.NewRequest(http.MethodGet, "/.pomerium/jwt", nil)
rawJWT := "eyJ0eXAiOiJKV1QiLCJhbGciOiJub25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0."
req.Header.Set("X-Pomerium-Jwt-Assertion", rawJWT)
t.Run("disabled", func(t *testing.T) {
opts := config.NewDefaultOptions()
opts.RuntimeFlags[config.RuntimeFlagPomeriumJWTEndpoint] = false
m := mux.NewRouter()
proxy.registerDashboardHandlers(m, opts)
w := httptest.NewRecorder()
m.ServeHTTP(w, req)
result := w.Result()
assert.Equal(t, http.StatusNotFound, result.StatusCode)
assert.Equal(t, "text/plain; charset=utf-8", result.Header.Get("Content-Type"))
b, _ := io.ReadAll(result.Body)
assert.Equal(t, "404 page not found\n", string(b))
})
t.Run("enabled", func(t *testing.T) {
opts := config.NewDefaultOptions()
opts.RuntimeFlags[config.RuntimeFlagPomeriumJWTEndpoint] = true
m := mux.NewRouter()
proxy.registerDashboardHandlers(m, opts)
w := httptest.NewRecorder()
m.ServeHTTP(w, req)
result := w.Result()
assert.Equal(t, http.StatusOK, result.StatusCode)
assert.Equal(t, "application/jwt", result.Header.Get("Content-Type"))
b, _ := io.ReadAll(result.Body)
assert.Equal(t, rawJWT, string(b))
})
}