Browse Source

Implement down-recursion.

Mike Pall 15 years ago
parent
commit
e7b737aa12
7 changed files with 114 additions and 13 deletions
  1. 7 1
      src/lj_asm.c
  2. 5 0
      src/lj_bc.h
  3. 2 5
      src/lj_dispatch.c
  4. 3 0
      src/lj_jit.h
  5. 39 2
      src/lj_record.c
  6. 57 5
      src/lj_trace.c
  7. 1 0
      src/lj_traceerr.h

+ 7 - 1
src/lj_asm.c

@@ -3212,8 +3212,14 @@ static void asm_tail_link(ASMState *as)
 
   if (as->T->link == TRACE_INTERP) {
     /* Setup fixed registers for exit to interpreter. */
+    const BCIns *pc = snap_pc(as->T->snapmap[snap->mapofs + snap->nent]);
+    if (bc_op(*pc) == BC_JLOOP) {  /* NYI: find a better way to do this. */
+      BCIns *retpc = &as->J->trace[bc_d(*pc)]->startins;
+      if (bc_isret(bc_op(*retpc)))
+	pc = retpc;
+    }
     emit_loada(as, RID_DISPATCH, J2GG(as->J)->dispatch);
-    emit_loada(as, RID_PC, snap_pc(as->T->snapmap[snap->mapofs + snap->nent]));
+    emit_loada(as, RID_PC, pc);
   } else if (baseslot) {
     /* Save modified BASE for linking to trace with higher start frame. */
     emit_setgl(as, RID_BASE, jit_base);

+ 5 - 0
src/lj_bc.h

@@ -245,6 +245,11 @@ typedef enum {
   (BCM##ma|(BCM##mb<<3)|(BCM##mc<<7)|(MM_##mm<<11)),
 #define BCMODE_FF	0
 
+static LJ_AINLINE int bc_isret(BCOp op)
+{
+  return (op == BC_RETM || op == BC_RET || op == BC_RET0 || op == BC_RET1);
+}
+
 LJ_DATA const uint16_t lj_bc_mode[];
 LJ_DATA const uint16_t lj_bc_ofs[];
 

+ 2 - 5
src/lj_dispatch.c

@@ -380,11 +380,8 @@ void LJ_FASTCALL lj_dispatch_ins(lua_State *L, const BCIns *pc)
       L->top = L->base + slots;  /* Fix top again. */
     }
   }
-  if ((g->hookmask & LUA_MASKRET)) {
-    BCOp op = bc_op(pc[-1]);
-    if (op == BC_RETM || op == BC_RET || op == BC_RET0 || op == BC_RET1)
-      callhook(L, LUA_HOOKRET, -1);
-  }
+  if ((g->hookmask & LUA_MASKRET) && bc_isret(bc_op(pc[-1])))
+    callhook(L, LUA_HOOKRET, -1);
 }
 
 /* Initialize call. Ensure stack space and clear missing parameters. */

+ 3 - 0
src/lj_jit.h

@@ -287,6 +287,9 @@ typedef struct jit_State {
   TraceNo parent;	/* Parent of current side trace (0 for root traces). */
   ExitNo exitno;	/* Exit number in parent of current side trace. */
 
+  BCIns *patchpc;	/* PC for pending re-patch. */
+  BCIns patchins;	/* Instruction for pending re-patch. */
+
   TValue errinfo;	/* Additional info element for trace errors. */
 
   MCode *mcarea;	/* Base of current mcode area. */

+ 39 - 2
src/lj_record.c

@@ -522,6 +522,29 @@ static void rec_tailcall(jit_State *J, BCReg func, ptrdiff_t nargs)
     lj_trace_err(J, LJ_TRERR_LUNROLL);
 }
 
