Browse Source

Translate extension name when no replication is needed (#127)

Previously, If an extension uses the replication lowering strategy
and a non-vector overload was chosen we would return the function
un-modified. This change makes sure that we still use the
custom lowering name for the extension if one was specified.
David Peixotto 8 years ago
parent
commit
db78f0400a

+ 0 - 14
include/dxc/HLSL/HLOperationLowerExtension.h

@@ -77,19 +77,5 @@ namespace hlsl {
     llvm::Value *Replicate(llvm::CallInst *CI);
     llvm::Value *Replicate(llvm::CallInst *CI);
     llvm::Value *Pack(llvm::CallInst *CI);
     llvm::Value *Pack(llvm::CallInst *CI);
     llvm::Value *Resource(llvm::CallInst *CI);
     llvm::Value *Resource(llvm::CallInst *CI);
-
-    // Translate the HL call by replicating the call for each vector element.
-    //
-    // For example,
-    //
-    //    <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
-    //    ==>
-    //    %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
-    //    %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
-    //    <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
-    //    <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
-    //
-    // You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
-    static llvm::Value *TranslateReplicating(llvm::CallInst *CI, llvm::Function *ReplicatedFunction);
   };
   };
 }
 }

+ 17 - 9
lib/HLSL/HLOperationLowerExtension.cpp

@@ -320,19 +320,27 @@ private:
   }
   }
 };
 };
 
 
-Value *ExtensionLowering::TranslateReplicating(CallInst *CI, Function *ReplicatedFunction) {
+// Translate the HL call by replicating the call for each vector element.
+//
+// For example,
+//
+//    <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
+//    ==>
+//    %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
+//    %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
+//    <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
+//    <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
+//
+// You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
+Value *ExtensionLowering::Replicate(CallInst *CI) {
+  Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
   if (!ReplicatedFunction)
   if (!ReplicatedFunction)
-    return nullptr;
+    return NoTranslation(CI);
 
 
   ReplicateCall replicate(CI, *ReplicatedFunction);
   ReplicateCall replicate(CI, *ReplicatedFunction);
   return replicate.Generate();
   return replicate.Generate();
 }
 }
 
 
-Value *ExtensionLowering::Replicate(CallInst *CI) {
-  Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
-  return TranslateReplicating(CI, ReplicatedFunction);
-}
-
 ///////////////////////////////////////////////////////////////////////////////
 ///////////////////////////////////////////////////////////////////////////////
 // Packed Lowering.
 // Packed Lowering.
 class PackCall {
 class PackCall {
@@ -436,7 +444,7 @@ class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
 Value *ExtensionLowering::Pack(CallInst *CI) {
 Value *ExtensionLowering::Pack(CallInst *CI) {
   Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
   Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
   if (!PackedFunction)
   if (!PackedFunction)
-    return nullptr;
+    return NoTranslation(CI);
 
 
   PackCall pack(CI, *PackedFunction);
   PackCall pack(CI, *PackedFunction);
   Value *result = pack.Generate();
   Value *result = pack.Generate();
@@ -621,7 +629,7 @@ Value *ExtensionLowering::Resource(CallInst *CI) {
   ResourceFunctionTypeTranslator resourceTypeTranslator(m_handleMap, m_hlslOp);
   ResourceFunctionTypeTranslator resourceTypeTranslator(m_handleMap, m_hlslOp);
   Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
   Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
   if (!resourceFunction)
   if (!resourceFunction)
-    return nullptr;
+    return NoTranslation(CI);
 
 
   ResourceMethodCall explode(CI, *resourceFunction, m_handleMap);
   ResourceMethodCall explode(CI, *resourceFunction, m_handleMap);
   Value *result = explode.Generate();
   Value *result = explode.Generate();

+ 20 - 0
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -403,6 +403,7 @@ public:
   TEST_METHOD(ReplicateLoweringWhenOnlyVectorIsResult);
   TEST_METHOD(ReplicateLoweringWhenOnlyVectorIsResult);
   TEST_METHOD(UnsignedOpcodeIsUnchanged);
   TEST_METHOD(UnsignedOpcodeIsUnchanged);
   TEST_METHOD(ResourceExtensionIntrinsic);
   TEST_METHOD(ResourceExtensionIntrinsic);
+  TEST_METHOD(NameLoweredWhenNoReplicationNeeded);
 };
 };
 
 
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -724,3 +725,22 @@ TEST_F(ExtensionTest, ResourceExtensionIntrinsic) {
   VERIFY_IS_TRUE(regex.isValid(regexErrors));
   VERIFY_IS_TRUE(regex.isValid(regexErrors));
   VERIFY_IS_TRUE(regex.match(disassembly));
   VERIFY_IS_TRUE(regex.match(disassembly));
 }
 }
+
+TEST_F(ExtensionTest, NameLoweredWhenNoReplicationNeeded) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "int main(int v1 : V1) : SV_Target {\n"
+    "  return test_int(v1);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // Make sure the name is still lowered even when no replication
+  // is needed because a non-vector overload of the function
+  // was used.
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i32 @test_int("));
+}