Browse Source

[dxil2spv] Translate createHandle and bufferLoad (#4389)

Add support for translating createHandle and bufferLoad DXIL operations
to SPIR-V instructions. The most significant limitation of this current
implementation is the naive translation of descriptor set and binding
numbers, but it is sufficient for simple passthrough shaders.
Natalie Chouinard 3 years ago
parent
commit
316b849cfa

+ 3 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -105,6 +105,9 @@ public:
   SpirvVariable *addFnVar(QualType valueType, SourceLocation,
                           llvm::StringRef name = "", bool isPrecise = false,
                           SpirvInstruction *init = nullptr);
+  SpirvVariable *addFnVar(const spirv::SpirvType *valueType, SourceLocation,
+                          llvm::StringRef name = "", bool isPrecise = false,
+                          SpirvInstruction *init = nullptr);
 
   /// \brief Ends building of the current function. All basic blocks constructed
   /// from the beginning or after ending the previous function will be collected

+ 12 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -147,6 +147,18 @@ SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
   return var;
 }
 
+SpirvVariable *SpirvBuilder::addFnVar(const spirv::SpirvType *valueType,
+                                      SourceLocation loc, llvm::StringRef name,
+                                      bool isPrecise, SpirvInstruction *init) {
+  assert(function && "found detached local variable");
+  // TODO: Handle potential bindless array of an opaque type.
+  SpirvVariable *var = new (context) SpirvVariable(
+      valueType, loc, spv::StorageClass::Function, isPrecise, init);
+  var->setDebugName(name);
+  function->addVariable(var);
+  return var;
+}
+
 void SpirvBuilder::endFunction() {
   assert(function && "no active function");
   mod->addFunctionToListOfSortedModuleFunctions(function);

+ 55 - 14
tools/clang/test/Dxil2Spv/passthru-cs.ll

@@ -86,7 +86,7 @@ attributes #2 = { nounwind }
 ; ; SPIR-V
 ; ; Version: 1.0
 ; ; Generator: Google spiregg; 0
-; ; Bound: 22
+; ; Bound: 56
 ; ; Schema: 0
 ;                OpCapability Shader
 ;                OpMemoryModel Logical GLSL450
@@ -95,6 +95,11 @@ attributes #2 = { nounwind }
 ;                OpName %type_ByteAddressBuffer "type.ByteAddressBuffer"
 ;                OpName %type_RWByteAddressBuffer "type.RWByteAddressBuffer"
 ;                OpName %main "main"
+;                OpName %dx_types_ResRet_i32 "dx.types.ResRet.i32"
+;                OpDecorate %3 DescriptorSet 0
+;                OpDecorate %3 Binding 0
+;                OpDecorate %4 DescriptorSet 0
+;                OpDecorate %4 Binding 1
 ;                OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
 ;                OpDecorate %_runtimearr_uint ArrayStride 4
 ;                OpMemberDecorate %type_ByteAddressBuffer 0 Offset 0
@@ -105,29 +110,65 @@ attributes #2 = { nounwind }
 ;        %uint = OpTypeInt 32 0
 ;      %uint_0 = OpConstant %uint 0
 ;      %uint_2 = OpConstant %uint 2
-;      %v3uint = OpTypeVector %uint 3
-; %_ptr_Input_v3uint = OpTypePointer Input %v3uint
+;      %uint_1 = OpConstant %uint 1
+;      %uint_3 = OpConstant %uint 3
+;      %uint_4 = OpConstant %uint 4
 ; %_runtimearr_uint = OpTypeRuntimeArray %uint
 ; %type_ByteAddressBuffer = OpTypeStruct %_runtimearr_uint
 ; %_ptr_Uniform_type_ByteAddressBuffer = OpTypePointer Uniform %type_ByteAddressBuffer
 ; %type_RWByteAddressBuffer = OpTypeStruct %_runtimearr_uint
 ; %_ptr_Uniform_type_RWByteAddressBuffer = OpTypePointer Uniform %type_RWByteAddressBuffer
+;      %v3uint = OpTypeVector %uint 3
+; %_ptr_Input_v3uint = OpTypePointer Input %v3uint
 ;        %void = OpTypeVoid
-;          %16 = OpTypeFunction %void
+;          %19 = OpTypeFunction %void
+;         %int = OpTypeInt 32 1
+; %dx_types_ResRet_i32 = OpTypeStruct %int %int %int %int %int
+; %_ptr_Function_dx_types_ResRet_i32 = OpTypePointer Function %dx_types_ResRet_i32
 ; %_ptr_Input_uint = OpTypePointer Input %uint
+; %_ptr_Uniform_uint = OpTypePointer Uniform %uint
+; %_ptr_Function_int = OpTypePointer Function %int
+;           %3 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
+;           %4 = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
 ; %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
-;          %11 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
-;          %14 = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
-;        %main = OpFunction %void None %16
-;          %17 = OpLabel
-;          %19 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0
-;          %20 = OpLoad %uint %19
-;          %21 = OpShiftLeftLogical %uint %20 %uint_2
+;        %main = OpFunction %void None %19
+;          %20 = OpLabel
+;          %24 = OpVariable %_ptr_Function_dx_types_ResRet_i32 Function
+;          %26 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0
+;          %27 = OpLoad %uint %26
+;          %28 = OpShiftLeftLogical %uint %27 %uint_2
+;          %29 = OpIAdd %uint %28 %uint_0
+;          %31 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %29
+;          %32 = OpLoad %uint %31
+;          %34 = OpAccessChain %_ptr_Function_int %24 %uint_0
+;          %35 = OpBitcast %int %32
+;                OpStore %34 %35
+;          %36 = OpIAdd %uint %28 %uint_1
+;          %37 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %36
+;          %38 = OpLoad %uint %37
+;          %39 = OpAccessChain %_ptr_Function_int %24 %uint_1
+;          %40 = OpBitcast %int %38
+;                OpStore %39 %40
+;          %41 = OpIAdd %uint %28 %uint_2
+;          %42 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %41
+;          %43 = OpLoad %uint %42
+;          %44 = OpAccessChain %_ptr_Function_int %24 %uint_2
+;          %45 = OpBitcast %int %43
+;                OpStore %44 %45
+;          %46 = OpIAdd %uint %28 %uint_3
+;          %47 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %46
+;          %48 = OpLoad %uint %47
+;          %49 = OpAccessChain %_ptr_Function_int %24 %uint_3
+;          %50 = OpBitcast %int %48
+;                OpStore %49 %50
+;          %51 = OpIAdd %uint %28 %uint_4
+;          %52 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %51
+;          %53 = OpLoad %uint %52
+;          %54 = OpAccessChain %_ptr_Function_int %24 %uint_4
+;          %55 = OpBitcast %int %53
+;                OpStore %54 %55
 ;                OpReturn
 ;                OpFunctionEnd
 ; CHECK-ERRORS:
-; error: Unhandled DXIL opcode: CreateHandle
-; error: Unhandled DXIL opcode: CreateHandle
-; error: Unhandled DXIL opcode: BufferLoad
 ; error: Unhandled LLVM instruction:   %6 = extractvalue %dx.types.ResRet.i32 %5, 0
 ; error: Unhandled DXIL opcode: BufferStore

+ 144 - 21
tools/clang/tools/dxil2spv/lib/dxil2spv.cpp

@@ -17,6 +17,7 @@
 #include "dxc/Support/ErrorCodes.h"
 #include "dxc/Support/Global.h"
 
+#include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvType.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -114,6 +115,10 @@ int Translator::Run() {
   createStageIOVariables(program.GetInputSignature().GetElements(),
                          program.GetOutputSignature().GetElements());
 
+  // Add HLSL resources.
+  createModuleVariables(program.GetSRVs());
+  createModuleVariables(program.GetUAVs());
+
   // Create entry function.
   spirv::SpirvFunction *entryFunction =
       createEntryFunction(program.GetEntryFunction());
@@ -131,10 +136,6 @@ int Translator::Run() {
                                 {});
   }
 
