Просмотр исходного кода

DynASM/ARM64: Add .long expr. Add .quad/.addr expr + refs.

Suggested by Dmitry Stogov, Hao Sun and Nick Gasson.
Mike Pall 4 лет назад
Родитель
Сommit
0f8a340c8c
2 измененных файлов с 84 добавлено и 23 удалено
  1. 30 6
      dynasm/dasm_arm64.h
  2. 54 17
      dynasm/dasm_arm64.lua

+ 30 - 6
dynasm/dasm_arm64.h

@@ -21,9 +21,9 @@ enum {
   /* The following actions need a buffer position. */
   DASM_ALIGN, DASM_REL_LG, DASM_LABEL_LG,
   /* The following actions also have an argument. */
-  DASM_REL_PC, DASM_LABEL_PC,
+  DASM_REL_PC, DASM_LABEL_PC, DASM_REL_A,
   DASM_IMM, DASM_IMM6, DASM_IMM12, DASM_IMM13W, DASM_IMM13X, DASM_IMML,
-  DASM_VREG,
+  DASM_IMMV, DASM_VREG,
   DASM__MAX
 };
 
@@ -249,7 +249,7 @@ void dasm_put(Dst_DECL, int start, ...)
 	n = (ins & 255); CK(n < D->maxsection, RANGE_SEC);
 	D->section = &D->sections[n]; goto stop;
       case DASM_ESC: p++; ofs += 4; break;
-      case DASM_REL_EXT: break;
+      case DASM_REL_EXT: if ((ins & 0x8000)) ofs += 8; break;
       case DASM_ALIGN: ofs += (ins & 255); b[pos++] = ofs; break;
       case DASM_REL_LG:
 	n = (ins & 2047) - 10; pl = D->lglabels + n;
@@ -270,6 +270,11 @@ void dasm_put(Dst_DECL, int start, ...)
 	  *pl = pos;
 	}
 	pos++;
+	if ((ins & 0x8000)) ofs += 8;
+	break;
+      case DASM_REL_A:
+	b[pos++] = n;
+	b[pos++] = va_arg(ap, int);
 	break;
       case DASM_LABEL_LG:
 	pl = D->lglabels + (ins & 2047) - 10; CKPL(lg, LG); goto putlabel;
@@ -321,6 +326,10 @@ void dasm_put(Dst_DECL, int start, ...)
 	b[pos++] = n;
 	break;
 	}
+      case DASM_IMMV:
+	ofs += 4;
+	b[pos++] = n;
+	break;
       case DASM_VREG:
 	CK(n < 32, RANGE_VREG);
 	b[pos++] = n;
@@ -381,8 +390,8 @@ int dasm_link(Dst_DECL, size_t *szp)
 	case DASM_REL_LG: case DASM_REL_PC: pos++; break;
 	case DASM_LABEL_LG: case DASM_LABEL_PC: b[pos++] += ofs; break;
 	case DASM_IMM: case DASM_IMM6: case DASM_IMM12: case DASM_IMM13W:
-	case DASM_IMML: case DASM_VREG: pos++; break;
-	case DASM_IMM13X: pos += 2; break;
+	case DASM_IMML: case DASM_IMMV: case DASM_VREG: pos++; break;
+	case DASM_IMM13X: case DASM_REL_A: pos += 2; break;
 	}
       }
       stop: (void)0;
@@ -433,7 +442,9 @@ int dasm_encode(Dst_DECL, void *buffer)
 	  break;
 	case DASM_REL_LG:
 	  if (n < 0) {
-	    n = (int)((ptrdiff_t)D->globals[-n] - (ptrdiff_t)cp + 4);
+	    ptrdiff_t na = (ptrdiff_t)D->globals[-n] - (ptrdiff_t)cp + 4;
+	    n = (int)na;
+	    CK((ptrdiff_t)n == na, RANGE_REL);
 	    goto patchrel;
 	  }
 	  /* fallthrough */
