Browse Source

DynASM/x64: Add full VREG support.

Contributed by Peter Cawley.
Mike Pall 9 years ago
parent
commit
a687a60eaa
2 changed files with 104 additions and 36 deletions
  1. 27 6
      dynasm/dasm_x86.h
  2. 77 30
      dynasm/dasm_x86.lua

+ 27 - 6
dynasm/dasm_x86.h

@@ -170,7 +170,7 @@ void dasm_put(Dst_DECL, int start, ...)
   dasm_State *D = Dst_REF;
   dasm_State *D = Dst_REF;
   dasm_ActList p = D->actionlist + start;
   dasm_ActList p = D->actionlist + start;
   dasm_Section *sec = D->section;
   dasm_Section *sec = D->section;
-  int pos = sec->pos, ofs = sec->ofs, mrm = 4;
+  int pos = sec->pos, ofs = sec->ofs, mrm = -1;
   int *b;
   int *b;
 
 
   if (pos >= sec->epos) {
   if (pos >= sec->epos) {
@@ -193,7 +193,7 @@ void dasm_put(Dst_DECL, int start, ...)
       b[pos++] = n;
       b[pos++] = n;
       switch (action) {
       switch (action) {
       case DASM_DISP:
       case DASM_DISP:
-	if (n == 0) { if ((mrm&7) == 4) mrm = p[-2]; if ((mrm&7) != 5) break; }
+	if (n == 0) { if (mrm < 0) mrm = p[-2]; if ((mrm&7) != 5) break; }
       case DASM_IMM_DB: if (((n+128)&-256) == 0) goto ob;
       case DASM_IMM_DB: if (((n+128)&-256) == 0) goto ob;
       case DASM_REL_A: /* Assumes ptrdiff_t is int. !x64 */
       case DASM_REL_A: /* Assumes ptrdiff_t is int. !x64 */
       case DASM_IMM_D: ofs += 4; break;
       case DASM_IMM_D: ofs += 4; break;
@@ -203,10 +203,17 @@ void dasm_put(Dst_DECL, int start, ...)
       case DASM_IMM_W: CK((n&-65536) == 0, RANGE_I); ofs += 2; break;
       case DASM_IMM_W: CK((n&-65536) == 0, RANGE_I); ofs += 2; break;
       case DASM_SPACE: p++; ofs += n; break;
       case DASM_SPACE: p++; ofs += n; break;
       case DASM_SETLABEL: b[pos-2] = -0x40000000; break;  /* Neg. label ofs. */
       case DASM_SETLABEL: b[pos-2] = -0x40000000; break;  /* Neg. label ofs. */
-      case DASM_VREG: CK((n&-8) == 0 && (n != 4 || (*p&1) == 0), RANGE_VREG);
-	if (*p++ == 1 && *p == DASM_DISP) mrm = n; continue;
+      case DASM_VREG: CK((n&-16) == 0 && (n != 4 || (*p>>5) != 2), RANGE_VREG);
+	if (*p < 0x40 && p[1] == DASM_DISP) mrm = n;
+	if (*p < 0x20 && (n&7) == 4) ofs++;
+	switch ((*p++ >> 3) & 3) {
+	case 3: n |= b[pos-3];
+	case 2: n |= b[pos-2];
+	case 1: if (n <= 7) { b[pos-1] |= 0x10; ofs--; }
+	}
+	continue;
       }
       }
-      mrm = 4;
+      mrm = -1;
     } else {
     } else {
       int *pl, n;
       int *pl, n;
       switch (action) {
       switch (action) {
@@ -393,7 +400,21 @@ int dasm_encode(Dst_DECL, void *buffer)
 	case DASM_IMM_W: dasmw(n); break;
 	case DASM_IMM_W: dasmw(n); break;
 	case DASM_VREG: {
 	case DASM_VREG: {
 	  int t = *p++;
 	  int t = *p++;
-	  if (t >= 5) n <<= 4; else if (t >= 2) n <<= 3;
+	  unsigned char *ex = cp - (t&7);
+	  if ((n & 8) && t < 0xa0) {
+	    if (*ex & 0x80) ex[1] ^= 0x20 << (t>>6); else *ex ^= 1 << (t>>6);
+	  } else if (n & 0x10) {
+	    if (*ex & 0x80) {
+	      *ex = 0xc5; ex[1] = (ex[1] & 0x80) | ex[2]; ex += 2;
+	    }
+	    while (++ex < cp) ex[-1] = *ex;
+	    if (mark) mark--;
+	    cp--;
+	  }
+	  n &= 7;
+	  if (t >= 0xc0) n <<= 4;
+	  else if (t >= 0x40) n <<= 3;
+	  else if (n == 4 && t < 0x20) { cp[-1] ^= n; *cp++ = 0x20; }
 	  cp[-1] ^= n;
 	  cp[-1] ^= n;
 	  break;
 	  break;
 	}
 	}

+ 77 - 30
dynasm/dasm_x86.lua

@@ -41,7 +41,7 @@ local action_names = {
   -- int arg, 1 buffer pos:
   -- int arg, 1 buffer pos:
   "DISP",  "IMM_S", "IMM_B", "IMM_W", "IMM_D",  "IMM_WB", "IMM_DB",
   "DISP",  "IMM_S", "IMM_B", "IMM_W", "IMM_D",  "IMM_WB", "IMM_DB",
   -- action arg (1 byte), int arg, 1 buffer pos (reg/num):
   -- action arg (1 byte), int arg, 1 buffer pos (reg/num):
-  "VREG", "SPACE", -- !x64: VREG support NYI.
+  "VREG", "SPACE",
   -- ptrdiff_t arg, 1 buffer pos (address): !x64
   -- ptrdiff_t arg, 1 buffer pos (address): !x64
   "SETLABEL", "REL_A",
   "SETLABEL", "REL_A",
   -- action arg (1 byte) or int arg, 2 buffer pos (link, offset):
   -- action arg (1 byte) or int arg, 2 buffer pos (link, offset):
@@ -83,6 +83,21 @@ local actargs = { 0 }
 -- Current number of section buffer positions for dasm_put().
 -- Current number of section buffer positions for dasm_put().
 local secpos = 1
 local secpos = 1
 
 
+-- VREG kind encodings, pre-shifted by 5 bits.
+local map_vreg = {
+  ["modrm.rm.m"] = 0x00,
+  ["modrm.rm.r"] = 0x20,
+  ["opcode"] =     0x20,
+  ["sib.base"] =   0x20,
+  ["sib.index"] =  0x40,
+  ["modrm.reg"] =  0x80,
+  ["vex.v"] =      0xa0,
+  ["imm.hi"] =     0xc0,
+}
+
+-- Current number of VREG actions contributing to REX/VEX shrinkage.
+local vreg_shrink_count = 0
+
 ------------------------------------------------------------------------------
 ------------------------------------------------------------------------------
 
 
 -- Compute action numbers for action names.
 -- Compute action numbers for action names.
@@ -134,6 +149,21 @@ local function waction(action, a, num)
   if a or num then secpos = secpos + (num or 1) end
   if a or num then secpos = secpos + (num or 1) end
 end
 end
 
 
+-- Optionally add a VREG action.
+local function wvreg(kind, vreg, psz, sk, defer)
+  if not vreg then return end
+  waction("VREG", vreg)
+  local b = assert(map_vreg[kind], "bad vreg kind `"..vreg.."'")
+  if b < (sk or 0) then
+    vreg_shrink_count = vreg_shrink_count + 1
+  end
+  if not defer then
+    b = b + vreg_shrink_count * 8
+    vreg_shrink_count = 0
+  end
+  wputxb(b + (psz or 0))
+end
+
 -- Add call to embedded DynASM C code.
 -- Add call to embedded DynASM C code.
 local function wcall(func, args)
 local function wcall(func, args)
   wline(format("dasm_%s(Dst, %s);", func, concat(args, ", ")), true)
   wline(format("dasm_%s(Dst, %s);", func, concat(args, ", ")), true)
@@ -326,6 +356,7 @@ mkrmap("w", "Rw", {"ax", "cx", "dx", "bx", "sp", "bp", "si", "di"})
 mkrmap("b", "Rb", {"al", "cl", "dl", "bl", "ah", "ch", "dh", "bh"})
 mkrmap("b", "Rb", {"al", "cl", "dl", "bl", "ah", "ch", "dh", "bh"})
 map_reg_valid_index[map_archdef.esp] = false
 map_reg_valid_index[map_archdef.esp] = false
 if x64 then map_reg_valid_index[map_archdef.rsp] = false end
 if x64 then map_reg_valid_index[map_archdef.rsp] = false end
+if x64 then map_reg_needrex[map_archdef.Rb] = true end
 map_archdef["Ra"] = "@"..addrsize
 map_archdef["Ra"] = "@"..addrsize
 
 
 -- FP registers (internally tword sized, but use "f" as operand size).
 -- FP registers (internally tword sized, but use "f" as operand size).
@@ -463,16 +494,24 @@ local function wputszarg(sz, n)
 end
 end
 
 
 -- Put multi-byte opcode with operand-size dependent modifications.
 -- Put multi-byte opcode with operand-size dependent modifications.
-local function wputop(sz, op, rex, vex)
+local function wputop(sz, op, rex, vex, vregr, vregxb)
+  local psz, sk = 0, nil
   if vex then
   if vex then
     local tail
     local tail
     if vex.m == 1 and band(rex, 11) == 0 then
     if vex.m == 1 and band(rex, 11) == 0 then
-      wputb(0xc5)
+      if x64 and vregxb then
+	sk = map_vreg["modrm.reg"]
+      else
+	wputb(0xc5)
       tail = shl(bxor(band(rex, 4), 4), 5)
       tail = shl(bxor(band(rex, 4), 4), 5)
-    else
+      psz = 3
+      end
+    end
+    if not tail then
       wputb(0xc4)
       wputb(0xc4)
       wputb(shl(bxor(band(rex, 7), 7), 5) + vex.m)
       wputb(shl(bxor(band(rex, 7), 7), 5) + vex.m)
       tail = shl(band(rex, 8), 4)
       tail = shl(band(rex, 8), 4)
+      psz = 4
     end
     end
     local reg, vreg = 0, nil
     local reg, vreg = 0, nil
     if vex.v then
     if vex.v then
@@ -482,12 +521,18 @@ local function wputop(sz, op, rex, vex)
     end
     end
     if sz == "y" or vex.l then tail = tail + 4 end
     if sz == "y" or vex.l then tail = tail + 4 end
     wputb(tail + shl(bxor(reg, 15), 3) + vex.p)
     wputb(tail + shl(bxor(reg, 15), 3) + vex.p)
-    if vreg then waction("VREG", vreg); wputxb(4) end
+    wvreg("vex.v", vreg)
     rex = 0
     rex = 0
     if op >= 256 then werror("bad vex opcode") end
     if op >= 256 then werror("bad vex opcode") end
+  else
+    if rex ~= 0 then
+      if not x64 then werror("bad operand size") end
+    elseif (vregr or vregxb) and x64 then
+      rex = 0x10
+      sk = map_vreg["vex.v"]
+    end
   end
   end
   local r
   local r
-  if rex ~= 0 and not x64 then werror("bad operand size") end
   if sz == "w" then wputb(102) end
   if sz == "w" then wputb(102) end
   -- Needs >32 bit numbers, but only for crc32 eax, word [ebx]
   -- Needs >32 bit numbers, but only for crc32 eax, word [ebx]
   if op >= 4294967296 then r = op%4294967296 wputb((op-r)/4294967296) op = r end
   if op >= 4294967296 then r = op%4294967296 wputb((op-r)/4294967296) op = r end
@@ -496,20 +541,20 @@ local function wputop(sz, op, rex, vex)
     if rex ~= 0 then
     if rex ~= 0 then
       local opc3 = band(op, 0xffff00)
       local opc3 = band(op, 0xffff00)
       if opc3 == 0x0f3a00 or opc3 == 0x0f3800 then
       if opc3 == 0x0f3a00 or opc3 == 0x0f3800 then
-	wputb(64 + band(rex, 15)); rex = 0
+	wputb(64 + band(rex, 15)); rex = 0; psz = 2
       end
       end
     end
     end
-    wputb(shr(op, 16)); op = band(op, 0xffff)
+    wputb(shr(op, 16)); op = band(op, 0xffff); psz = psz + 1
   end
   end
   if op >= 256 then
   if op >= 256 then
     local b = shr(op, 8)
     local b = shr(op, 8)
-    if b == 15 and rex ~= 0 then wputb(64 + band(rex, 15)); rex = 0 end
-    wputb(b)
-    op = band(op, 255)
+    if b == 15 and rex ~= 0 then wputb(64 + band(rex, 15)); rex = 0; psz = 2 end
+    wputb(b); op = band(op, 255); psz = psz + 1
   end
   end
-  if rex ~= 0 then wputb(64 + band(rex, 15)) end
+  if rex ~= 0 then wputb(64 + band(rex, 15)); psz = 2 end
   if sz == "b" then op = op - 1 end
   if sz == "b" then op = op - 1 end
   wputb(op)
   wputb(op)
+  return psz, sk
 end
 end
 
 
 -- Put ModRM or SIB formatted byte.
 -- Put ModRM or SIB formatted byte.
@@ -519,7 +564,7 @@ local function wputmodrm(m, s, rm, vs, vrm)
 end
 end
 
 
 -- Put ModRM/SIB plus optional displacement.
 -- Put ModRM/SIB plus optional displacement.
-local function wputmrmsib(t, imark, s, vsreg)
+local function wputmrmsib(t, imark, s, vsreg, psz, sk)
   local vreg, vxreg
   local vreg, vxreg
   local reg, xreg = t.reg, t.xreg
   local reg, xreg = t.reg, t.xreg
   if reg and reg < 0 then reg = 0; vreg = t.vreg end
   if reg and reg < 0 then reg = 0; vreg = t.vreg end
@@ -529,8 +574,8 @@ local function wputmrmsib(t, imark, s, vsreg)
   -- Register mode.
   -- Register mode.
   if sub(t.mode, 1, 1) == "r" then
   if sub(t.mode, 1, 1) == "r" then
     wputmodrm(3, s, reg)
     wputmodrm(3, s, reg)
-    if vsreg then waction("VREG", vsreg); wputxb(2) end
-    if vreg then waction("VREG", vreg); wputxb(0) end
+    wvreg("modrm.reg", vsreg, psz+1, sk, vreg)
+    wvreg("modrm.rm.r", vreg, psz+1, sk)
     return
     return
   end
   end
 
 
@@ -544,21 +589,22 @@ local function wputmrmsib(t, imark, s, vsreg)
       -- [xreg*xsc+disp] -> (0, s, esp) (xsc, xreg, ebp)
       -- [xreg*xsc+disp] -> (0, s, esp) (xsc, xreg, ebp)
       wputmodrm(0, s, 4)
       wputmodrm(0, s, 4)
       if imark == "I" then waction("MARK") end
       if imark == "I" then waction("MARK") end
-      if vsreg then waction("VREG", vsreg); wputxb(2) end
+      wvreg("modrm.reg", vsreg, psz+1, sk, vxreg)
       wputmodrm(t.xsc, xreg, 5)
       wputmodrm(t.xsc, xreg, 5)
-      if vxreg then waction("VREG", vxreg); wputxb(3) end
+      wvreg("sib.index", vxreg, psz+2, sk)
     else
     else
       -- Pure 32 bit displacement.
       -- Pure 32 bit displacement.
       if x64 and tdisp ~= "table" then
       if x64 and tdisp ~= "table" then
 	wputmodrm(0, s, 4) -- [disp] -> (0, s, esp) (0, esp, ebp)
 	wputmodrm(0, s, 4) -- [disp] -> (0, s, esp) (0, esp, ebp)
+	wvreg("modrm.reg", vsreg, psz+1, sk)
 	if imark == "I" then waction("MARK") end
 	if imark == "I" then waction("MARK") end
 	wputmodrm(0, 4, 5)
 	wputmodrm(0, 4, 5)
       else
       else
 	riprel = x64
 	riprel = x64
 	wputmodrm(0, s, 5) -- [disp|rip-label] -> (0, s, ebp)
 	wputmodrm(0, s, 5) -- [disp|rip-label] -> (0, s, ebp)
+	wvreg("modrm.reg", vsreg, psz+1, sk)
 	if imark == "I" then waction("MARK") end
 	if imark == "I" then waction("MARK") end
       end
       end
-      if vsreg then waction("VREG", vsreg); wputxb(2) end
     end
     end
     if riprel then -- Emit rip-relative displacement.
     if riprel then -- Emit rip-relative displacement.
       if match("UWSiI", imark) then
       if match("UWSiI", imark) then
@@ -586,16 +632,16 @@ local function wputmrmsib(t, imark, s, vsreg)
   if xreg or band(reg, 7) == 4 then
   if xreg or band(reg, 7) == 4 then
     wputmodrm(m or 2, s, 4) -- ModRM.
     wputmodrm(m or 2, s, 4) -- ModRM.
     if m == nil or imark == "I" then waction("MARK") end
     if m == nil or imark == "I" then waction("MARK") end
-    if vsreg then waction("VREG", vsreg); wputxb(2) end
+    wvreg("modrm.reg", vsreg, psz+1, sk, vxreg or vreg)
     wputmodrm(t.xsc or 0, xreg or 4, reg) -- SIB.
     wputmodrm(t.xsc or 0, xreg or 4, reg) -- SIB.
-    if vxreg then waction("VREG", vxreg); wputxb(3) end
-    if vreg then waction("VREG", vreg); wputxb(1) end
+    wvreg("sib.index", vxreg, psz+2, sk, vreg)
+    wvreg("sib.base", vreg, psz+2, sk)
   else
   else
     wputmodrm(m or 2, s, reg) -- ModRM.
     wputmodrm(m or 2, s, reg) -- ModRM.
     if (imark == "I" and (m == 1 or m == 2)) or
     if (imark == "I" and (m == 1 or m == 2)) or
        (m == nil and (vsreg or vreg)) then waction("MARK") end
        (m == nil and (vsreg or vreg)) then waction("MARK") end
-    if vsreg then waction("VREG", vsreg); wputxb(2) end
-    if vreg then waction("VREG", vreg); wputxb(1) end
+    wvreg("modrm.reg", vsreg, psz+1, sk, vreg)
+    wvreg("modrm.rm.m", vreg, psz+1, sk)
   end
   end
 
 
   -- Put displacement.
   -- Put displacement.
@@ -1761,10 +1807,11 @@ local function dopattern(pat, args, sz, op, needrex)
       if t.xreg and t.xreg > 7 then rex = rex + 2 end
       if t.xreg and t.xreg > 7 then rex = rex + 2 end
       if s > 7 then rex = rex + 4 end
       if s > 7 then rex = rex + 4 end
       if needrex then rex = rex + 16 end
       if needrex then rex = rex + 16 end
-      wputop(szov, opcode, rex, vex); opcode = nil
+      local psz, sk = wputop(szov, opcode, rex, vex, s < 0, t.vreg or t.vxreg)
+      opcode = nil
       local imark = sub(pat, -1) -- Force a mark (ugly).
       local imark = sub(pat, -1) -- Force a mark (ugly).
       -- Put ModRM/SIB with regno/last digit as spare.
       -- Put ModRM/SIB with regno/last digit as spare.
-      wputmrmsib(t, imark, s, addin and addin.vreg)
+      wputmrmsib(t, imark, s, addin and addin.vreg, psz, sk)
       addin = nil
       addin = nil
     elseif map_vexarg[c] ~= nil then -- Encode using VEX prefix
     elseif map_vexarg[c] ~= nil then -- Encode using VEX prefix
       local b = band(opcode, 255); opcode = shr(opcode, 8)
       local b = band(opcode, 255); opcode = shr(opcode, 8)
@@ -1791,8 +1838,8 @@ local function dopattern(pat, args, sz, op, needrex)
 	if szov == "q" and rex == 0 then rex = rex + 8 end
 	if szov == "q" and rex == 0 then rex = rex + 8 end
 	if needrex then rex = rex + 16 end
 	if needrex then rex = rex + 16 end
 	if addin and addin.reg == -1 then
 	if addin and addin.reg == -1 then
-	  wputop(szov, opcode - 7, rex, vex)
-	  waction("VREG", addin.vreg); wputxb(0)
+	  local psz, sk = wputop(szov, opcode - 7, rex, vex, true)
+	  wvreg("opcode", addin.vreg, psz, sk)
 	else
 	else
 	  if addin and addin.reg > 7 then rex = rex + 1 end
 	  if addin and addin.reg > 7 then rex = rex + 1 end
 	  wputop(szov, opcode, rex, vex)
 	  wputop(szov, opcode, rex, vex)
@@ -1836,7 +1883,7 @@ local function dopattern(pat, args, sz, op, needrex)
 	  local reg = a.reg
 	  local reg = a.reg
 	  if reg < 0 then
 	  if reg < 0 then
 	    wputb(0)
 	    wputb(0)
-	    waction("VREG", a.vreg); wputxb(5)
+	    wvreg("imm.hi", a.vreg)
 	  else
 	  else
 	    wputb(shl(reg, 4))
 	    wputb(shl(reg, 4))
 	  end
 	  end
@@ -1988,8 +2035,8 @@ if x64 then
 	rex = a.reg > 7 and 9 or 8
 	rex = a.reg > 7 and 9 or 8
       end
       end
     end
     end
-    wputop(sz, opcode, rex)
-    if vreg then waction("VREG", vreg); wputxb(0) end
+    local psz, sk = wputop(sz, opcode, rex, nil, vreg)
+    wvreg("opcode", vreg, psz, sk)
     waction("IMM_D", format("(unsigned int)(%s)", op64))
     waction("IMM_D", format("(unsigned int)(%s)", op64))
     waction("IMM_D", format("(unsigned int)((%s)>>32)", op64))
     waction("IMM_D", format("(unsigned int)((%s)>>32)", op64))
   end
   end