diff --git a/internal/controlplane/luascripts/fix-misdirected.lua b/internal/controlplane/luascripts/fix-misdirected.lua index 1cc19d5e7..a3db3e67a 100644 --- a/internal/controlplane/luascripts/fix-misdirected.lua +++ b/internal/controlplane/luascripts/fix-misdirected.lua @@ -12,8 +12,12 @@ function envoy_on_response(response_handle) local headers = response_handle:headers() local dynamic_meta = response_handle:streamInfo():dynamicMetadata() - local authority = - dynamic_meta:get("envoy.filters.http.lua")["request.authority"] + local filter_meta = dynamic_meta:get("envoy.filters.http.lua") + if filter_meta == nil then + return + end + + local authority = filter_meta["request.authority"] local expected_authority = "%s" -- if we got a 404 (no route found) and the authority header doesn't match diff --git a/internal/controlplane/xds_lua_test.go b/internal/controlplane/xds_lua_test.go index 6eb9199ee..2fbf84996 100644 --- a/internal/controlplane/xds_lua_test.go +++ b/internal/controlplane/xds_lua_test.go @@ -10,6 +10,61 @@ import ( lua "github.com/yuin/gopher-lua" ) +func TestLuaFixMisdirected(t *testing.T) { + t.Run("request", func(t *testing.T) { + L := lua.NewState() + defer L.Close() + + bs, err := luaFS.ReadFile("luascripts/fix-misdirected.lua") + require.NoError(t, err) + + err = L.DoString(string(bs)) + require.NoError(t, err) + + headers := map[string]string{ + ":authority": "TEST", + } + metadata := map[string]interface{}{} + dynamicMetadata := map[string]map[string]interface{}{} + handle := newLuaResponseHandle(L, headers, metadata, dynamicMetadata) + + err = L.CallByParam(lua.P{ + Fn: L.GetGlobal("envoy_on_request"), + NRet: 0, + Protect: true, + }, handle) + require.NoError(t, err) + + assert.Equal(t, map[string]map[string]interface{}{ + "envoy.filters.http.lua": { + "request.authority": "TEST", + }, + }, dynamicMetadata) + }) + t.Run("empty metadata", func(t *testing.T) { + L := lua.NewState() + defer L.Close() + + bs, err := luaFS.ReadFile("luascripts/fix-misdirected.lua") + require.NoError(t, err) + + err = L.DoString(string(bs)) + require.NoError(t, err) + + headers := map[string]string{} + metadata := map[string]interface{}{} + dynamicMetadata := map[string]map[string]interface{}{} + handle := newLuaResponseHandle(L, headers, metadata, dynamicMetadata) + + err = L.CallByParam(lua.P{ + Fn: L.GetGlobal("envoy_on_response"), + NRet: 0, + Protect: true, + }, handle) + require.NoError(t, err) + }) +} + func TestLuaRewriteHeaders(t *testing.T) { L := lua.NewState() defer L.Close() @@ -37,7 +92,7 @@ func TestLuaRewriteHeaders(t *testing.T) { }, }, } - handle := newLuaResponseHandle(L, headers, metadata) + handle := newLuaResponseHandle(L, headers, metadata, nil) err = L.CallByParam(lua.P{ Fn: L.GetGlobal("envoy_on_response"), @@ -49,9 +104,12 @@ func TestLuaRewriteHeaders(t *testing.T) { assert.Equal(t, "https://frontend/one/some/uri/", headers["Location"]) } -func newLuaResponseHandle(L *lua.LState, headers map[string]string, metadata map[string]interface{}) lua.LValue { - typ := L.NewTable() - L.SetFuncs(typ, map[string]lua.LGFunction{ +func newLuaResponseHandle(L *lua.LState, + headers map[string]string, + metadata map[string]interface{}, + dynamicMetadata map[string]map[string]interface{}, +) lua.LValue { + return newLuaType(L, map[string]lua.LGFunction{ "headers": func(L *lua.LState) int { L.Push(newLuaHeaders(L, headers)) return 1 @@ -60,12 +118,11 @@ func newLuaResponseHandle(L *lua.LState, headers map[string]string, metadata map L.Push(newLuaMetadata(L, metadata)) return 1 }, + "streamInfo": func(L *lua.LState) int { + L.Push(newLuaStreamInfo(L, dynamicMetadata)) + return 1 + }, }) - L.SetField(typ, "__index", typ) - - tbl := L.NewTable() - L.SetMetatable(tbl, typ) - return tbl } func newLuaHeaders(L *lua.LState, headers map[string]string) lua.LValue { @@ -102,8 +159,7 @@ func newLuaHeaders(L *lua.LState, headers map[string]string) lua.LValue { } func newLuaMetadata(L *lua.LState, metadata map[string]interface{}) lua.LValue { - typ := L.NewTable() - L.SetFuncs(typ, map[string]lua.LGFunction{ + return newLuaType(L, map[string]lua.LGFunction{ "get": func(L *lua.LState) int { _ = L.CheckTable(1) key := L.CheckString(2) @@ -118,6 +174,53 @@ func newLuaMetadata(L *lua.LState, metadata map[string]interface{}) lua.LValue { return 1 }, }) +} + +func newLuaDynamicMetadata(L *lua.LState, metadata map[string]map[string]interface{}) lua.LValue { + return newLuaType(L, map[string]lua.LGFunction{ + "get": func(L *lua.LState) int { + _ = L.CheckTable(1) + key := L.CheckString(2) + + obj, ok := metadata[key] + if !ok { + L.Push(lua.LNil) + return 0 + } + + L.Push(toLua(L, obj)) + return 1 + }, + "set": func(L *lua.LState) int { + _ = L.CheckTable(1) + filterName := L.CheckString(2) + key := L.CheckString(3) + value := L.CheckAny(4) + + m, ok := metadata[filterName] + if !ok { + m = make(map[string]interface{}) + metadata[filterName] = m + } + m[key] = fromLua(L, value) + + return 0 + }, + }) +} + +func newLuaStreamInfo(L *lua.LState, dynamicMetadata map[string]map[string]interface{}) lua.LValue { + return newLuaType(L, map[string]lua.LGFunction{ + "dynamicMetadata": func(L *lua.LState) int { + L.Push(newLuaDynamicMetadata(L, dynamicMetadata)) + return 1 + }, + }) +} + +func newLuaType(L *lua.LState, funcs map[string]lua.LGFunction) lua.LValue { + typ := L.NewTable() + L.SetFuncs(typ, funcs) L.SetField(typ, "__index", typ) tbl := L.NewTable() @@ -125,6 +228,35 @@ func newLuaMetadata(L *lua.LState, metadata map[string]interface{}) lua.LValue { return tbl } +func fromLua(L *lua.LState, v lua.LValue) interface{} { + switch v.Type() { + case lua.LTNil: + return nil + case lua.LTBool: + return bool(v.(lua.LBool)) + case lua.LTNumber: + return float64(v.(lua.LNumber)) + case lua.LTString: + return string(v.(lua.LString)) + case lua.LTTable: + a := []interface{}{} + m := map[string]interface{}{} + v.(*lua.LTable).ForEach(func(key, value lua.LValue) { + if key.Type() == lua.LTNumber { + a = append(a, fromLua(L, value)) + } else { + m[lua.LVAsString(key)] = fromLua(L, value) + } + }) + if len(a) > 0 { + return a + } + return m + default: + panic("not supported") + } +} + func toLua(L *lua.LState, obj interface{}) lua.LValue { // send the object through JSON to remove custom types var normalized interface{}