+/* Check unroll limits for down-recursion. */
+static int check_downrec_unroll(jit_State *J, GCproto *pt)
+{
+  IRRef ptref;
+  for (ptref = J->chain[IR_KGC]; ptref; ptref = IR(ptref)->prev)
+    if (ir_kgc(IR(ptref)) == obj2gco(pt)) {
+      int count = 0;
+      IRRef ref;
+      for (ref = J->chain[IR_RETF]; ref; ref = IR(ref)->prev)
+	if (IR(ref)->op1 == ptref)
+	  count++;
+      if (count) {
+	if (J->pc == J->startpc) {
+	  if (count + J->tailcalled > J->param[JIT_P_recunroll])
+	    return 1;
+	} else {
+	  lj_trace_err(J, LJ_TRERR_DOWNREC);
+	}
+      }
+    }
+  return 0;
+}
+
 /* Record return. */
 static void rec_ret(jit_State *J, BCReg rbase, ptrdiff_t gotresults)
 {
@@ -545,6 +568,15 @@ static void rec_ret(jit_State *J, BCReg rbase, ptrdiff_t gotresults)
     BCIns callins = *(frame_pc(frame)-1);
     ptrdiff_t nresults = bc_b(callins) ? (ptrdiff_t)bc_b(callins)-1 :gotresults;
     BCReg cbase = bc_a(callins);
+    GCproto *pt = funcproto(frame_func(frame - (cbase+1)));
+    if (J->pt && frame == J->L->base - 1) {
+      if (J->framedepth == 0 && check_downrec_unroll(J, pt)) {
+	J->maxslot = rbase + nresults;
+	rec_stop(J, J->curtrace);  /* Down-recursion. */
+	return;
+      }
+      lj_snap_add(J);
+    }
     for (i = 0; i < nresults; i++)  /* Adjust results. */
       J->base[i-1] = i < gotresults ? J->base[rbase+i] : TREF_NIL;
     J->maxslot = cbase+(BCReg)nresults;
@@ -553,11 +585,10 @@ static void rec_ret(jit_State *J, BCReg rbase, ptrdiff_t gotresults)
       lua_assert(J->baseslot > cbase+1);
       J->baseslot -= cbase+1;
       J->base -= cbase+1;
-    } else if (J->parent == 0) {
+    } else if (J->parent == 0 && !bc_isret(bc_op(J->cur.startins))) {
       /* Return to lower frame would leave the loop in a root trace. */
       lj_trace_err(J, LJ_TRERR_LLEAVE);
     } else {  /* Return to lower frame. Guard for the target we return to. */
-      GCproto *pt = funcproto(frame_func(frame - (cbase+1)));
       TRef trpt = lj_ir_kgc(J, obj2gco(pt), IRT_PROTO);
       TRef trpc = lj_ir_kptr(J, (void *)frame_pc(frame));
       emitir(IRTG(IR_RETF, IRT_PTR), trpt, trpc);
@@ -2285,6 +2316,12 @@ static const BCIns *rec_setup_root(jit_State *J)
     J->maxslot = ra;
     pc++;
     break;
+  case BC_RET:
+  case BC_RET0:
+  case BC_RET1:
+    /* No bytecode range check for down-recursive root traces. */
+    J->maxslot = ra + bc_d(ins);
+    break;
   case BC_FUNCF:
     /* No bytecode range check for root traces started by a hot call. */
     J->maxslot = J->pt->numparams;

+ 57 - 5
src/lj_trace.c

@@ -357,6 +357,8 @@ static void trace_start(jit_State *J)
   if ((J->pt->flags & PROTO_NO_JIT)) {  /* JIT disabled for this proto? */
     if (J->parent == 0) {
       /* Lazy bytecode patching to disable hotcount events. */
+      lua_assert(bc_op(*J->pc) == BC_FORL || bc_op(*J->pc) == BC_ITERL ||
+		 bc_op(*J->pc) == BC_LOOP || bc_op(*J->pc) == BC_FUNCF);
       setbc_op(J->pc, (int)bc_op(*J->pc)+(int)BC_ILOOP-(int)BC_LOOP);
       J->pt->flags |= PROTO_HAS_ILOOP;
     }
@@ -416,10 +418,16 @@ static void trace_stop(jit_State *J)
     /* Patch bytecode of starting instruction in root trace. */
     setbc_op(pc, (int)op+(int)BC_JLOOP-(int)BC_LOOP);
     setbc_d(pc, J->curtrace);
+  addroot:
     /* Add to root trace chain in prototype. */
     J->cur.nextroot = pt->trace;
     pt->trace = (TraceNo1)J->curtrace;
     break;
+  case BC_RET:
+  case BC_RET0:
+  case BC_RET1:
+    *pc = BCINS_AD(BC_JLOOP, J->cur.snap[0].nslots, J->curtrace);
+    goto addroot;
   case BC_JMP:
     /* Patch exit branch in parent to side trace entry. */
     lua_assert(J->parent != 0 && J->cur.root != 0);
@@ -450,6 +458,21 @@ static void trace_stop(jit_State *J)
   );
 }
 
+/* Start a new root trace for down-recursion. */
+static int trace_downrec(jit_State *J)
+{
+  /* Restart recording at the return instruction. */
+  lua_assert(J->pt != NULL);
+  lua_assert(bc_isret(bc_op(*J->pc)));
+  if (bc_op(*J->pc) == BC_RETM)
+    return 0;  /* NYI: down-recursion with RETM. */
+  J->parent = 0;
+  J->exitno = 0;
+  J->state = LJ_TRACE_RECORD;
+  trace_start(J);
+  return 1;
+}
+
 /* Abort tracing. */
 static int trace_abort(jit_State *J)
 {
@@ -463,7 +486,7 @@ static int trace_abort(jit_State *J)
     return 1;  /* Retry ASM with new MCode area. */
   }
   /* Penalize or blacklist starting bytecode instruction. */
