Browse Source

Put mmbin inside the binop

Avoids a bunch of branch statements on the main function body.
Hugo Musso Gualandi 3 years ago
parent
commit
c188b98016
2 changed files with 126 additions and 81 deletions
  1. 115 58
      src/functions_header.c
  2. 11 23
      src/luaot_functions.c

+ 115 - 58
src/functions_header.c

@@ -38,15 +38,13 @@ typedef struct {
   if (ttisinteger(v1)) {  \
     lua_Integer iv1 = ivalue(v1);  \
     setivalue(s2v(ra), iop(L, iv1, imm));  \
-    return 1; \
   }  \
   else if (ttisfloat(v1)) {  \
     lua_Number nb = fltvalue(v1);  \
     lua_Number fimm = cast_num(imm);  \
     setfltvalue(s2v(ra), fop(L, nb, fimm)); \
-    return 1; \
  } else { \
-    return 0; \
+    luaot_mmbin(L,ctx,pc); \
  }}
 
 #undef op_arithf_aux
@@ -54,9 +52,8 @@ typedef struct {
   lua_Number n1; lua_Number n2;  \
   if (tonumberns(v1, n1) && tonumberns(v2, n2)) {  \
     setfltvalue(s2v(ra), fop(L, n1, n2));  \
-    return 1; \
   } else { \
-    return 0; \
+    luaot_mmbin(L,ctx,pc); \
   }}
 
 #undef op_arith_aux
@@ -64,7 +61,6 @@ typedef struct {
   if (ttisinteger(v1) && ttisinteger(v2)) {  \
     lua_Integer i1 = ivalue(v1); lua_Integer i2 = ivalue(v2);  \
     setivalue(s2v(ra), iop(L, i1, i2));  \
-    return 1; \
   }  \
   else op_arithf_aux(L,fop); }
 
@@ -73,18 +69,16 @@ typedef struct {
   lua_Integer i2 = ivalue(v2);  \
   if (tointegerns(v1, &i1)) {  \
     setivalue(s2v(ra), op(i1, i2));  \
-    return 1; \
   } else { \
-    return 0; \
+    luaot_mmbin(L,ctx,pc); \
   }}
 
 #define op_bitwise_aux(L,op) {  \
   lua_Integer i1; lua_Integer i2;  \
   if (tointegerns(v1, &i1) && tointegerns(v2, &i2)) {  \
     setivalue(s2v(ra), op(i1, i2));  \
-    return 1; \
   } else { \
-    return 0; \
+    luaot_mmbin(L,ctx,pc); \
   }}
 
 //
@@ -94,43 +88,38 @@ typedef struct {
 #define luaot_arithI(L,f) { \
   TValue *v1 = vRB(i);  \
   int imm = GETARG_sC(i);  \
-  if ((f)(L, ra, v1, imm)) \
-    goto LUAOT_SKIP1; \
+  (f)(L, ctx, pc, ra, v1, imm); \
 }
 
 #define luaot_arithK(L, f) { \
   TValue *v1 = vRB(i); \
-  TValue *v2 = KC(i); lua_assert(ttisnumber(v2)); \
-  if ((f)(L, ra, v1, v2)) \
-    goto LUAOT_SKIP1; \
+  TValue *v2 = KC(i); \
+  lua_assert(ttisnumber(v2)); \
+  (f)(L, ctx, pc, ra, v1, v2); \
 }
 
 #define luaot_arith(L, f) { \
   TValue *v1 = vRB(i); \
   TValue *v2 = vRC(i); \
-  if ((f)(L, ra, v1, v2)) \
-    goto LUAOT_SKIP1; \
+  (f)(L, ctx, pc, ra, v1, v2); \
 }
 
 #define luaot_arithf(L, f) { \
-  TValue *v1 = vRB(i);  \
-  TValue *v2 = vRC(i);  \
-  if ((f)(L, ra, v1, v2)) \
-    goto LUAOT_SKIP1; \
+  TValue *v1 = vRB(i); \
+  TValue *v2 = vRC(i); \
+  (f)(L, ctx, pc, ra, v1, v2); \
 }
 
 #define luaot_bitwiseK(L, f) { \
-  TValue *v1 = vRB(i);  \
+  TValue *v1 = vRB(i); \
   TValue *v2 = KC(i);  \
-  if ((f)(L, ra, v1, v2)) \
-    goto LUAOT_SKIP1; \
+  (f)(L, ctx, pc, ra, v1, v2); \
 }
 
 #define luaot_bitwise(L, f) {  \
