Browse Source

Re-introduce the stack-reuse optimization

As predicted, with some extra thought it is possible to preserve the
stack-reuse optimization in lvm, the one having to do with CIST_FRESH.

First, this improves the performance by reducing stack usage for Lua-Lua
calls. Furthermore, it also fixes some bugs! For example, when we had
disabled the optimization, we had inadvertently broken vararg functions
(it had to do with the Protect messing with the L->top)
Hugo Musso Gualandi 4 years ago
parent
commit
16c39df347
3 changed files with 53 additions and 122 deletions
  1. 23 31
      src/luaot-trampoline.c
  2. 27 21
      src/luaot.c
  3. 3 70
      src/lvm.c

+ 23 - 31
src/luaot-trampoline.c

@@ -676,6 +676,19 @@ void luaot_PrintOpcodeComment(Proto *f, int pc)
 //   - Jumps go into a trampoline
 //
 
+static
+void println_goto_ret()
+{
+    // This is the piece of code that is after the "ret" label.
+    // It should be used in the places that do "goto ret;"
+    println("        if (ci->callstatus & CIST_FRESH)");
+    println("            return;  /* end this frame */");
+    println("        else {");
+    println("            ci = ci->previous;");
+    println("            return luaV_execute(L, ci); /* continue running caller in this frame */"); // (!)
+    println("        }");
+}
+
 static
 void create_function(Proto *f)
 {
@@ -717,21 +730,6 @@ void create_function(Proto *f)
     println("  StkId ra;");
     printnl();
 
-    // If we are resuming a coroutine, jump to the savedpc.
-    // However, allowing coroutines hurts performance so we disable it by default.
-    if (enable_coroutines) {
-        println("  switch (pc - code) {");
-        for (int pc = 0; pc < f->sizecode; pc++) {
-            println("    case %d: goto label_%02d;", pc, pc);
-        }
-        println("  }");
-    } else {
-        println("  if (pc != code) {");
-        println("    luaG_runerror(L, \"This program was compiled without support for coroutines\\n\");");
-        println("  }");
-    }
-    printnl();
-
     println("  while (1) {");
     println("    switch (pc - code) {");
     for (int pc = 0; pc < f->sizecode; pc++) {
@@ -740,13 +738,8 @@ void create_function(Proto *f)
 
         luaot_PrintOpcodeComment(f, pc);
 
-        if (enable_coroutines) {
-            println("      case %d: label_%02d: {", pc, pc);
-            println("        aot_vmfetch(0x%08x);", instr);
-        } else {
-            println("      case %d: {", pc);
-            println("        aot_vmfetch(0x%08x);", instr);
-        }
+        println("      case %d: {", pc);
+        println("        aot_vmfetch(0x%08x);", instr);
 
         switch (op) {
             case OP_MOVE: {
@@ -1308,8 +1301,6 @@ void create_function(Proto *f)
                 break;
             }
             case OP_CALL: {
-                // We have to adapt this opcode to remove the optimization that
-                // reuses the luaV_execute stack frame. The goto startfunc.
                 println("        CallInfo *newci;");
                 println("        int b = GETARG_B(i);");
                 println("        int nresults = GETARG_C(i) - 1;");
@@ -1320,8 +1311,9 @@ void create_function(Proto *f)
                 println("        if ((newci = luaD_precall(L, ra, nresults)) == NULL)");
                 println("            updatetrap(ci);  /* C call; nothing else to be done */");
                 println("        else {");
-                println("            newci->callstatus = CIST_FRESH;");
-                println("            Protect(luaV_execute(L, newci));");//(!)
+                println("            ci = newci;");
+                println("            ci->callstatus = 0;  /* call re-uses 'luaV_execute' */");
+                println("            return luaV_execute(L, ci);"); // (!!!)
                 println("        }");
                 // FALLTHROUGH
                 break;
@@ -1353,11 +1345,11 @@ void create_function(Proto *f)
                 println("          ci->func -= delta;  /* restore 'func' (if vararg) */");
                 println("          luaD_poscall(L, ci, cast_int(L->top - ra));  /* finish caller */");
                 println("          updatetrap(ci);  /* 'luaD_poscall' can change hooks */");
-                println("          return;  /* caller returns after the tail call */");//(!)
+                println_goto_ret(); // (!)
                 println("        }");
                 println("        ci->func -= delta;  /* restore 'func' (if vararg) */");
                 println("        luaD_pretailcall(L, ci, ra, b);  /* prepare call frame */");
-                println("        return luaV_execute(L, ci); /* execute the callee */");//(!)
+                println("        return luaV_execute(L, ci); /* execute the callee */"); // (!)
                 // FALLTHROUGH
                 break;
             }
@@ -1379,7 +1371,7 @@ void create_function(Proto *f)
                 println("        L->top = ra + n;  /* set call for 'luaD_poscall' */");
                 println("        luaD_poscall(L, ci, n);");
                 println("        updatetrap(ci);  /* 'luaD_poscall' can change hooks */");
-                println("        return;"); //(!)
+                println_goto_ret(); // (!)
                 // FALLTHROUGH
                 break;
             }
@@ -1397,7 +1389,7 @@ void create_function(Proto *f)
                 println("          for (nres = ci->nresults; l_unlikely(nres > 0); nres--)");
                 println("            setnilvalue(s2v(L->top++));  /* all results are nil */");
                 println("        }");
-                println("        return;"); //(!)
+                println_goto_ret(); // (!)
                 // FALLTHROUGH
                 break;
             }
@@ -1420,7 +1412,7 @@ void create_function(Proto *f)
                 println("              setnilvalue(s2v(L->top++));");
                 println("          }");
                 println("        }");
-                println("        return;"); //(!)
+                println_goto_ret(); // (!)
                 // FALLTHROUGH
                 break;
             }