-  // Add HLSL resources.
-  createModuleVariables(program.GetSRVs());
-  createModuleVariables(program.GetUAVs());
-
   // Contsruct the SPIR-V module.
   std::vector<uint32_t> m = spvBuilder.takeModuleForDxilToSpv();
 
@@ -204,8 +205,12 @@ void Translator::createModuleVariables(
     assert(hlslType->isPointerTy());
     llvm::Type *pointeeType =
         cast<llvm::PointerType>(hlslType)->getPointerElementType();
-    spvBuilder.addModuleVar(toSpirvType(pointeeType),
-                            spv::StorageClass::Uniform, false);
+    spirv::SpirvVariable *moduleVar = spvBuilder.addModuleVar(
+        toSpirvType(pointeeType), spv::StorageClass::Uniform, false);
+    spvBuilder.decorateDSetBinding(moduleVar, nextDescriptorSet,
+                                   nextBindingNo++);
+    resourceMap[{static_cast<unsigned>(resource->GetClass()),
+                 resource->GetID()}] = moduleVar;
   }
 }
 
@@ -263,6 +268,12 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
     case hlsl::DXIL::OpCode::ThreadId: {
       createThreadIdInstruction(callInstruction);
     } break;
+    case hlsl::DXIL::OpCode::CreateHandle: {
+      createHandleInstruction(callInstruction);
+    } break;
+    case hlsl::DXIL::OpCode::BufferLoad: {
+      createBufferLoadInstruction(callInstruction);
+    } break;
     default: {
       emitError("Unhandled DXIL opcode: %0")
           << hlsl::OP::GetOpCodeName(dxilOpcode);
@@ -281,10 +292,7 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
   }
   // Unhandled instruction type.
   else {
-    std::string instStr;
-    llvm::raw_string_ostream os(instStr);
-    instruction.print(os);
-    emitError("Unhandled LLVM instruction: %0") << os.str();
+    emitError("Unhandled LLVM instruction: %0", instruction);
   }
 }
 
