Browse Source

FFI: Add 64 bit integer arithmetic.

Mike Pall 14 years ago
parent
commit
44935dae0d
6 changed files with 157 additions and 17 deletions
  1. 1 1
      src/Makefile.dep
  2. 111 16
      src/lib_ffi.c
  3. 40 0
      src/lj_cdata.c
  4. 2 0
      src/lj_cdata.h
  5. 1 0
      src/lj_errmsg.h
  6. 2 0
      src/lj_obj.h

+ 1 - 1
src/Makefile.dep

@@ -22,7 +22,7 @@ lib_debug.o: lib_debug.c lua.h luaconf.h lauxlib.h lualib.h lj_obj.h \
  lj_def.h lj_arch.h lj_err.h lj_errmsg.h lj_lib.h lj_libdef.h
 lib_ffi.o: lib_ffi.c lua.h luaconf.h lauxlib.h lualib.h lj_obj.h lj_def.h \
  lj_arch.h lj_gc.h lj_err.h lj_errmsg.h lj_str.h lj_ctype.h lj_cparse.h \
- lj_cdata.h lj_cconv.h lj_lib.h lj_libdef.h
+ lj_cdata.h lj_cconv.h lj_ff.h lj_ffdef.h lj_lib.h lj_libdef.h
 lib_init.o: lib_init.c lua.h luaconf.h lauxlib.h lualib.h lj_arch.h
 lib_io.o: lib_io.c lua.h luaconf.h lauxlib.h lualib.h lj_obj.h lj_def.h \
  lj_arch.h lj_gc.h lj_err.h lj_errmsg.h lj_str.h lj_ff.h lj_ffdef.h \

+ 111 - 16
src/lib_ffi.c

@@ -21,6 +21,7 @@
 #include "lj_cparse.h"
 #include "lj_cdata.h"
 #include "lj_cconv.h"
+#include "lj_ff.h"
 #include "lj_lib.h"
 
 /* -- C type checks ------------------------------------------------------- */
@@ -80,9 +81,8 @@ typedef struct FFIArith {
 } FFIArith;
 
 /* Check arguments for arithmetic metamethods. */
-static void ffi_checkarith(lua_State *L, FFIArith *fa)
+static void ffi_checkarith(lua_State *L, CTState *cts, FFIArith *fa)
 {
-  CTState *cts = ctype_cts(L);
   TValue *o = L->base;
   MSize i;
   if (o+1 >= L->top)
@@ -154,17 +154,18 @@ LJLIB_CF(ffi_meta___call)	LJLIB_REC(cdata_call)
 }
 
 /* Pointer arithmetic. */
-static int ffi_arith_ptr(lua_State *L, FFIArith *fa, int sub)
+static int ffi_arith_ptr(lua_State *L, CTState *cts, FFIArith *fa, MMS mm)
 {
-  CTState *cts = ctype_cts(L);
   CType *ctp = fa->ct[0];
   uint8_t *pp = fa->p[0];
   ptrdiff_t idx;
   CTSize sz;
   CTypeID id;
   GCcdata *cd;
+  if (!(mm == MM_add || mm == MM_sub))
+    return 0;
   if (ctype_isptr(ctp->info) || ctype_isrefarray(ctp->info)) {
-    if (sub &&
+    if (mm == MM_sub &&
 	(ctype_isptr(fa->ct[1]->info) || ctype_isrefarray(fa->ct[1]->info))) {
       /* Pointer difference. */
       intptr_t diff;
@@ -184,11 +185,13 @@ static int ffi_arith_ptr(lua_State *L, FFIArith *fa, int sub)
       setnumV(L->top-1, (lua_Number)diff);
       return 1;
     }
+    if (!ctype_isnum(fa->ct[1]->info)) return 0;
     lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), fa->ct[1],
 		   (uint8_t *)&idx, fa->p[1], 0);
-    if (sub) idx = -idx;
-  } else if (!sub &&
+    if (mm == MM_sub) idx = -idx;
+  } else if (mm == MM_add &&
       (ctype_isptr(fa->ct[1]->info) || ctype_isrefarray(fa->ct[1]->info))) {
+    if (!ctype_isnum(ctp->info)) return 0;
     /* Swap pointer and index. */
     ctp = fa->ct[1]; pp = fa->p[1];
     lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), fa->ct[0],