+ 27 - 21
src/luaot.c

@@ -684,6 +684,19 @@ int jump_target(Proto *f, int pc)
     return (pc+1) + GETARG_sJ(instr);
 }
 
+static
+void println_goto_ret()
+{
+    // This is the piece of code that is after the "ret" label.
+    // It should be used in the places that do "goto ret;"
+    println("    if (ci->callstatus & CIST_FRESH)");
+    println("        return;  /* end this frame */");
+    println("    else {");
+    println("        ci = ci->previous;");
+    println("        return luaV_execute(L, ci); /* continue running caller in this frame */"); // (!)
+    println("    }");
+}
+
 static
 void create_function(Proto *f)
 {
@@ -725,19 +738,13 @@ void create_function(Proto *f)
     println("  StkId ra;");
     printnl();
 
-    // If we are resuming a coroutine, jump to the savedpc.
-    // However, allowing coroutines hurts performance so we disable it by default.
-    if (enable_coroutines) {
-        println("  switch (pc - code) {");
-        for (int pc = 0; pc < f->sizecode; pc++) {
-            println("    case %d: goto label_%02d;", pc, pc);
-        }
-        println("  }");
-    } else {
-        println("  if (pc != code) {");
-        println("    luaG_runerror(L, \"This program was compiled without support for coroutines\\n\");");
-        println("  }");
+    // If we are returning from another function, or resuming a coroutine,
+    // jump back to where left.
+    println("  switch (pc - code) {");
+    for (int pc = 0; pc < f->sizecode; pc++) {
+        println("    case %d: goto label_%02d;", pc, pc);
     }
+    println("  }");
     printnl();
 
     for (int pc = 0; pc < f->sizecode; pc++) {
@@ -1225,8 +1232,6 @@ void create_function(Proto *f)
                 break;
             }
             case OP_CALL: {
-                // We have to adapt this opcode to remove the optimization that
-                // reuses the luaV_execute stack frame. The goto startfunc.
                 println("    CallInfo *newci;");
                 println("    int b = GETARG_B(i);");
                 println("    int nresults = GETARG_C(i) - 1;");
@@ -1237,8 +1242,9 @@ void create_function(Proto *f)
                 println("    if ((newci = luaD_precall(L, ra, nresults)) == NULL)");
                 println("        updatetrap(ci);  /* C call; nothing else to be done */");
                 println("    else {");
-                println("        newci->callstatus = CIST_FRESH;");
-                println("        Protect(luaV_execute(L, newci));");//(!)
+                println("        ci = newci;");
+                println("        ci->callstatus = 0;  /* call re-uses 'luaV_execute' */");
+                println("        return luaV_execute(L, ci);"); // (!!!)
                 println("    }");
                 break;
             }
@@ -1269,11 +1275,11 @@ void create_function(Proto *f)
                 println("      ci->func -= delta;  /* restore 'func' (if vararg) */");
                 println("      luaD_poscall(L, ci, cast_int(L->top - ra));  /* finish caller */");
                 println("      updatetrap(ci);  /* 'luaD_poscall' can change hooks */");
-                println("      return;  /* caller returns after the tail call */");//(!)
+                println_goto_ret(); // (!)
                 println("    }");
                 println("    ci->func -= delta;  /* restore 'func' (if vararg) */");
                 println("    luaD_pretailcall(L, ci, ra, b);  /* prepare call frame */");
-                println("    return luaV_execute(L, ci); /* execute the callee */");//(!)
+                println("    return luaV_execute(L, ci); /* execute the callee */"); // (!)
                 break;
             }
             case OP_RETURN: {
@@ -1294,7 +1300,7 @@ void create_function(Proto *f)
                 println("    L->top = ra + n;  /* set call for 'luaD_poscall' */");
                 println("    luaD_poscall(L, ci, n);");
                 println("    updatetrap(ci);  /* 'luaD_poscall' can change hooks */");
-                println("    return;"); //(!)
+                println_goto_ret();
                 break;
             }
             case OP_RETURN0: {
@@ -1311,7 +1317,7 @@ void create_function(Proto *f)
                 println("      for (nres = ci->nresults; l_unlikely(nres > 0); nres--)");
                 println("        setnilvalue(s2v(L->top++));  /* all results are nil */");
                 println("    }");
-                println("    return;"); //(!)
+                println_goto_ret();
                 break;
             }
             case OP_RETURN1: {
@@ -1333,7 +1339,7 @@ void create_function(Proto *f)
                 println("          setnilvalue(s2v(L->top++));");
                 println("      }");
                 println("    }");
-                println("    return;"); //(!)
+                println_goto_ret();
                 break;
             }
             case OP_FORLOOP: {

+ 3 - 70
src/lvm.c

@@ -1139,18 +1139,14 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
 #if LUA_USE_JUMPTABLE
 #include "ljumptab.h"
 #endif
-#if LUAOT
+ startfunc:
   trap = L->hookmask;
+ returning:  /* trap already set */
   cl = clLvalue(s2v(ci->func));
-  lua_assert(ci->callstatus & CIST_FRESH);
+#if LUAOT
   if (cl->p->aot_implementation) {
       return cl->p->aot_implementation(L, ci);
   }
-#else
- startfunc:
-  trap = L->hookmask;
- returning:  /* trap already set */
-  cl = clLvalue(s2v(ci->func));
 #endif
   k = cl->p->k;
   pc = ci->u.l.savedpc;
@@ -1626,24 +1622,6 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
         }
         vmbreak;
       }
