Prechádzať zdrojové kódy

Support lowering extensions directly to dxil (#217)

This commit adds the 'dxil' strategy for lowering extensions. This
strategy will change the extension call into a call to a dxil intrinsic.
This is useful for targeting dxil intrinsics that are not exposed in hlsl.
David Peixotto 8 rokov pred
rodič
commit
76a4288913

+ 2 - 0
include/dxc/HLSL/HLOperationLowerExtension.h

@@ -37,6 +37,7 @@ namespace hlsl {
       Replicate,      // Scalarize the vector arguments and replicate the call.
       Pack,           // Convert the vector arguments into structs.
       Resource,       // Convert return value to resource return and explode vectors.
+      Dxil,           // Convert call to a dxil intrinsic.
     };
 
     // Create the lowering using the given strategy and custom codegen helper.
@@ -74,5 +75,6 @@ namespace hlsl {
     llvm::Value *Replicate(llvm::CallInst *CI);
     llvm::Value *Pack(llvm::CallInst *CI);
     llvm::Value *Resource(llvm::CallInst *CI);
+    llvm::Value *Dxil(llvm::CallInst *CI);
   };
 }

+ 7 - 0
include/dxc/HLSL/HLSLExtensionsCodegenHelper.h

@@ -10,6 +10,7 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 #pragma once
+#include "dxc/HLSL/DxilOperations.h"
 #include <vector>
 #include <string>
 
@@ -65,6 +66,12 @@ public:
   // Get the name to use for the dxil intrinsic function.
   virtual std::string GetIntrinsicName(unsigned opcode) = 0;
 
+  // Get the dxil opcode the extension should use when lowering with
+  // dxil lowering strategy.
+  //
+  // Returns true if the opcode was successfully mapped to a dxil opcode.
+  virtual bool GetDxilOpcode(unsigned opcode, OP::OpCode &dxilOpcode) = 0;
+
   // Struct to hold a root signature that is read from a define.
   struct CustomRootSignature {
     std::string RootSignature;

+ 11 - 0
include/dxc/Support/DxcLangExtensionsHelper.h

@@ -122,6 +122,17 @@ public:
       return "";
   }
 
+  // Get the dxil opcode for the extension opcode if one exists.
+  // Return true if the opcode was mapped successfully.
+  bool GetDxilOpCode(UINT opcode, UINT &dxilOpcode) {
+    for (IDxcIntrinsicTable *table : m_intrinsicTables) {
+      if (SUCCEEDED(table->GetDxilOpCode(opcode, &dxilOpcode))) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   // Result of validating a semantic define.
   // Stores any warning or error messages produced by the validator.
   // Successful validation means that there are no warning or error messages.

+ 1 - 0
include/dxc/dxcapi.h

@@ -1,3 +1,4 @@
+
 ///////////////////////////////////////////////////////////////////////////////
 //                                                                           //
 // dxcapi.h                                                                  //

+ 4 - 0
include/dxc/dxcapi.internal.h

@@ -134,6 +134,10 @@ public:
   // name. The string "$o" in the name will be replaced by the return type of the
   // intrinsic.
   virtual HRESULT STDMETHODCALLTYPE GetIntrinsicName(UINT opcode, LPCSTR *pName) = 0;
+
+  // Callback to support the 'dxil' lowering strategy.
+  // Returns the dxil opcode that the intrinsic should use for lowering.
+  virtual HRESULT STDMETHODCALLTYPE GetDxilOpCode(UINT opcode, UINT *pDxilOpcode) = 0;
 };
 
 struct __declspec(uuid("1d063e4f-515a-4d57-a12a-431f6a44cfb9"))

+ 37 - 1
lib/HLSL/HLOperationLowerExtension.cpp

@@ -35,6 +35,7 @@ ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
     case 'r': return Strategy::Replicate;
     case 'p': return Strategy::Pack;
     case 'm': return Strategy::Resource;
+    case 'd': return Strategy::Dxil;
     default: break;
   }
   return Strategy::Unknown;
@@ -46,6 +47,7 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
     case Strategy::Replicate:     return "r";
     case Strategy::Pack:          return "p";
     case Strategy::Resource:      return "m"; // m for resource method