@@ -210,22 +213,114 @@ static int ffi_arith_ptr(lua_State *L, FFIArith *fa, int sub)
   return 1;
 }
 
-LJLIB_CF(ffi_meta___add)
+/* 64 bit integer arithmetic. */
+static int ffi_arith_int64(lua_State *L, CTState *cts, FFIArith *fa, MMS mm)
 {
+  if (ctype_isnum(fa->ct[0]->info) && fa->ct[0]->size <= 8 &&
+      ctype_isnum(fa->ct[1]->info) && fa->ct[1]->size <= 8) {
+    CTypeID id = (((fa->ct[0]->info & CTF_UNSIGNED) && fa->ct[0]->size == 8) ||
+		  ((fa->ct[1]->info & CTF_UNSIGNED) && fa->ct[1]->size == 8)) ?
+		 CTID_UINT64 : CTID_INT64;
+    CType *ct = ctype_get(cts, id);
+    GCcdata *cd;
+    uint64_t u0, u1, *up;
+    lj_cconv_ct_ct(cts, ct, fa->ct[0], (uint8_t *)&u0, fa->p[0], 0);
+    if (mm != MM_unm)
+      lj_cconv_ct_ct(cts, ct, fa->ct[1], (uint8_t *)&u1, fa->p[1], 0);
+    if ((mm == MM_div || mm == MM_mod)) {
+      if (u1 == 0) {  /* Division by zero. */
+	if (u0 == 0)
+	  setnanV(L->top-1);
+	else if (id == CTID_INT64 && (int64_t)u0 < 0)
+	  setminfV(L->top-1);
+	else
+	  setpinfV(L->top-1);
+	return 1;
+      } else if (id == CTID_INT64 && (int64_t)u1 == -1 &&
+		 u0 == U64x(80000000,00000000)) {  /* MIN64 / -1. */
+	if (mm == MM_div) id = CTID_UINT64; else u0 = 0;
+	mm = MM_unm;  /* Result is 0x8000000000000000ULL or 0LL. */
+      }
+    }
+    cd = lj_cdata_new(cts, id, 8);
+    up = (uint64_t *)cdataptr(cd);
+    setcdataV(L, L->top-1, cd);
+    switch (mm) {
+    case MM_add: *up = u0 + u1; break;
+    case MM_sub: *up = u0 - u1; break;
+    case MM_mul: *up = u0 * u1; break;
+    case MM_div:
+      if (id == CTID_INT64)
+	*up = (uint64_t)((int64_t)u0 / (int64_t)u1);
+      else
+	*up = u0 / u1;
+      break;
+    case MM_mod:
+      if (id == CTID_INT64)
+	*up = (uint64_t)((int64_t)u0 % (int64_t)u1);
+      else
+	*up = u0 % u1;
+      break;
+    case MM_pow: *up = lj_cdata_powi64(u0, u1, (id == CTID_UINT64)); break;
+    case MM_unm: *up = -u0; break;
+    default: lua_assert(0); break;
+    }
+    return 1;
+  }
+  return 0;
+}
+
+/* cdata arithmetic. */
+static int ffi_arith(lua_State *L)
+{
+  CTState *cts = ctype_cts(L);
   FFIArith fa;
-  ffi_checkarith(L, &fa);
-  if (!ffi_arith_ptr(L, &fa, 0))
-    lj_err_caller(L, LJ_ERR_FFI_INVTYPE);
+  MMS mm = (MMS)(curr_func(L)->c.ffid - (int)FF_ffi_meta___add + (int)MM_add);
+  ffi_checkarith(L, cts, &fa);
+  if (!ffi_arith_int64(L, cts, &fa, mm) &&
+      !ffi_arith_ptr(L, cts, &fa, mm)) {
+    const char *repr[2];
+    int i;
+    for (i = 0; i < 2; i++)
+      repr[i] = strdata(lj_ctype_repr(L, ctype_typeid(cts, fa.ct[i]), NULL));
+    lj_err_callerv(L, LJ_ERR_FFI_BADARITH, repr[0], repr[1]);
+  }
   return 1;
 }
 