-  TValue *v1 = vRB(i);  \
-  TValue *v2 = vRC(i);  \
-  if ((f)(L, ra, v1, v2)) \
-    goto LUAOT_SKIP1; \
+  TValue *v1 = vRB(i); \
+  TValue *v2 = vRC(i); \
+  (f)(L, ctx, pc, ra, v1, v2); \
 }
 
 /*
@@ -234,6 +223,50 @@ void luaot_vmfetch_trap(lua_State *L, LuaotExecuteState *ctx, const Instruction
     updatebase(ctx->ci);  /* correct stack */ \
 }
 
+static
+void luaot_mmbin(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc)
+{
+    Instruction i;
+    StkId ra;
+    aot_vmfetch(*pc);
+    pc = pc + 1;
+
+    switch (GET_OPCODE(i)) {
+        case OP_MMBIN: {
+            Instruction pi = *(pc - 2);  /* original arith. expression */
+            TValue *rb = vRB(i);
+            TMS tm = (TMS)GETARG_C(i);
+            StkId result = RA(pi);
+            lua_assert(OP_ADD <= GET_OPCODE(pi) && GET_OPCODE(pi) <= OP_SHR);
+            Protect(luaT_trybinTM(L, s2v(ra), rb, result, tm));
+            break;
+        }
+        case OP_MMBINI: {
+            Instruction pi = *(pc - 2);  /* original arith. expression */
+            int imm = GETARG_sB(i);
+            TMS tm = (TMS)GETARG_C(i);
+            int flip = GETARG_k(i);
+            StkId result = RA(pi);
+            Protect(luaT_trybiniTM(L, s2v(ra), imm, flip, result, tm));
+            break;
+        }
+        case OP_MMBINK: {
+            Instruction pi = *(pc - 2);  /* original arith. expression */
+            TValue *imm = KB(i);
+            TMS tm = (TMS)GETARG_C(i);
+            int flip = GETARG_k(i);
+            StkId result = RA(pi);
+            Protect(luaT_trybinassocTM(L, s2v(ra), imm, flip, result, tm));
+            break;
+        }
+        default: {
+            lua_assert(0);
+            break;
+        }
+    }
+}
+
+
 static
 void luaot_MOVE(lua_State *L,
                 StkId ra, StkId rb)
@@ -473,171 +506,195 @@ void luaot_SELF(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
 }
 
 static
-int luaot_ADDI(lua_State *L, StkId ra, TValue *v1, int imm)
+void luaot_ADDI(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, int imm)
 {
     op_arithI_aux(L, l_addi, luai_numadd);
 }
 
 static
-int luaot_ADDK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_ADDK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, l_addi, luai_numadd);
 }
 
 static
-int luaot_SUBK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_SUBK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, l_subi, luai_numsub);
 }
 
 static
-int luaot_MULK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_MULK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, l_muli, luai_nummul);
 }
 
 static
-int luaot_MODK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_MODK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, luaV_mod, luaV_modf);
 }
 
 static
-int luaot_POWK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_POWK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arithf_aux(L, luai_numpow);
 }
 
 static
-int luaot_DIVK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_DIVK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arithf_aux(L, luai_numdiv);
 }
 
 static
-int luaot_IDIVK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_IDIVK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+                StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, luaV_idiv, luai_numidiv);
 }
 
 static
-int luaot_BANDK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_BANDK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+                StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwiseK_aux(L, l_band);
 }
 
 static
-int luaot_BORK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_BORK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwiseK_aux(L, l_bor);
 }
 
 static