+    case Strategy::Dxil:          return "d";
     default: break;
   }
   return "?";
@@ -65,6 +67,7 @@ llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
   case Strategy::Replicate:     return Replicate(CI);
   case Strategy::Pack:          return Pack(CI);
   case Strategy::Resource:      return Resource(CI);
+  case Strategy::Dxil:          return Dxil(CI);
   default: break;
   }
   return Unknown(CI);
@@ -190,7 +193,7 @@ llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
 ///////////////////////////////////////////////////////////////////////////////
 // Replicated Lowering.
 enum {
-  NO_COMMON_VECTOR_SIZE = 0xFFFFFFFF,
+  NO_COMMON_VECTOR_SIZE = 0x0,
 };
 // Find the vector size that will be used for replication.
 // The function call will be replicated once for each element of the vector
@@ -609,6 +612,39 @@ Value *ExtensionLowering::Resource(CallInst *CI) {
   return result;
 }
 
+///////////////////////////////////////////////////////////////////////////////
+// Dxil Lowering.
+
+Value *ExtensionLowering::Dxil(CallInst *CI) {
+  // Map the extension opcode to the corresponding dxil opcode.
+  unsigned extOpcode = GetHLOpcode(CI);
+  OP::OpCode dxilOpcode;
+  if (!m_helper->GetDxilOpcode(extOpcode, dxilOpcode))
+    return nullptr;
+
+  // Find the dxil function based on the overload type.
+  Type *overloadTy = m_hlslOp.GetOverloadType(dxilOpcode, CI->getCalledFunction());
+  Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());
+
+  // Update the opcode in the original call so we can just copy it below.
+  // We are about to delete this call anyway.
+  CI->setOperand(0, m_hlslOp.GetI32Const(static_cast<unsigned>(dxilOpcode)));
+
+  // Create the new call.
+  Value *result = nullptr;
+  if (overloadTy->isVectorTy()) {
+    ReplicateCall replicate(CI, *F);
+    result = replicate.Generate();
+  }
+  else {
+    IRBuilder<> builder(CI);
+    SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
+    result = builder.CreateCall(F, args);
+  }
+
+  return result;
+}
+
 ///////////////////////////////////////////////////////////////////////////////
 // Computing Extension Names.
 

+ 11 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -1903,6 +1903,17 @@ public:
     return m_langExtensionsHelper.GetIntrinsicName(opcode);
   }
   