-#if LUAOT
-      vmcase(OP_CALL) {
-        CallInfo *newci;
-        int b = GETARG_B(i);
-        int nresults = GETARG_C(i) - 1;
-        if (b != 0)  /* fixed number of arguments? */
-          L->top = ra + b;  /* top signals number of arguments */
-        /* else previous instruction set top */
-        savepc(L);  /* in case of errors */
-        if ((newci = luaD_precall(L, ra, nresults)) == NULL)
-          updatetrap(ci);  /* C call; nothing else to be done */
-        else {  /* Lua call: run function in this same C frame */
-          newci->callstatus = CIST_FRESH;
-          Protect(luaV_execute(L, newci));
-        }
-        vmbreak;
-      }
-#else
       vmcase(OP_CALL) {
         CallInfo *newci;
         int b = GETARG_B(i);
@@ -1661,44 +1639,6 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
         }
         vmbreak;
       }
-#endif
-
-#if LUAOT
-      vmcase(OP_TAILCALL) {
-        int b = GETARG_B(i);  /* number of arguments + 1 (function) */
-        int nparams1 = GETARG_C(i);
-        /* delta is virtual 'func' - real 'func' (vararg functions) */
-        int delta = (nparams1) ? ci->u.l.nextraargs + nparams1 : 0;
-        if (b != 0)
-          L->top = ra + b;
-        else  /* previous instruction set top */
-          b = cast_int(L->top - ra);
-        savepc(ci);  /* several calls here can raise errors */
-        if (TESTARG_k(i)) {
-          luaF_closeupval(L, base);  /* close upvalues from current call */
-          lua_assert(L->tbclist < base);  /* no pending tbc variables */
-          lua_assert(base == ci->func + 1);
-        }
-        while (!ttisfunction(s2v(ra))) {  /* not a function? */
-          luaD_tryfuncTM(L, ra);  /* try '__call' metamethod */
-          b++;  /* there is now one extra argument */
-          checkstackGCp(L, 1, ra);
-        }
-        if (!ttisLclosure(s2v(ra))) {  /* C function? */
-          luaD_precall(L, ra, LUA_MULTRET);  /* call it */
-          updatetrap(ci);
-          updatestack(ci);  /* stack may have been relocated */
-          ci->func -= delta;  /* restore 'func' (if vararg) */
-          luaD_poscall(L, ci, cast_int(L->top - ra));  /* finish caller */
-          updatetrap(ci);  /* 'luaD_poscall' can change hooks */
-          goto ret;  /* caller returns after the tail call */
-        }
-        ci->func -= delta;  /* restore 'func' (if vararg) */
-        luaD_pretailcall(L, ci, ra, b);  /* prepare call frame */
-        return luaV_execute(L, ci); /* execute the callee */
-        vmbreak;
-      }
-#else
       vmcase(OP_TAILCALL) {
         int b = GETARG_B(i);  /* number of arguments + 1 (function) */
         int nparams1 = GETARG_C(i);
@@ -1731,9 +1671,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
         ci->func -= delta;  /* restore 'func' (if vararg) */
         luaD_pretailcall(L, ci, ra, b);  /* prepare call frame */
         goto startfunc;  /* execute the callee */
-        vmbreak;
       }
-#endif
       vmcase(OP_RETURN) {
         int n = GETARG_B(i) - 1;  /* number of results */
         int nparams1 = GETARG_C(i);
@@ -1790,17 +1728,12 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
           }
         }
        ret:  /* return from a Lua function */
-#if LUAOT
-        lua_assert(ci->callstatus & CIST_FRESH);
-        return;
-#else
         if (ci->callstatus & CIST_FRESH)
           return;  /* end this frame */
         else {
           ci = ci->previous;
           goto returning;  /* continue running caller in this frame */
         }
-#endif
       }
       vmcase(OP_FORLOOP) {
         if (ttisinteger(s2v(ra + 2))) {  /* integer loop? */