@@ -455,8 +466,18 @@ int dasm_encode(Dst_DECL, void *buffer)
 	  } else if ((ins & 0x1000)) {  /* TBZ, TBNZ */
 	    CK((n & 3) == 0 && ((n+0x00008000) >> 16) == 0, RANGE_REL);
 	    cp[-1] |= ((n << 3) & 0x0007ffe0);
+	  } else if ((ins & 0x8000)) {  /* absolute */
+	    cp[0] = (unsigned int)((ptrdiff_t)cp - 4 + n);
+	    cp[1] = (unsigned int)(((ptrdiff_t)cp - 4 + n) >> 32);
+	    cp += 2;
 	  }
 	  break;
+	case DASM_REL_A: {
+	  ptrdiff_t na = (((ptrdiff_t)(*b++) << 32) | (unsigned int)n) - (ptrdiff_t)cp + 4;
+	  n = (int)na;
+	  CK((ptrdiff_t)n == na, RANGE_REL);
+	  goto patchrel;
+	}
 	case DASM_LABEL_LG:
 	  ins &= 2047; if (ins >= 20) D->globals[ins-10] = (void *)(base + n);
 	  break;
@@ -482,6 +503,9 @@ int dasm_encode(Dst_DECL, void *buffer)
 	    ((n << (10-scale)) | 0x01000000) : ((n & 511) << 12);
 	  break;
 	  }
+	case DASM_IMMV:
+	  *cp++ = n;
+	  break;
 	case DASM_VREG:
 	  cp[-1] |= (n & 0x1f) << (ins & 0x1f);
 	  break;

+ 54 - 17
dynasm/dasm_arm64.lua

@@ -23,12 +23,12 @@ local _M = { _info = _info }
 local type, tonumber, pairs, ipairs = type, tonumber, pairs, ipairs
 local assert, setmetatable, rawget = assert, setmetatable, rawget
 local _s = string
-local sub, format, byte, char = _s.sub, _s.format, _s.byte, _s.char
+local format, byte, char = _s.format, _s.byte, _s.char
 local match, gmatch, gsub = _s.match, _s.gmatch, _s.gsub
 local concat, sort, insert = table.concat, table.sort, table.insert
 local bit = bit or require("bit")
 local band, shl, shr, sar = bit.band, bit.lshift, bit.rshift, bit.arshift
-local ror, tohex = bit.ror, bit.tohex
+local ror, tohex, tobit = bit.ror, bit.tohex, bit.tobit
 
 -- Inherited tables and callbacks.
 local g_opt, g_arch
@@ -39,7 +39,8 @@ local wline, werror, wfatal, wwarn
 local action_names = {
   "STOP", "SECTION", "ESC", "REL_EXT",
   "ALIGN", "REL_LG", "LABEL_LG",
-  "REL_PC", "LABEL_PC", "IMM", "IMM6", "IMM12", "IMM13W", "IMM13X", "IMML",
+  "REL_PC", "LABEL_PC", "REL_A",
+  "IMM", "IMM6", "IMM12", "IMM13W", "IMM13X", "IMML", "IMMV",
   "VREG",
 }
 
@@ -311,7 +312,7 @@ local function parse_number(n)
   local code = loadenv("return "..n)
   if code then
     local ok, y = pcall(code)
-    if ok then return y end
+    if ok and type(y) == "number" then return y end
   end
   return nil
 end
@@ -575,14 +576,14 @@ local function parse_load_pair(params, nparams, n, op)
 end
 
 local function parse_label(label, def)
-  local prefix = sub(label, 1, 2)
+  local prefix = label:sub(1, 2)
   -- =>label (pc label reference)
   if prefix == "=>" then
-    return "PC", 0, sub(label, 3)
+    return "PC", 0, label:sub(3)
   end
   -- ->name (global label reference)
   if prefix == "->" then
-    return "LG", map_global[sub(label, 3)]
+    return "LG", map_global[label:sub(3)]
   end
   if def then
     -- [1-9] (local label definition)
