فهرست منبع

new opcodes, resolve calls addresses, call stack alignment on 16bytes

Nicolas Cannasse 10 سال پیش
والد
کامیت
d5cf18f673
5فایلهای تغییر یافته به همراه210 افزوده شده و 50 حذف شده
  1. 44 19
      src/code.c
  2. 5 4
      src/hl.h
  3. 139 14
      src/jit.c
  4. 16 13
      src/module.c
  5. 6 0
      src/opcodes.h

+ 44 - 19
src/code.c

@@ -26,6 +26,11 @@
 #define OP_END };
 #include "opcodes.h"
 
+#define OP(n,_) #n,
+#define OP_BEGIN static const char *hl_op_names[] = {
+#define OP_END };
+#include "opcodes.h"
+
 typedef struct {
 	const unsigned char *b;
 	int size;
@@ -109,7 +114,7 @@ static int hl_read_uindex( hl_reader *r ) {
 	return i;
 }
 
-hl_type *hl_get_type( hl_reader *r ) {
+static hl_type *hl_get_type( hl_reader *r ) {
 	int i = INDEX();
 	if( i < 0 || i >= r->code->ntypes ) {
 		ERROR("Invalid type index");
@@ -118,7 +123,7 @@ hl_type *hl_get_type( hl_reader *r ) {
 	return r->code->types + i;
 }
 
-const char *hl_get_string( hl_reader *r ) {
+static const char *hl_get_string( hl_reader *r ) {
 	int i = INDEX();
 	if( i < 0 || i >= r->code->nstrings ) {
 		ERROR("Invalid string index");
@@ -127,7 +132,7 @@ const char *hl_get_string( hl_reader *r ) {
 	return r->code->strings[i];
 }
 
-int hl_get_global( hl_reader *r ) {
+static int hl_get_global( hl_reader *r ) {
 	int g = INDEX();
 	if( g < 0 || g >= r->code->nglobals ) {
 		ERROR("Invalid global index");
@@ -136,7 +141,7 @@ int hl_get_global( hl_reader *r ) {
 	return g;
 }
 
-void hl_read_type( hl_reader *r, hl_type *t ) {
+static void hl_read_type( hl_reader *r, hl_type *t ) {
 	int i;
 	t->kind = READ();
 	if( t->kind >= HLAST ) {
@@ -160,7 +165,7 @@ void hl_read_type( hl_reader *r, hl_type *t ) {
 	}
 }
 
-void hl_read_opcode( hl_reader *r, hl_function *f, hl_opcode *o ) {
+static void hl_read_opcode( hl_reader *r, hl_function *f, hl_opcode *o ) {
 	o->op = (hl_op)READ();
 	if( o->op >= OLast ) {
 		ERROR("Invalid opcode");
@@ -181,40 +186,54 @@ void hl_read_opcode( hl_reader *r, hl_function *f, hl_opcode *o ) {
 		o->p2 = INDEX();
 		o->p3 = INDEX();
 		break;
-	default:
+	case 4:
+		o->p1 = INDEX();
+		o->p2 = INDEX();
+		o->p3 = INDEX();
+		o->extra = (int*)INDEX();
+		break;
+	case -1:
 		switch( o->op ) {
-		case OCall2:
-			o->p1 = INDEX();
-			o->p2 = INDEX();
-			o->p3 = INDEX();
-			o->extra = (void*)INDEX();
-			break;
 		case OCallN:
 			{
-				int *args, i;
+				int i;
 				o->p1 = INDEX();
 				o->p2 = INDEX();
 				o->p3 = READ();
-				args = (int*)hl_malloc(&r->code->alloc,sizeof(int) * o->p3);
-				if( args == NULL ) {
+				o->extra = (int*)hl_malloc(&r->code->alloc,sizeof(int) * o->p3);
+				if( o->extra == NULL ) {
 					ERROR("Out of memory");
 					return;
 				}
 				for(i=0;i<o->p3;i++)
-					args[i] = INDEX();
-				o->extra = args;
+					o->extra[i] = INDEX();
 			}
 			break;
 		default:
 			ERROR("Don't know how to process opcode");
 			break;
 		}
+	default:
+		{
+			int i, size = hl_op_nargs[o->op] - 3;
+			o->p1 = INDEX();
+			o->p2 = INDEX();
+			o->p3 = INDEX();
+			o->extra = (int*)hl_malloc(&r->code->alloc,sizeof(int) * size);
+			if( o->extra == NULL ) {
+				ERROR("Out of memory");
+				return;
+			}
+			for(i=0;i<size;i++)
+				o->extra[i] = INDEX();
+		}
+		break;
 	}
 }
 
-void hl_read_function( hl_reader *r, hl_function *f ) {
+static void hl_read_function( hl_reader *r, hl_function *f ) {
 	int i;
-	f->index = UINDEX();
+	f->global = UINDEX();
 	f->nregs = UINDEX();
 	f->nops = UINDEX();
 	f->regs = (hl_type**)hl_malloc(&r->code->alloc, f->nregs * sizeof(hl_type*));
@@ -239,6 +258,12 @@ void hl_read_function( hl_reader *r, hl_function *f ) {
 #define EXIT(msg) { ERROR(msg); CHK_ERROR(); }
 #define ALLOC(v,ptr,count) { v = (ptr *)hl_zalloc(&c->alloc,count*sizeof(ptr)); if( v == NULL ) EXIT("Out of memory"); }
 
+const char *hl_op_name( int op ) {
+	if( op < 0 || op >= OLast )
+		return "UnknownOp";
+	return hl_op_names[op];
+}
+
 hl_code *hl_code_read( const unsigned char *data, int size ) {
 	hl_reader _r = { data, size, 0, 0, NULL };	
 	hl_reader *r = &_r;

+ 5 - 4
src/hl.h

@@ -108,13 +108,13 @@ typedef struct {
 	int p1;
 	int p2;
 	int p3;
-	void *extra;
+	int *extra;
 } hl_opcode;
 
 typedef struct hl_ptr_list hl_ptr_list;
 
 typedef struct {
-	int index;
+	int global;
 	int nregs;
 	int nops;
 	hl_type **regs;
@@ -166,14 +166,15 @@ int hl_word_size( hl_type *t ); // same as hl_type_size, but round to the next w
 
 hl_code *hl_code_read( const unsigned char *data, int size );
 void hl_code_free( hl_code *c );
+const char* hl_op_name( int op );
 
 hl_module *hl_module_alloc( hl_code *code );
+int hl_module_init( hl_module *m );
 void hl_module_free( hl_module *m );
 
 jit_ctx *hl_jit_alloc();
 void hl_jit_free( jit_ctx *ctx );
 int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f );
-int hl_module_init( hl_module *m );
-void *hl_jit_code( jit_ctx *ctx );
+void *hl_jit_code( jit_ctx *ctx, hl_module *m );
 
 #endif

+ 139 - 14
src/jit.c

@@ -69,7 +69,6 @@
 #define XCall_d(delta)			B(0xE8); W(delta)
 #define XPush_r(r)				B(0x50+(r))
 #define XPush_c(cst)			B(0x68); W(cst)
-//XPush_p
 #define XPush_p(reg,idx)		OP_ADDR(0xFF,idx,reg,6)
 #define XAdd_rc(reg,cst)		if IS_SBYTE(cst) { OP_RM(0x83,3,0,reg); B(cst); } else { OP_RM(0x81,3,0,reg); W(cst); }
 #define XAdd_rr(dst,src)		OP_RM(0x03,3,dst,src)
@@ -100,6 +99,8 @@
 //XNeg_r
 //XNeg_p
 
+#define XNop()					B(0x90)				
+
 #define XTest_rc(r,cst)			if( r == Eax ) { B(0xA9); W(cst); } else { B(0xF7); MOD_RM(3,0,r); W(cst); }
 #define XTest_rr(r,src)			B(0x85); MOD_RM(3,r,src)
 #define XAnd_rc(r,cst)			if( r == Eax ) { B(0x25); W(cst); } else { B(0x81); MOD_RM(3,4,r); W(cst); }
@@ -173,17 +174,24 @@ struct jit_ctx {
 	unsigned char *startBuf;
 	int bufSize;
 	int totalRegsSize;
+	int *globalToFunction;
+	int functionPos;
 	hl_module *m;
 	hl_function *f;
 	jlist *jumps;
-	hl_alloc alloc;
+	jlist *calls;
+	hl_alloc falloc; // cleared per-function
+	hl_alloc galloc;
 };
 
 static void jit_buf( jit_ctx *ctx ) {
 	if( ctx->buf.b - ctx->startBuf > ctx->bufSize - MAX_OP_SIZE ) {
 		int nsize = ctx->bufSize ? (ctx->bufSize * 4) / 3 : ctx->f->nops * 4;
-		unsigned char *nbuf = (unsigned char*)malloc(nsize);
-		int curpos = ctx->buf.b - ctx->startBuf;
+		unsigned char *nbuf;
+		int curpos;
+		if( nsize < ctx->bufSize + MAX_OP_SIZE * 4 ) nsize = ctx->bufSize + MAX_OP_SIZE * 4;
+		curpos = ctx->buf.b - ctx->startBuf;
+		nbuf = (unsigned char*)malloc(nsize);
 		// TODO : check nbuf
 		if( ctx->startBuf ) {
 			memcpy(nbuf,ctx->startBuf,curpos);
@@ -222,6 +230,38 @@ static void op_callr( jit_ctx *ctx, int r, int rf, int size ) {
 	STORE(r, Eax);
 }
 
+static void op_callg( jit_ctx *ctx, int r, int g, int size ) {
+	int fid = ctx->globalToFunction[g];
+	if( fid < 0 ) {
+		// not a static function or native, load it at runtime
+		XMov_ra(Eax, (int_val)(ctx->m->globals_data + ctx->m->globals_indexes[g]));
+		XCall_r(Eax);
+	} else if( fid >= ctx->m->code->nfunctions ) {
+		// native function, already resolved
+		XMov_rc(Eax, *(int_val*)(ctx->m->globals_data + ctx->m->globals_indexes[g]));
+		XCall_r(Eax);
+	} else {
+		int cpos = ctx->buf.b - ctx->startBuf;
+		if( ctx->m->functions_ptrs[fid] ) {
+			// already compiled
+			XCall_d((int_val)ctx->m->functions_ptrs[fid] - (cpos + 5));
+		} else if( ctx->m->code->functions + fid == ctx->f ) {
+			// our current function
+			XCall_d(ctx->functionPos - (cpos + 5));
+		} else {
+			// stage for later
+			jlist *j = (jlist*)hl_malloc(&ctx->galloc,sizeof(jlist));
+			j->pos = cpos;
+			j->target = g;
+			j->next = ctx->calls;
+			ctx->calls = j;
+			XCall_d(0);
+		}
+	}
+	XAdd_rc(Esp, size);
+	STORE(r, Eax);
+}
+
 static void op_enter( jit_ctx *ctx ) {
 	XPush_r(Ebp);
 	XMov_rr(Ebp, Esp);
@@ -256,8 +296,21 @@ static int *do_jump( jit_ctx *ctx, hl_op op ) {
 		XJump(JAlways,j);
 		break;
 	case OGte:
+	case OJGte:
 		XJump(JGte,j);
 		break;
+	case OLt:
+	case OJLt:
+		XJump(JLt,j);
+		break;
+	case OEq:
+	case OJEq:
+		XJump(JEq,j);
+		break;
+	case ONotEq:
+	case OJNeq:
+		XJump(JNeq,j);
+		break;
 	default:
 		j = NULL;
 		printf("Unknown JUMP %d\n",op);
@@ -286,7 +339,7 @@ static void op_cmp( jit_ctx *ctx, hl_opcode *op ) {
 
 static void register_jump( jit_ctx *ctx, int *p, int target ) {
 	int pos = (int_val)p - (int_val)ctx->startBuf; 
-	jlist *j = (jlist*)hl_malloc(&ctx->alloc, sizeof(jlist));
+	jlist *j = (jlist*)hl_malloc(&ctx->falloc, sizeof(jlist));
 	j->pos = pos;
 	j->target = target;
 	j->next = ctx->jumps;
@@ -297,7 +350,8 @@ jit_ctx *hl_jit_alloc() {
 	jit_ctx *ctx = (jit_ctx*)malloc(sizeof(jit_ctx));
 	if( ctx == NULL ) return NULL;
 	memset(ctx,0,sizeof(jit_ctx));
-	hl_alloc_init(&ctx->alloc);
+	hl_alloc_init(&ctx->falloc);
+	hl_alloc_init(&ctx->galloc);
 	return ctx;
 }
 
@@ -306,16 +360,35 @@ void hl_jit_free( jit_ctx *ctx ) {
 	free(ctx->regsSize);
 	free(ctx->opsPos);
 	free(ctx->startBuf);
-	hl_free(&ctx->alloc);
+	hl_free(&ctx->falloc);
+	hl_free(&ctx->galloc);
 	free(ctx);
 }
 
+int pad_stack( jit_ctx *ctx, int size ) {
+	int total = size + ctx->totalRegsSize + HL_WSIZE * 2; // EIP+EBP
+	if( total & 15 ) {
+		int pad = 16 - (total & 15);
+		XSub_rc(Esp,pad);
+		size += pad;
+	}
+	return size;
+}
+
 int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f ) {
 	int i, j, size = 0;
 	int codePos = ctx->buf.b - ctx->startBuf;
-	int nargs = m->code->globals[f->index]->nargs;
+	int nargs = m->code->globals[f->global]->nargs;
 	ctx->m = m;
 	ctx->f = f;
+	if( !ctx->globalToFunction ) {
+		ctx->globalToFunction = (int*)malloc(sizeof(int)*m->code->nglobals);
+		memset(ctx->globalToFunction,0xFF,sizeof(int)*m->code->nglobals);
+		for(i=0;i<m->code->nfunctions;i++)
+			ctx->globalToFunction[(m->code->functions + i)->global] = i;
+		for(i=0;i<m->code->nnatives;i++)
+			ctx->globalToFunction[(m->code->natives + i)->global] = i + m->code->nfunctions;
+	}
 	if( f->nregs > ctx->maxRegs ) {
 		free(ctx->regsPos);
 		free(ctx->regsSize);
@@ -352,6 +425,7 @@ int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f ) {
 	}
 	ctx->totalRegsSize = size;
 	jit_buf(ctx);
+	ctx->functionPos = ctx->buf.b - ctx->startBuf;
 	op_enter(ctx);
 	ctx->opsPos[0] = 0;
 	for(i=0;i<f->nops;i++) {
@@ -374,12 +448,43 @@ int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f ) {
 		case OCallN:
 			size = 0;
 			for(j=o->p3-1;j>=0;j--) {
-				int r = ((int*)o->extra)[j];
+				int r = o->extra[j];
+				size += ctx->regsSize[r];
+			}
+			size = pad_stack(ctx,size);
+			for(j=o->p3-1;j>=0;j--) {
+				int r = o->extra[j];
+				if( (j & 7) == 0 ) jit_buf(ctx);
 				op_pushr(ctx, r);
-				size += 4; // regsize !
 			}
 			op_callr(ctx, o->p1, o->p2, size);
 			break;
+		case OCall1:
+			size = pad_stack(ctx,ctx->regsSize[o->p3]);
+			op_pushr(ctx, o->p3);
+			op_callg(ctx, o->p1, o->p2, size);
+			break;
+		case OCall2:
+			size = pad_stack(ctx,ctx->regsSize[o->p3] + ctx->regsSize[(int)(int_val)o->extra]);
+			op_pushr(ctx, (int)(int_val)o->extra);
+			op_pushr(ctx, o->p3);
+			op_callg(ctx, o->p1, o->p2, size);
+			break;
+		case OCall3:
+			size = pad_stack(ctx,ctx->regsSize[o->p3] + ctx->regsSize[o->extra[0]] + ctx->regsSize[o->extra[1]]);
+			op_pushr(ctx, o->extra[1]);
+			op_pushr(ctx, o->extra[0]);
+			op_pushr(ctx, o->p3);
+			op_callg(ctx, o->p1, o->p2, size);
+			break;
+		case OCall4:
+			size = pad_stack(ctx,ctx->regsSize[o->p3] + ctx->regsSize[o->extra[0]] + ctx->regsSize[o->extra[1]] + ctx->regsSize[o->extra[2]]);
+			op_pushr(ctx, o->extra[2]);
+			op_pushr(ctx, o->extra[1]);
+			op_pushr(ctx, o->extra[0]);
+			op_pushr(ctx, o->p3);
+			op_callg(ctx, o->p1, o->p2, size);
+			break;
 		case OSub:
 			op_sub(ctx, o->p1, o->p2, o->p3);
 			break;
@@ -395,6 +500,13 @@ int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f ) {
 			XJump(JZero,jump);
 			register_jump(ctx,jump,(i + 1) + o->p2);
 			break;
+		case OJLt:
+			LOAD(Eax,o->p1);
+			LOAD(Ecx,o->p2);
+			XCmp_rr(Eax, Ecx);
+			jump = do_jump(ctx,o->op);
+			register_jump(ctx,jump,(i + 1) + o->p3);
+			break;
 		case OToAny:
 			op_mov(ctx,o->p1,o->p2); // TODO
 			break;
@@ -402,7 +514,7 @@ int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f ) {
 			op_ret(ctx, o->p1);
 			break;
 		default:
-			printf("Don't know how to jit op #%d\n",o->op);
+			printf("Don't know how to jit %s(%d)\n",hl_op_name(o->op),o->op);
 			return -1;
 		}
 		ctx->opsPos[i+1] = ctx->buf.b - ctx->startBuf;
@@ -416,17 +528,30 @@ int hl_jit_function( jit_ctx *ctx, hl_module *m, hl_function *f ) {
 		}
 		ctx->jumps = NULL;
 	}
+	// add nops padding
+	while( (ctx->buf.b - ctx->startBuf) & 15 )
+		XNop();
 	// reset tmp allocator
-	hl_free(&ctx->alloc);
+	hl_free(&ctx->falloc);
 	return codePos;
 }
 
 void *hl_alloc_executable_memory( int size );
 
-void *hl_jit_code( jit_ctx *ctx ) {
+void *hl_jit_code( jit_ctx *ctx, hl_module *m ) {
+	jlist *c;
 	int size = ctx->buf.b - ctx->startBuf;
-	void *code = hl_alloc_executable_memory(size);
+	unsigned char *code = (unsigned char*)hl_alloc_executable_memory(size);
 	if( code == NULL ) return NULL;
 	memcpy(code,ctx->startBuf,size);
+	// patch calls
+	c = ctx->calls;
+	while( c ) {
+		int fid = ctx->globalToFunction[c->target];
+		int fpos = (int_val)m->functions_ptrs[fid];
+		*(int*)(code + c->pos + 1) = fpos - (c->pos + 5);
+		c = c->next;
+	}
 	return code;
 }
+

+ 16 - 13
src/module.c

@@ -65,33 +65,36 @@ static void do_log( int i ) {
 int hl_module_init( hl_module *m ) {
 	int i;
 	jit_ctx *ctx;
+	// RESET globals
+	for(i=0;i<m->code->nglobals;i++) {
+		hl_type *t = m->code->globals[i];
+		if( t->kind == HFUN ) *(fptr*)(m->globals_data + m->globals_indexes[i]) = null_function;
+	}
+	// INIT natives
+	for(i=0;i<m->code->nnatives;i++) {
+		hl_native *n = m->code->natives + i;
+		*(void**)(m->globals_data + m->globals_indexes[n->global]) = do_log;
+	}
 	// JIT
 	ctx = hl_jit_alloc();
 	if( ctx == NULL )
 		return 0;
 	for(i=0;i<m->code->nfunctions;i++) {
 		int f = hl_jit_function(ctx, m, m->code->functions+i);
-		if( f < 0 ) return 0;
+		if( f < 0 ) {
+			hl_jit_free(ctx);
+			return 0;
+		}
 		m->functions_ptrs[i] = (void*)f;
 	}
-	m->jit_code = hl_jit_code(ctx);
+	m->jit_code = hl_jit_code(ctx, m);
 	for(i=0;i<m->code->nfunctions;i++)
 		m->functions_ptrs[i] = ((unsigned char*)m->jit_code) + ((int_val)m->functions_ptrs[i]);
 	hl_jit_free(ctx);
-	// INIT globals
-	for(i=0;i<m->code->nglobals;i++) {
-		hl_type *t = m->code->globals[i];
-		if( t->kind == HFUN ) *(fptr*)(m->globals_data + m->globals_indexes[i]) = null_function;
-	}
 	// INIT functions
 	for(i=0;i<m->code->nfunctions;i++) {
 		hl_function *f = m->code->functions + i;
-		*(void**)(m->globals_data + m->globals_indexes[f->index]) = m->functions_ptrs[i];
-	}
-	// INIT natives
-	for(i=0;i<m->code->nnatives;i++) {
-		hl_native *n = m->code->natives + i;
-		*(void**)(m->globals_data + m->globals_indexes[n->global]) = do_log;
+		*(void**)(m->globals_data + m->globals_indexes[f->global]) = m->functions_ptrs[i];
 	}
 	return 1;
 }

+ 6 - 0
src/opcodes.h

@@ -41,6 +41,8 @@ OP_BEGIN
 	OP(OCall0,2)
 	OP(OCall1,3)
 	OP(OCall2,4)
+	OP(OCall3,5)
+	OP(OCall4,6)
 	OP(OCallN,-1)
 	OP(OGetGlobal, 2)
 	OP(OSetGlobal,2)
@@ -53,6 +55,10 @@ OP_BEGIN
 	OP(OJFalse,2)
 	OP(OJNull,2)
 	OP(OJNotNull,2)
+	OP(OJLt,3)
+	OP(OJGte,3)
+	OP(OJEq,3)
+	OP(OJNeq,3)
 	OP(OJAlways,1)
 	OP(OToAny,2)
 	// --