@@ -344,8 +352,8 @@ void Translator::createStoreOutputInstruction(llvm::CallInst &instruction) {
   spirv::SpirvAccessChain *outputVarPtr =
       spvBuilder.createAccessChain(elemType, outputVar, {index}, {});
   spirv::SpirvInstruction *valueToStore =
-      instructionMap[instruction.getArgOperand(
-          hlsl::DXIL::OperandIndex::kStoreOutputValOpIdx)];
+      getSpirvInstruction(instruction.getArgOperand(
+          hlsl::DXIL::OperandIndex::kStoreOutputValOpIdx));
   spvBuilder.createStore(outputVarPtr, valueToStore, {});
 }
 
@@ -384,15 +392,8 @@ void Translator::createBinaryOpInstruction(llvm::BinaryOperator &instruction) {
   // Shift left instruction.
   case llvm::Instruction::Shl: {
     // Value to be shifted.
-    spirv::SpirvInstruction *val = instructionMap[instruction.getOperand(0)];
-    if (!val) {
-      std::string instStr;
-      llvm::raw_string_ostream os(instStr);
-      instruction.print(os);
-      emitError("Could not find translation of instruction operand 0: %0")
-          << os.str();
-      return;
-    }
+    spirv::SpirvInstruction *val =
+        getSpirvInstruction(instruction.getOperand(0));
 
     // Amount to shift by.
     const spirv::IntegerType *uint32 = spvContext.getUIntType(32);
@@ -412,6 +413,108 @@ void Translator::createBinaryOpInstruction(llvm::BinaryOperator &instruction) {
   instructionMap[&instruction] = result;
 }
 
+void Translator::createHandleInstruction(llvm::CallInst &instruction) {
+  unsigned resourceClass =
+      cast<llvm::ConstantInt>(
+          instruction.getArgOperand(
+              hlsl::DXIL::OperandIndex::kCreateHandleResClassOpIdx))
+          ->getLimitedValue();
+  unsigned resourceRangeId =
+      cast<llvm::ConstantInt>(
+          instruction.getArgOperand(
+              hlsl::DXIL::OperandIndex::kCreateHandleResIDOpIdx))
+          ->getLimitedValue();
+  spirv::SpirvVariable *inputVar =
+      resourceMap[{resourceClass, resourceRangeId}];
+  if (!inputVar) {
+    emitError("No resource found corresponding to handle: %0", instruction);
+    return;
+  }
+
+  instructionMap[&instruction] = inputVar;
+}
+
+void Translator::createBufferLoadInstruction(llvm::CallInst &instruction) {
+  // TODO: Extend this function to work with all buffer types on which it is
+  // used, not just ByteAddressBuffers.
+
+  // ByteAddressBuffers are represented as a struct with one member that is a
+  // runtime array of unsigned integers. The SPIR-V OpAccessChain instruction is
+  // then used to access that offset, and OpLoad is used to load integers
+  // into a corresponding struct.
+
+  // clang-format off
+  // For example, the following DXIL instruction:
+  //   %dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 } 
+  //   %ret = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle %res, i32 %index, i32 undef)
+
+  // would translate to the following SPIR-V instructions:
+  //   %dx_types_ResRet_i32 = OpTypeStruct %int %int %int %int %int
+  //   %_ptr_Function_dx_types_ResRet_i32 = OpTypePointer Function %dx_types_ResRet_i32
+  //   %ret = OpVariable %_ptr_Function_dx_types_ResRet_i32 Function
+  //     %i = OpLoad %uint %index
+  // for %offset = {0, 1, 2, 3, 4}, repeat:
+  //    %v0 = OpIAdd %uint %i %offset
+  //    %v1 = OpAccessChain %_ptr_Uniform_uint %res %offset %v0
+  //    %v2 = OpLoad %uint %v1
+  //    %v3 = OpAccessChain %_ptr_Function_int %ret %offset
+  //    %v4 = OpBitcast %int %v2
+  //          OpStore %v3 %v4
+  // clang-format on
+
+  // Get module input variable corresponding to given DXIL handle.
+  spirv::SpirvInstruction *inputVar =
+      getSpirvInstruction(instruction.getArgOperand(
+          hlsl::DXIL::OperandIndex::kBufferLoadHandleOpIdx));
+
+  // Translate DXIL instruction return type (expected to be a struct of
+  // integers) to a SPIR-V type.
+  const spirv::SpirvType *returnType = toSpirvType(instruction.getType());
+  assert(isa<spirv::StructType>(returnType));
+  const spirv::StructType *structType = cast<spirv::StructType>(returnType);
+
+  // Create a return variable to initialize with values loaded from the buffer.
+  spirv::SpirvVariable *returnVar =
+      spvBuilder.addFnVar(structType, {}, "", false, nullptr);
+
+  // Translate indices into resource buffer to SPIR-V instructions.
+  auto uint32 = spvContext.getUIntType(32);
+  spirv::SpirvConstant *indexIntoStruct =
+      spvBuilder.getConstantInt(uint32, llvm::APInt(32, 0));
+  spirv::SpirvInstruction *baseArrayIndex =
+      getSpirvInstruction(instruction.getArgOperand(
+          hlsl::DXIL::OperandIndex::kBufferLoadCoord0OpIdx));
+
+  // Initialize each field in the struct.
+  for (size_t i = 0; i < structType->getFields().size(); i++) {
+    // Add offset for current field.
+    spirv::SpirvConstant *fieldOffset =
+        spvBuilder.getConstantInt(uint32, llvm::APInt(32, i));
+    spirv::SpirvInstruction *indexIntoArray = spvBuilder.createBinaryOp(
+        spv::Op::OpIAdd, uint32, baseArrayIndex, fieldOffset, {});
+
+    // Create access chain and load.
+    spirv::SpirvAccessChain *loadPtr = spvBuilder.createAccessChain(
+        uint32, inputVar, {indexIntoStruct, indexIntoArray}, {});
+    spirv::SpirvInstruction *loadInstr =
+        spvBuilder.createLoad(uint32, loadPtr, {});
+
+    // Create access chain and store.
+    const spirv::SpirvType *fieldType = structType->getFields()[i].type;
+    spirv::SpirvAccessChain *storePtr =
+        spvBuilder.createAccessChain(fieldType, returnVar, fieldOffset, {});
+    // LLVM types are signless, so type conversions are not 1-to-1. A bitcast on
+    // the unsigned integer may be necessary before storing.
+    spirv::SpirvInstruction *valToStore =
+        fieldType == uint32 ? loadInstr
+                            : spvBuilder.createUnaryOp(
+                                  spv::Op::OpBitcast, fieldType, loadInstr, {});
+    spvBuilder.createStore(storePtr, valToStore, {});
+  }
+
+  instructionMap[&instruction] = returnVar;
+}
+
 bool Translator::spirvToolsValidate(std::vector<uint32_t> *mod,
                                     std::string *messages) {
   spvtools::SpirvTools tools(featureManager.getTargetEnv());
@@ -500,6 +603,17 @@ const spirv::SpirvType *Translator::toSpirvType(llvm::StructType *structType) {
   return spvContext.getStructType(fields, name);
 }
 
+spirv::SpirvInstruction *
+Translator::getSpirvInstruction(llvm::Value *instruction) {
+  spirv::SpirvInstruction *spirvInstruction = instructionMap[instruction];
+  if (!spirvInstruction) {
+    emitError("Expected SPIR-V instruction not found for DXIL instruction: %0",
+              *instruction);
+    return nullptr;
+  }
+  return spirvInstruction;
+}
+
 template <unsigned N>
 DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
   const auto diagId =
@@ -507,5 +621,14 @@ DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
   return diagnosticsEngine.Report({}, diagId);
 }
 
+template <unsigned N>
+DiagnosticBuilder Translator::emitError(const char (&message)[N],
+                                        llvm::Value &value) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  value.print(os);
+  return emitError(message) << os.str();
+}
+
 } // namespace dxil2spv
 } // namespace clang

