Просмотр исходного кода

Add support for specifying overload arg index in extension function (#3510)

Any extension function that includes the specical string "$o" in its name
needs to have the $o replaced with the type name of the overload. Previously
we used a default heuristic to select the overload type from a function
argument.

This commit adds support for explicitly setting the argument to use for the
overload name by appending a ":<ArgIndex>" to the overload marker. For example,
using a name like "my_special_function.$o:3" would take the overload type from
the third function argument.
David Peixotto 4 лет назад
Родитель
Сommit
58163b04cf
2 измененных файлов с 115 добавлено и 11 удалено
  1. 77 11
      lib/HLSL/HLOperationLowerExtension.cpp
  2. 38 0
      tools/clang/unittests/HLSL/ExtensionTest.cpp

+ 77 - 11
lib/HLSL/HLOperationLowerExtension.cpp

@@ -1142,13 +1142,21 @@ private:
     return name.size() > 0;
   }
 
+  typedef unsigned OverloadArgIndex;
+  static constexpr OverloadArgIndex DefaultOverloadIndex = std::numeric_limits<OverloadArgIndex>::max();
+
   // Choose the (return value or argument) type that determines the overload type
   // for the intrinsic call.
-  // For now we take the return type as the overload. If the return is void we
-  // take the first (non-opcode) argument as the overload type. We could extend the
-  // $o sytnax in the extension name to explicitly specify the overload slot (e.g.
-  // $o:3 would say the overload type is determined by parameter 3.
-  static Type *SelectOverloadSlot(CallInst *CI) {
+  // If the overload arg index was explicitly specified (see ParseOverloadArgIndex)
+  // then we use that arg to pick the overload name. Otherwise we pick a default
+  // where we take the return type as the overload. If the return is void we
+  // take the first (non-opcode) argument as the overload type.
+  static Type *SelectOverloadSlot(CallInst *CI, OverloadArgIndex ArgIndex) {
+   if (ArgIndex != DefaultOverloadIndex)
+    {
+      return CI->getArgOperand(ArgIndex)->getType();
+    }
+
     Type *ty = CI->getType();
     if (ty->isVoidTy()) {
       if (CI->getNumArgOperands() > 1)
@@ -1158,8 +1166,8 @@ private:
     return ty;
   }
 
-  static Type *GetOverloadType(CallInst *CI) {
-    Type *ty = SelectOverloadSlot(CI);
+  static Type *GetOverloadType(CallInst *CI, OverloadArgIndex ArgIndex) {
+    Type *ty = SelectOverloadSlot(CI, ArgIndex);
     if (ty->isVectorTy())
       ty = ty->getVectorElementType();
 
@@ -1174,19 +1182,77 @@ private:
       return typeName;
   }
 
-  static std::string GetOverloadTypeName(CallInst *CI) {
-    Type *ty = GetOverloadType(CI);
+  static std::string GetOverloadTypeName(CallInst *CI, OverloadArgIndex ArgIndex) {
+    Type *ty = GetOverloadType(CI, ArgIndex);
     return GetTypeName(ty);
   }
 
+  // Parse the arg index out of the overload marker (if any).
+  //
+  // The function names use a $o to indicate that the function is overloaded
+  // and we should replace $o with the overload type. The extension name can
+  // explicitly set which arg to use for the overload type by adding a colon
+  // and a number after the $o (e.g. $o:3 would say the overload type is
+  // determined by parameter 3).
+  //
+  // If we find an arg index after the overload marker we update the size
+  // of the marker to include the full parsed string size so that it can
+  // be replaced with the selected overload type.
+  //
+  static OverloadArgIndex ParseOverloadArgIndex(
+      const std::string& functionName,
+      size_t OverloadMarkerStartIndex,
+      size_t *pOverloadMarkerSize)
+  {
+      assert(OverloadMarkerStartIndex != std::string::npos);
+      size_t StartIndex = OverloadMarkerStartIndex + *pOverloadMarkerSize;
+
+      // Check if we have anything after the overload marker to parse.
+      if (StartIndex >= functionName.size())
+      {
+          return DefaultOverloadIndex;
+      }
+
+      // Does it start with a ':' ?
+      if (functionName[StartIndex] != ':')
+      {
+          return DefaultOverloadIndex;
+      }
+
+      // Skip past the :
+      ++StartIndex;
+
+      // Collect all the digits.
+      std::string Digits;
+      Digits.reserve(functionName.size() - StartIndex);
+      for (size_t i = StartIndex; i < functionName.size(); ++i)
+      {
+          char c = functionName[i];
+          if (!isdigit(c))
+          {
+              break;
+          }
+          Digits.push_back(c);
+      }
+
+      if (Digits.empty())
+      {
+          return DefaultOverloadIndex;
+      }
+
+      *pOverloadMarkerSize = *pOverloadMarkerSize + std::strlen(":") + Digits.size();
+      return std::stoi(Digits);
+  }
+
   // Find the occurence of the overload marker $o and replace it the the overload type name.
   static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
     const char *OverloadMarker = "$o";
-    const size_t OverloadMarkerLength = 2;
+    size_t OverloadMarkerLength = 2;
 
     size_t pos = functionName.find(OverloadMarker);
     if (pos != std::string::npos) {
-      std::string typeName = GetOverloadTypeName(CI);
+      OverloadArgIndex ArgIndex = ParseOverloadArgIndex(functionName, pos, &OverloadMarkerLength);
+      std::string typeName = GetOverloadTypeName(CI, ArgIndex);
       functionName.replace(pos, OverloadMarkerLength, typeName);
     }
   }

+ 38 - 0
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -145,6 +145,15 @@ static const HLSL_INTRINSIC_ARGUMENT TestMyTexture2DOp[] = {
   { "val", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
 };
 
+// float = test_overload(float a, uint b, double c)
+static const HLSL_INTRINSIC_ARGUMENT TestOverloadArgs[] = {
+  { "test_overload", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_NUMERIC, 1, IA_C },
+  { "a", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_FLOAT, 1, IA_C },
+  { "b", AR_QUAL_IN, 2, LITEMPLATE_ANY, 2, LICOMPTYPE_UINT, 1, IA_C },
+  { "c", AR_QUAL_IN, 3, LITEMPLATE_SCALAR, 3, LICOMPTYPE_DOUBLE, 1, IA_C },
+};
+
+
 struct Intrinsic {
   LPCWSTR hlslName;
   const char *dxilName;
@@ -175,6 +184,9 @@ Intrinsic Intrinsics[] = {
   // counterpart for testing purposes.
   {L"test_unsigned","test_unsigned",   "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, false, -1, countof(TestUnsigned), TestUnsigned}},
   {L"wave_proc",    DEFAULT_NAME,      "r", { 16, false, true, true, -1, countof(WaveProcArgs), WaveProcArgs }},
+  {L"test_o_1",     "test_o_1.$o:1",   "r", { 18, false, true, true, -1, countof(TestOverloadArgs), TestOverloadArgs }},
+  {L"test_o_2",     "test_o_2.$o:2",   "r", { 19, false, true, true, -1, countof(TestOverloadArgs), TestOverloadArgs }},
+  {L"test_o_3",     "test_o_3.$o:3",   "r", { 20, false, true, true, -1, countof(TestOverloadArgs), TestOverloadArgs }},
 };
 
 Intrinsic BufferIntrinsics[] = {
@@ -530,6 +542,7 @@ public:
   TEST_METHOD(ResourceExtensionIntrinsicCustomLowering1)
   TEST_METHOD(ResourceExtensionIntrinsicCustomLowering2)
   TEST_METHOD(ResourceExtensionIntrinsicCustomLowering3)
+  TEST_METHOD(CustomOverloadArg1)
 };
 
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -1182,3 +1195,28 @@ TEST_F(ExtensionTest, ResourceExtensionIntrinsicCustomLowering3) {
   };
   CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
 }
+
+TEST_F(ExtensionTest, CustomOverloadArg1) {
+  // Test that we pick the overload name based on the first arg.
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  auto result = c.Compile(
+    "float main() : SV_Target {\n"
+    "  float o1 = test_o_1(1.0f, 2u, 4.0);\n"
+    "  float o2 = test_o_2(1.0f, 2u, 4.0);\n"
+    "  float o3 = test_o_3(1.0f, 2u, 4.0);\n"
+    "  return o1 + o2 + o3;\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  CheckOperationResultMsgs(result, {}, true, false);
+  std::string disassembly = c.Disassemble();
+
+  // The function name should match the first arg (float)
+  LPCSTR expected[] = {
+    "call float @test_o_1.float(i32 18, float 1.000000e+00, i32 2, double 4.000000e+00)",
+    "call float @test_o_2.i32(i32 18, float 1.000000e+00, i32 2, double 4.000000e+00)",
+    "call float @test_o_3.double(i32 18, float 1.000000e+00, i32 2, double 4.000000e+00)",
+  };
+  CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, false);
+}