+LJLIB_CF(ffi_meta___add)
+{
+  return ffi_arith(L);
+}
+
 LJLIB_CF(ffi_meta___sub)
 {
-  FFIArith fa;
-  ffi_checkarith(L, &fa);
-  if (!ffi_arith_ptr(L, &fa, 1))
-    lj_err_caller(L, LJ_ERR_FFI_INVTYPE);
-  return 1;
+  return ffi_arith(L);
+}
+
+LJLIB_CF(ffi_meta___mul)
+{
+  return ffi_arith(L);
+}
+
+LJLIB_CF(ffi_meta___div)
+{
+  return ffi_arith(L);
+}
+
+LJLIB_CF(ffi_meta___mod)
+{
+  return ffi_arith(L);
+}
+
+LJLIB_CF(ffi_meta___pow)
+{
+  return ffi_arith(L);
+}
+
+LJLIB_CF(ffi_meta___unm)
+{
+  return ffi_arith(L);
 }
 
 LJLIB_CF(ffi_meta___tostring)

+ 40 - 0
src/lj_cdata.c

@@ -230,4 +230,44 @@ void lj_cdata_set(CTState *cts, CType *d, uint8_t *dp, TValue *o, CTInfo qual)
   lj_cconv_ct_tv(cts, d, dp, o, 0);
 }
 
+/* -- 64 bit integer arithmetic helpers ----------------------------------- */
+
+/* 64 bit integer x^k. */
+uint64_t lj_cdata_powi64(uint64_t x, uint64_t k, int isunsigned)
+{
+  uint64_t y = 0;
+  int sign = 0;
+  if (k == 0)
+    return 1;
+  if (!isunsigned) {
+    if ((int64_t)k < 0) {
+      if (x == 0)
+	return U64x(7fffffff,ffffffff);
+      else if (x == 1)
+	return 1;
+      else if ((int64_t)x == -1)
+	return (k & 1) ? -1 : 1;
+      else
+	return 0;
+    }
+    if ((int64_t)x < 0) {
+      x = -x;
+      sign = (k & 1);
+    }
+  }
+  for (; (k & 1) == 0; k >>= 1) x *= x;
+  y = x;
+  if ((k >>= 1) != 0) {
+    for (;;) {
+      x *= x;
+      if (k == 1) break;
+      if (k & 1) y *= x;
+      k >>= 1;
+    }
+    y *= x;
+  }
+  if (sign) y = (uint64_t)-(int64_t)y;
+  return y;
+}
+
 #endif

+ 2 - 0
src/lj_cdata.h

@@ -66,6 +66,8 @@ LJ_FUNC void lj_cdata_get(CTState *cts, CType *s, TValue *o, uint8_t *sp);
 LJ_FUNC void lj_cdata_set(CTState *cts, CType *d, uint8_t *dp, TValue *o,
 			  CTInfo qual);
 
+LJ_FUNC uint64_t lj_cdata_powi64(uint64_t x, uint64_t k, int isunsigned);
+
 #endif
 
 #endif

+ 1 - 0
src/lj_errmsg.h

@@ -147,6 +147,7 @@ ERRDEF(FFI_BADTAG,	"undeclared or implicit tag " LUA_QS)
 ERRDEF(FFI_REDEF,	"attempt to redefine " LUA_QS)
 ERRDEF(FFI_INITOV,	"too many initializers for " LUA_QS)
 ERRDEF(FFI_BADCONV,	"cannot convert " LUA_QS " to " LUA_QS)
+ERRDEF(FFI_BADARITH,	"attempt to perform arithmetic on " LUA_QS " and " LUA_QS)
 ERRDEF(FFI_BADMEMBER,	LUA_QS " has no member named " LUA_QS)
 ERRDEF(FFI_BADIDX,	LUA_QS " cannot be indexed")
 ERRDEF(FFI_WRCONST,	"attempt to write to constant location")

+ 2 - 0
src/lj_obj.h

@@ -738,6 +738,8 @@ define_setV(setudataV, GCudata, LJ_TUDATA)
 
 #define setnumV(o, x)		((o)->n = (x))
 #define setnanV(o)		((o)->u64 = U64x(fff80000,00000000))
+#define setpinfV(o)		((o)->u64 = U64x(7ff00000,00000000))
+#define setminfV(o)		((o)->u64 = U64x(fff00000,00000000))
 #define setintV(o, i)		((o)->n = cast_num((int32_t)(i)))
 
 /* Copy tagged values. */