-  if (J->parent == 0)
+  if (J->parent == 0 && !bc_isret(bc_op(J->cur.startins)))
     penalty_pc(J, &gcref(J->cur.startpt)->pt, (BCIns *)J->startpc, e);
   if (J->curtrace) {  /* Is there anything to abort? */
     ptrdiff_t errobj = savestack(L, L->top-1);  /* Stack may be resized. */
@@ -493,17 +516,29 @@ static int trace_abort(jit_State *J)
     J->curtrace = 0;
   }
   L->top--;  /* Remove error object */
-  if (e == LJ_TRERR_MCODEAL)
+  if (e == LJ_TRERR_DOWNREC)
+    return trace_downrec(J);
+  else if (e == LJ_TRERR_MCODEAL)
     lj_trace_flushall(L);
   return 0;
 }
 
+/* Perform pending re-patch of a bytecode instruction. */
+static LJ_AINLINE void trace_pendpatch(jit_State *J, int force)
+{
+  if (LJ_UNLIKELY(J->patchpc) && (force || J->chain[IR_RETF])) {
+    *J->patchpc = J->patchins;
+    J->patchpc = NULL;
+  }
+}
+
 /* State machine for the trace compiler. Protected callback. */
 static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud)
 {
   jit_State *J = (jit_State *)ud;
   UNUSED(dummy);
   do {
+  retry:
     switch (J->state) {
     case LJ_TRACE_START:
       J->state = LJ_TRACE_RECORD;  /* trace_start() may change state. */
@@ -512,6 +547,7 @@ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud)
       break;
 
     case LJ_TRACE_RECORD:
+      trace_pendpatch(J, 0);
       setvmstate(J2G(J), RECORD);
       lj_vmevent_send(L, RECORD,
 	setintV(L->top++, J->curtrace);
@@ -523,6 +559,7 @@ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud)
       break;
 
     case LJ_TRACE_END:
+      trace_pendpatch(J, 1);
       J->loopref = 0;
       if ((J->flags & JIT_F_OPT_LOOP) &&
 	  J->cur.link == J->curtrace && J->framedepth + J->retdepth == 0) {
@@ -551,8 +588,9 @@ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud)
       setintV(L->top++, (int32_t)LJ_TRERR_RECERR);
       /* fallthrough */
     case LJ_TRACE_ERR:
+      trace_pendpatch(J, 1);
       if (trace_abort(J))
-	break;  /* Retry. */
+	goto retry;
       setvmstate(J2G(J), INTERP);
       J->state = LJ_TRACE_IDLE;
       lj_dispatch_update(J2G(J));
@@ -627,6 +665,7 @@ int LJ_FASTCALL lj_trace_exit(jit_State *J, void *exptr)
   lua_State *L = J->L;
   ExitDataCP exd;
   int errcode;
+  const BCIns *pc;
   exd.J = J;
   exd.exptr = exptr;
   errcode = lj_vm_cpcall(L, NULL, &exd, trace_exit_cp);
@@ -651,8 +690,21 @@ int LJ_FASTCALL lj_trace_exit(jit_State *J, void *exptr)
     }
   );
 
-  trace_hotside(J, exd.pc);
-  setcframe_pc(cframe_raw(L->cframe), exd.pc);
+  pc = exd.pc;
+  trace_hotside(J, pc);
+  if (bc_op(*pc) == BC_JLOOP) {
+    BCIns *retpc = &J->trace[bc_d(*pc)]->startins;
+    if (bc_isret(bc_op(*retpc))) {
+      if (J->state == LJ_TRACE_RECORD) {
+	J->patchins = *pc;
+	J->patchpc = (BCIns *)pc;
+	*J->patchpc = *retpc;
+      } else {
+	pc = retpc;
+      }
+    }
+  }
+  setcframe_pc(cframe_raw(L->cframe), pc);
   return 0;
 }
 

+ 1 - 0
src/lj_traceerr.h

@@ -22,6 +22,7 @@ TREDEF(LUNROLL,	"loop unroll limit reached")
 TREDEF(BADTYPE,	"bad argument type")
 TREDEF(CJITOFF,	"call to JIT-disabled function")
 TREDEF(CUNROLL,	"call unroll limit reached")
+TREDEF(DOWNREC,	"down-recursion, restarting")
 TREDEF(NYIVF,	"NYI: vararg function")
 TREDEF(NYICF,	"NYI: C function %p")
 TREDEF(NYIFF,	"NYI: FastFunc %s")