-int luaot_BXORK(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_BXORK(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+                StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwiseK_aux(L, l_bxor);
 }
 
 static
-int luaot_SHRI(lua_State *L, StkId ra, TValue *rb, int ic)
+void luaot_SHRI(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *rb, int ic)
 {
     lua_Integer ib;
     if (tointegerns(rb, &ib)) {
         setivalue(s2v(ra), luaV_shiftl(ib, -ic));
-        return 1;
     } else {
-        return 0;
+        luaot_mmbin(L, ctx, pc);
     }
 }
 
 static
-int luaot_SHLI(lua_State *L,
+void luaot_SHLI(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+
                StkId ra, TValue *rb, int ic)
 {
     lua_Integer ib;
     if (tointegerns(rb, &ib)) {
         setivalue(s2v(ra), luaV_shiftl(ic, ib));
-        return 1;
     } else {
-        return 0;
+        luaot_mmbin(L, ctx, pc);
     }
 }
 
 static
-int luaot_ADD(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_ADD(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, l_addi, luai_numadd);
 }
 
 static
-int luaot_SUB(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_SUB(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, l_subi, luai_numsub);
 }
 
 static
-int luaot_MUL(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_MUL(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, l_muli, luai_nummul);
 }
 
 static
-int luaot_MOD(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_MOD(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, luaV_mod, luaV_modf);
 }
 
 static
-int luaot_POW(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_POW(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_arithf_aux(L, luai_numpow);
 }
 
 static
-int luaot_DIV(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_DIV(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_arithf_aux(L, luai_numdiv);
 }
 
 static
-int luaot_IDIV(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_IDIV(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_arith_aux(L, luaV_idiv, luai_numidiv);
 }
 
 static
-int luaot_BAND(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_BAND(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwise_aux(L, l_band);
 }
 
 static
-int luaot_BOR(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_BOR(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwise_aux(L, l_bor);
 }
 
 static
-int luaot_BXOR(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_BXOR(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+               StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwise_aux(L, l_bxor);
 }
 
 static
-int luaot_SHR(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_SHR(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwise_aux(L, luaV_shiftr);
 }
 
 static
-int luaot_SHL(lua_State *L, StkId ra, TValue *v1, TValue *v2)
+void luaot_SHL(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+              StkId ra, TValue *v1, TValue *v2)
 {
     op_bitwise_aux(L, luaV_shiftl);
 }
 
 #if 0
 static
-int luaot_(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+void luaot_(lua_State *L, LuaotExecuteState *ctx, const Instruction *pc,
+           LuaotExecuteState *ctx, const Instruction *pc,
            StkId ra, TValue *v1, int c)
 {
 }

+ 11 - 23
src/luaot_functions.c

@@ -80,6 +80,12 @@ void create_function(Proto *f)
 
         luaot_PrintOpcodeComment(f, pc);
 
+        if (op == OP_MMBIN || op == OP_MMBINI || op == OP_MMBINK) {
+            println("  label_%02d: {}", pc);
+            printnl();
+            continue;
+        }
+
         // While an instruction is executing, the program counter typically
         // points towards the next instruction. There are some corner cases
         // where the program counter getss adjusted mid-instruction, but I
@@ -273,16 +279,13 @@ void create_function(Proto *f)
             case OP_SHRI: {
                 println("    TValue *rb = vRB(i);");
                 println("    int ic = GETARG_sC(i);");
-                println("    if (luaot_SHRI(L, ra, rb, ic))");
-                println("      goto LUAOT_SKIP1;");
+                println("    luaot_SHRI(L, ctx, pc, ra, rb, ic);");
                 break;
             }
             case OP_SHLI: {
                 println("    TValue *rb = vRB(i);");
                 println("    int ic = GETARG_sC(i);");
-                println("    if (luaot_SHLI(L, ra, rb, ic))");
-                println("      goto LUAOT_SKIP1;");
-                println("    luaot_arithI(L, luaot_SHRL);");
+                println("    luaot_SHLI(L, ctx, pc, ra, rb, ic);");
                 break;
             }
             case OP_ADD: {
@@ -334,30 +337,15 @@ void create_function(Proto *f)
                 break;
             }
             case OP_MMBIN: {
-                println("    Instruction pi = 0x%08x; /* original arith. expression */", f->code[pc-1]);
-                println("    TValue *rb = vRB(i);");
-                println("    TMS tm = (TMS)GETARG_C(i);");
-                println("    StkId result = RA(pi);");
-                println("    lua_assert(OP_ADD <= GET_OPCODE(pi) && GET_OPCODE(pi) <= OP_SHR);");
-                println("    Protect(luaT_trybinTM(L, s2v(ra), rb, result, tm));");
+                // Inlined in previous opcode
                 break;
             }
             case OP_MMBINI: {
-                println("    Instruction pi = 0x%0x;  /* original arith. expression */", f->code[pc-1]);
-                println("    int imm = GETARG_sB(i);");
-                println("    TMS tm = (TMS)GETARG_C(i);");
-                println("    int flip = GETARG_k(i);");
-                println("    StkId result = RA(pi);");
-                println("    Protect(luaT_trybiniTM(L, s2v(ra), imm, flip, result, tm));");
+                // Inlined in previous opcode
                 break;
             }
             case OP_MMBINK: {
-                println("    Instruction pi = 0x%08x;  /* original arith. expression */", f->code[pc-1]);
-                println("    TValue *imm = KB(i);");
-                println("    TMS tm = (TMS)GETARG_C(i);");
-                println("    int flip = GETARG_k(i);");
-                println("    StkId result = RA(pi);");
-                println("    Protect(luaT_trybinassocTM(L, s2v(ra), imm, flip, result, tm));");
+                // Inlined in previous opcode
                 break;
             }
             case OP_UNM: {