@@ -600,8 +601,11 @@ local function parse_label(label, def)
     if extname then
       return "EXT", map_extern[extname]
     end
+    -- &expr (pointer)
+    if label:sub(1, 1) == "&" then
+      return "A", 0, format("(ptrdiff_t)(%s)", label:sub(2))
+    end
   end
-  werror("bad label `"..label.."'")
 end
 
 local function branch_type(op)
@@ -895,14 +899,14 @@ end
 
 -- Handle opcodes defined with template strings.
 local function parse_template(params, template, nparams, pos)
-  local op = tonumber(sub(template, 1, 8), 16)
+  local op = tonumber(template:sub(1, 8), 16)
   local n = 1
   local rtt = {}
 
   parse_reg_type = false
 
   -- Process each character.
-  for p in gmatch(sub(template, 9), ".") do
+  for p in gmatch(template:sub(9), ".") do
     local q = params[n]
     if p == "D" then
       op = op + parse_reg(q, 0); n = n + 1
@@ -944,8 +948,14 @@ local function parse_template(params, template, nparams, pos)
 
     elseif p == "B" then
       local mode, v, s = parse_label(q, false); n = n + 1
+      if not mode then werror("bad label `"..q.."'") end
       local m = branch_type(op)
-      waction("REL_"..mode, v+m, s, 1)
+      if mode == "A" then
+	waction("REL_"..mode, v+m, format("(unsigned int)(%s)", s))
+	actargs[#actargs+1] = format("(unsigned int)((%s)>>32)", s)
+      else
+	waction("REL_"..mode, v+m, s, 1)
+      end
 
     elseif p == "I" then
       op = op + parse_imm12(q); n = n + 1
@@ -1050,23 +1060,50 @@ map_op[".label_1"] = function(params)
   if not params then return "[1-9] | ->global | =>pcexpr" end
   if secpos+1 > maxsecpos then wflush() end
   local mode, n, s = parse_label(params[1], true)
-  if mode == "EXT" then werror("bad label definition") end
+  if not mode or mode == "EXT" then werror("bad label definition") end
   waction("LABEL_"..mode, n, s, 1)
 end
 
 ------------------------------------------------------------------------------
 
 -- Pseudo-opcodes for data storage.
-map_op[".long_*"] = function(params)
+local function op_data(params)
   if not params then return "imm..." end
+  local sz = params.op == ".long" and 4 or 8
   for _,p in ipairs(params) do
-    local n = tonumber(p)
-    if not n then werror("bad immediate `"..p.."'") end
-    if n < 0 then n = n + 2^32 end
-    wputw(n)
+    local imm = parse_number(p)
+    if imm then
+      local n = tobit(imm)
+      if n == imm or (n < 0 and n + 2^32 == imm) then
+	wputw(n < 0 and n + 2^32 or n)
+	if sz == 8 then
+	  wputw(imm < 0 and 0xffffffff or 0)
+	end
+      elseif sz == 4 then
+	werror("bad immediate `"..p.."'")
+      else
+	imm = nil
+      end
+    end
+    if not imm then
+      local mode, v, s = parse_label(p, false)
+      if sz == 4 then
+	if mode then werror("label does not fit into .long") end
+	waction("IMMV", 0, p)
+      elseif mode and mode ~= "A" then
+	waction("REL_"..mode, v+0x8000, s, 1)
+      else
+	if mode == "A" then p = s end
+	waction("IMMV", 0, format("(unsigned int)(%s)", p))
+	waction("IMMV", 0, format("(unsigned int)((unsigned long long)(%s)>>32)", p))
+      end
+    end
     if secpos+2 > maxsecpos then wflush() end
   end
 end
+map_op[".long_*"] = op_data
+map_op[".quad_*"] = op_data
+map_op[".addr_*"] = op_data
 
 -- Alignment pseudo-opcode.
 map_op[".align_1"] = function(params)