+  virtual bool GetDxilOpcode(UINT opcode, OP::OpCode &dxilOpcode) override {
+    UINT dop = static_cast<UINT>(OP::OpCode::NumOpCodes);
+    if (m_langExtensionsHelper.GetDxilOpCode(opcode, dop)) {
+      if (dop < static_cast<UINT>(OP::OpCode::NumOpCodes)) {
+        dxilOpcode = static_cast<OP::OpCode>(dop);
+        return true;
+      }
+    }
+    return false;
+  }
+
   virtual HLSLExtensionsCodegenHelper::CustomRootSignature::Status GetCustomRootSignature(CustomRootSignature *out) {
     // Find macro definition in preprocessor.
     Preprocessor &pp = m_CI.getPreprocessor();

+ 87 - 0
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -15,6 +15,7 @@
 #include "dxc/dxcapi.internal.h"
 #include "dxc/HLSL/HLOperationLowerExtension.h"
 #include "dxc/HlslIntrinsicOp.h"
+#include "dxc/HLSL/DxilOperations.h"
 #include "llvm/Support/Regex.h"
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -97,6 +98,20 @@ static const HLSL_INTRINSIC_ARGUMENT TestMyBufferOp[] = {
   { "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
 };
 
+// bool<> = test_isinf(float<> x)
+static const HLSL_INTRINSIC_ARGUMENT TestIsInf[] = {
+  { "test_isinf", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_BOOL, 1, IA_C },
+  { "x", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_FLOAT, 1, IA_C},
+};
+
+// int = test_ibfe(uint width, uint offset, uint val)
+static const HLSL_INTRINSIC_ARGUMENT TestIBFE[] = {
+  { "test_ibfe", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_INT, 1, 1 },
+  { "width",  AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+  { "offset", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+  { "val",    AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
+};
+
 struct Intrinsic {
   LPCWSTR hlslName;
   const char *dxilName;
@@ -121,6 +136,8 @@ Intrinsic Intrinsics[] = {
   {L"test_pack_3",  "test_pack_3.$o",  "p", {  9, false, true, -1, countof(TestFnPack3), TestFnPack3}},
   {L"test_pack_4",  "test_pack_4.$o",  "p", { 10, false, true, -1, countof(TestFnPack4), TestFnPack4}},
   {L"test_rand",    "test_rand",       "r", { 11, false, false,-1, countof(TestRand), TestRand}},
+  {L"test_isinf",   "test_isinf",      "d", { 13, true,  true, -1, countof(TestIsInf), TestIsInf}},
+  {L"test_ibfe",    "test_ibfe",       "d", { 14, true,  true, -1, countof(TestIBFE), TestIBFE}},
   // Make this intrinsic have the same opcode as an hlsl intrinsic with an unsigned
   // counterpart for testing purposes.
   {L"test_unsigned","test_unsigned",   "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, -1, countof(TestUnsigned), TestUnsigned}},
@@ -259,6 +276,19 @@ public:
     return S_OK;
   }
 
+  __override HRESULT STDMETHODCALLTYPE
+  GetDxilOpCode(UINT opcode, _Outptr_ UINT *pDxilOpcode) {
+    if (opcode == 13) {
+      *pDxilOpcode = static_cast<UINT>(hlsl::OP::OpCode::IsInf);
+      return S_OK;
+    }
+    else if (opcode == 14) {
+      *pDxilOpcode = static_cast<UINT>(hlsl::OP::OpCode::Ibfe);
+      return S_OK;
+    }
+    return E_FAIL;
+  }
+
   Intrinsic *FindByOpcode(UINT opcode) {
     IntrinsicTable::SearchResult result;
     for (const IntrinsicTable &table : m_tables) {
@@ -404,6 +434,9 @@ public:
   TEST_METHOD(UnsignedOpcodeIsUnchanged);
   TEST_METHOD(ResourceExtensionIntrinsic);
   TEST_METHOD(NameLoweredWhenNoReplicationNeeded);
+  TEST_METHOD(DxilLoweringVector1);
+  TEST_METHOD(DxilLoweringVector2);
+  TEST_METHOD(DxilLoweringScalar);
 };
 
 TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
@@ -744,3 +777,57 @@ TEST_F(ExtensionTest, NameLoweredWhenNoReplicationNeeded) {
     disassembly.npos !=
     disassembly.find("call i32 @test_int("));
 }
+
+TEST_F(ExtensionTest, DxilLoweringVector1) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "int main(float v1 : V1) : SV_Target {\n"
+    "  return test_isinf(v1);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // Check that the extension was lowered to the correct dxil intrinsic.
+  static_assert(9 == (unsigned)hlsl::OP::OpCode::IsInf, "isinf opcode changed?");
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i1 @dx.op.isSpecialFloat.f32(i32 9"));
+}
+
+TEST_F(ExtensionTest, DxilLoweringVector2) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "int2 main(float2 v1 : V1) : SV_Target {\n"
+    "  return test_isinf(v1);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // Check that the extension was lowered to the correct dxil intrinsic.
+  static_assert(9 == (unsigned)hlsl::OP::OpCode::IsInf, "isinf opcode changed?");
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i1 @dx.op.isSpecialFloat.f32(i32 9"));
+}
+
+TEST_F(ExtensionTest, DxilLoweringScalar) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "int main(uint v1 : V1, uint v2 : V2, uint v3 : V3) : SV_Target {\n"
+    "  return test_ibfe(v1, v2, v3);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // Check that the extension was lowered to the correct dxil intrinsic.
+  static_assert(51 == (unsigned)hlsl::OP::OpCode::Ibfe, "ibfe opcode changed?");
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i32 @dx.op.tertiary.i32(i32 51"));
+}