+ 21 - 0
tools/clang/tools/dxil2spv/lib/dxil2spv.h

@@ -47,9 +47,17 @@ private:
   llvm::DenseMap<unsigned, spirv::SpirvVariable *> inputSignatureElementMap;
   llvm::DenseMap<unsigned, spirv::SpirvVariable *> outputSignatureElementMap;
 
+  // Map from HLSL resource class and range ID to corresponding SPIR-V variable.
+  llvm::DenseMap<std::pair<unsigned, unsigned>, spirv::SpirvVariable *>
+      resourceMap;
+
   // Map from DXIL instructions (values) to SPIR-V instructions.
   llvm::DenseMap<llvm::Value *, spirv::SpirvInstruction *> instructionMap;
 
+  // Get corresponding SPIR-V instruction for a given DXIL instruction, with
+  // error checking.
+  spirv::SpirvInstruction *getSpirvInstruction(llvm::Value *instruction);
+
   // Create SPIR-V stage IO variable from DXIL input and output signatures.
   void createStageIOVariables(
       const std::vector<std::unique_ptr<hlsl::DxilSignatureElement>>
@@ -73,6 +81,8 @@ private:
   void createStoreOutputInstruction(llvm::CallInst &instruction);
   void createThreadIdInstruction(llvm::CallInst &instruction);
   void createBinaryOpInstruction(llvm::BinaryOperator &instruction);
+  void createHandleInstruction(llvm::CallInst &instruction);
+  void createBufferLoadInstruction(llvm::CallInst &instruction);
 
   // SPIR-V Tools wrapper functions.
   bool spirvToolsValidate(std::vector<uint32_t> *mod, std::string *messages);
@@ -83,7 +93,18 @@ private:
   const spirv::SpirvType *toSpirvType(llvm::Type *llvmType);
   const spirv::SpirvType *toSpirvType(llvm::StructType *structType);
 
+  // TODO: These variables are used for a temporary hack to assign descriptor
+  // set and binding numbers that works only for the most simple cases (always
+  // use descriptor set 0, increment binding number for each resource). Further
+  // work is needed to translate non-trivial shaders.
+  unsigned nextDescriptorSet = 0;
+  unsigned nextBindingNo = 0;
+
+  // Helper diagnostic functions for emitting error messages.
   template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]);
+  template <unsigned N>
+  DiagnosticBuilder emitError(const char (&message)[N],
+                              llvm::Value &instruction);
 };
 
 } // namespace dxil2spv