|
@@ -17,6 +17,7 @@
|
|
|
#include "dxc/Support/ErrorCodes.h"
|
|
|
#include "dxc/Support/Global.h"
|
|
|
|
|
|
+#include "clang/SPIRV/SpirvInstruction.h"
|
|
|
#include "clang/SPIRV/SpirvType.h"
|
|
|
#include "llvm/IR/LLVMContext.h"
|
|
|
#include "llvm/IR/Module.h"
|
|
@@ -114,6 +115,10 @@ int Translator::Run() {
|
|
|
createStageIOVariables(program.GetInputSignature().GetElements(),
|
|
|
program.GetOutputSignature().GetElements());
|
|
|
|
|
|
+ // Add HLSL resources.
|
|
|
+ createModuleVariables(program.GetSRVs());
|
|
|
+ createModuleVariables(program.GetUAVs());
|
|
|
+
|
|
|
// Create entry function.
|
|
|
spirv::SpirvFunction *entryFunction =
|
|
|
createEntryFunction(program.GetEntryFunction());
|
|
@@ -131,10 +136,6 @@ int Translator::Run() {
|
|
|
{});
|
|
|
}
|
|
|
|
|
|
- // Add HLSL resources.
|
|
|
- createModuleVariables(program.GetSRVs());
|
|
|
- createModuleVariables(program.GetUAVs());
|
|
|
-
|
|
|
// Contsruct the SPIR-V module.
|
|
|
std::vector<uint32_t> m = spvBuilder.takeModuleForDxilToSpv();
|
|
|
|
|
@@ -204,8 +205,12 @@ void Translator::createModuleVariables(
|
|
|
assert(hlslType->isPointerTy());
|
|
|
llvm::Type *pointeeType =
|
|
|
cast<llvm::PointerType>(hlslType)->getPointerElementType();
|
|
|
- spvBuilder.addModuleVar(toSpirvType(pointeeType),
|
|
|
- spv::StorageClass::Uniform, false);
|
|
|
+ spirv::SpirvVariable *moduleVar = spvBuilder.addModuleVar(
|
|
|
+ toSpirvType(pointeeType), spv::StorageClass::Uniform, false);
|
|
|
+ spvBuilder.decorateDSetBinding(moduleVar, nextDescriptorSet,
|
|
|
+ nextBindingNo++);
|
|
|
+ resourceMap[{static_cast<unsigned>(resource->GetClass()),
|
|
|
+ resource->GetID()}] = moduleVar;
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -263,6 +268,12 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
|
|
|
case hlsl::DXIL::OpCode::ThreadId: {
|
|
|
createThreadIdInstruction(callInstruction);
|
|
|
} break;
|
|
|
+ case hlsl::DXIL::OpCode::CreateHandle: {
|
|
|
+ createHandleInstruction(callInstruction);
|
|
|
+ } break;
|
|
|
+ case hlsl::DXIL::OpCode::BufferLoad: {
|
|
|
+ createBufferLoadInstruction(callInstruction);
|
|
|
+ } break;
|
|
|
default: {
|
|
|
emitError("Unhandled DXIL opcode: %0")
|
|
|
<< hlsl::OP::GetOpCodeName(dxilOpcode);
|
|
@@ -281,10 +292,7 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
|
|
|
}
|
|
|
// Unhandled instruction type.
|
|
|
else {
|
|
|
- std::string instStr;
|
|
|
- llvm::raw_string_ostream os(instStr);
|
|
|
- instruction.print(os);
|
|
|
- emitError("Unhandled LLVM instruction: %0") << os.str();
|
|
|
+ emitError("Unhandled LLVM instruction: %0", instruction);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -344,8 +352,8 @@ void Translator::createStoreOutputInstruction(llvm::CallInst &instruction) {
|
|
|
spirv::SpirvAccessChain *outputVarPtr =
|
|
|
spvBuilder.createAccessChain(elemType, outputVar, {index}, {});
|
|
|
spirv::SpirvInstruction *valueToStore =
|
|
|
- instructionMap[instruction.getArgOperand(
|
|
|
- hlsl::DXIL::OperandIndex::kStoreOutputValOpIdx)];
|
|
|
+ getSpirvInstruction(instruction.getArgOperand(
|
|
|
+ hlsl::DXIL::OperandIndex::kStoreOutputValOpIdx));
|
|
|
spvBuilder.createStore(outputVarPtr, valueToStore, {});
|
|
|
}
|
|
|
|
|
@@ -384,15 +392,8 @@ void Translator::createBinaryOpInstruction(llvm::BinaryOperator &instruction) {
|
|
|
// Shift left instruction.
|
|
|
case llvm::Instruction::Shl: {
|
|
|
// Value to be shifted.
|
|
|
- spirv::SpirvInstruction *val = instructionMap[instruction.getOperand(0)];
|
|
|
- if (!val) {
|
|
|
- std::string instStr;
|
|
|
- llvm::raw_string_ostream os(instStr);
|
|
|
- instruction.print(os);
|
|
|
- emitError("Could not find translation of instruction operand 0: %0")
|
|
|
- << os.str();
|
|
|
- return;
|
|
|
- }
|
|
|
+ spirv::SpirvInstruction *val =
|
|
|
+ getSpirvInstruction(instruction.getOperand(0));
|
|
|
|
|
|
// Amount to shift by.
|
|
|
const spirv::IntegerType *uint32 = spvContext.getUIntType(32);
|
|
@@ -412,6 +413,108 @@ void Translator::createBinaryOpInstruction(llvm::BinaryOperator &instruction) {
|
|
|
instructionMap[&instruction] = result;
|
|
|
}
|
|
|
|
|
|
+void Translator::createHandleInstruction(llvm::CallInst &instruction) {
|
|
|
+ unsigned resourceClass =
|
|
|
+ cast<llvm::ConstantInt>(
|
|
|
+ instruction.getArgOperand(
|
|
|
+ hlsl::DXIL::OperandIndex::kCreateHandleResClassOpIdx))
|
|
|
+ ->getLimitedValue();
|
|
|
+ unsigned resourceRangeId =
|
|
|
+ cast<llvm::ConstantInt>(
|
|
|
+ instruction.getArgOperand(
|
|
|
+ hlsl::DXIL::OperandIndex::kCreateHandleResIDOpIdx))
|
|
|
+ ->getLimitedValue();
|
|
|
+ spirv::SpirvVariable *inputVar =
|
|
|
+ resourceMap[{resourceClass, resourceRangeId}];
|
|
|
+ if (!inputVar) {
|
|
|
+ emitError("No resource found corresponding to handle: %0", instruction);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ instructionMap[&instruction] = inputVar;
|
|
|
+}
|
|
|
+
|
|
|
+void Translator::createBufferLoadInstruction(llvm::CallInst &instruction) {
|
|
|
+ // TODO: Extend this function to work with all buffer types on which it is
|
|
|
+ // used, not just ByteAddressBuffers.
|
|
|
+
|
|
|
+ // ByteAddressBuffers are represented as a struct with one member that is a
|
|
|
+ // runtime array of unsigned integers. The SPIR-V OpAccessChain instruction is
|
|
|
+ // then used to access that offset, and OpLoad is used to load integers
|
|
|
+ // into a corresponding struct.
|
|
|
+
|
|
|
+ // clang-format off
|
|
|
+ // For example, the following DXIL instruction:
|
|
|
+ // %dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 }
|
|
|
+ // %ret = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle %res, i32 %index, i32 undef)
|
|
|
+
|
|
|
+ // would translate to the following SPIR-V instructions:
|
|
|
+ // %dx_types_ResRet_i32 = OpTypeStruct %int %int %int %int %int
|
|
|
+ // %_ptr_Function_dx_types_ResRet_i32 = OpTypePointer Function %dx_types_ResRet_i32
|
|
|
+ // %ret = OpVariable %_ptr_Function_dx_types_ResRet_i32 Function
|
|
|
+ // %i = OpLoad %uint %index
|
|
|
+ // for %offset = {0, 1, 2, 3, 4}, repeat:
|
|
|
+ // %v0 = OpIAdd %uint %i %offset
|
|
|
+ // %v1 = OpAccessChain %_ptr_Uniform_uint %res %offset %v0
|
|
|
+ // %v2 = OpLoad %uint %v1
|
|
|
+ // %v3 = OpAccessChain %_ptr_Function_int %ret %offset
|
|
|
+ // %v4 = OpBitcast %int %v2
|
|
|
+ // OpStore %v3 %v4
|
|
|
+ // clang-format on
|
|
|
+
|
|
|
+ // Get module input variable corresponding to given DXIL handle.
|
|
|
+ spirv::SpirvInstruction *inputVar =
|
|
|
+ getSpirvInstruction(instruction.getArgOperand(
|
|
|
+ hlsl::DXIL::OperandIndex::kBufferLoadHandleOpIdx));
|
|
|
+
|
|
|
+ // Translate DXIL instruction return type (expected to be a struct of
|
|
|
+ // integers) to a SPIR-V type.
|
|
|
+ const spirv::SpirvType *returnType = toSpirvType(instruction.getType());
|
|
|
+ assert(isa<spirv::StructType>(returnType));
|
|
|
+ const spirv::StructType *structType = cast<spirv::StructType>(returnType);
|
|
|
+
|
|
|
+ // Create a return variable to initialize with values loaded from the buffer.
|
|
|
+ spirv::SpirvVariable *returnVar =
|
|
|
+ spvBuilder.addFnVar(structType, {}, "", false, nullptr);
|
|
|
+
|
|
|
+ // Translate indices into resource buffer to SPIR-V instructions.
|
|
|
+ auto uint32 = spvContext.getUIntType(32);
|
|
|
+ spirv::SpirvConstant *indexIntoStruct =
|
|
|
+ spvBuilder.getConstantInt(uint32, llvm::APInt(32, 0));
|
|
|
+ spirv::SpirvInstruction *baseArrayIndex =
|
|
|
+ getSpirvInstruction(instruction.getArgOperand(
|
|
|
+ hlsl::DXIL::OperandIndex::kBufferLoadCoord0OpIdx));
|
|
|
+
|
|
|
+ // Initialize each field in the struct.
|
|
|
+ for (size_t i = 0; i < structType->getFields().size(); i++) {
|
|
|
+ // Add offset for current field.
|
|
|
+ spirv::SpirvConstant *fieldOffset =
|
|
|
+ spvBuilder.getConstantInt(uint32, llvm::APInt(32, i));
|
|
|
+ spirv::SpirvInstruction *indexIntoArray = spvBuilder.createBinaryOp(
|
|
|
+ spv::Op::OpIAdd, uint32, baseArrayIndex, fieldOffset, {});
|
|
|
+
|
|
|
+ // Create access chain and load.
|
|
|
+ spirv::SpirvAccessChain *loadPtr = spvBuilder.createAccessChain(
|
|
|
+ uint32, inputVar, {indexIntoStruct, indexIntoArray}, {});
|
|
|
+ spirv::SpirvInstruction *loadInstr =
|
|
|
+ spvBuilder.createLoad(uint32, loadPtr, {});
|
|
|
+
|
|
|
+ // Create access chain and store.
|
|
|
+ const spirv::SpirvType *fieldType = structType->getFields()[i].type;
|
|
|
+ spirv::SpirvAccessChain *storePtr =
|
|
|
+ spvBuilder.createAccessChain(fieldType, returnVar, fieldOffset, {});
|
|
|
+ // LLVM types are signless, so type conversions are not 1-to-1. A bitcast on
|
|
|
+ // the unsigned integer may be necessary before storing.
|
|
|
+ spirv::SpirvInstruction *valToStore =
|
|
|
+ fieldType == uint32 ? loadInstr
|
|
|
+ : spvBuilder.createUnaryOp(
|
|
|
+ spv::Op::OpBitcast, fieldType, loadInstr, {});
|
|
|
+ spvBuilder.createStore(storePtr, valToStore, {});
|
|
|
+ }
|
|
|
+
|
|
|
+ instructionMap[&instruction] = returnVar;
|
|
|
+}
|
|
|
+
|
|
|
bool Translator::spirvToolsValidate(std::vector<uint32_t> *mod,
|
|
|
std::string *messages) {
|
|
|
spvtools::SpirvTools tools(featureManager.getTargetEnv());
|
|
@@ -500,6 +603,17 @@ const spirv::SpirvType *Translator::toSpirvType(llvm::StructType *structType) {
|
|
|
return spvContext.getStructType(fields, name);
|
|
|
}
|
|
|
|
|
|
+spirv::SpirvInstruction *
|
|
|
+Translator::getSpirvInstruction(llvm::Value *instruction) {
|
|
|
+ spirv::SpirvInstruction *spirvInstruction = instructionMap[instruction];
|
|
|
+ if (!spirvInstruction) {
|
|
|
+ emitError("Expected SPIR-V instruction not found for DXIL instruction: %0",
|
|
|
+ *instruction);
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+ return spirvInstruction;
|
|
|
+}
|
|
|
+
|
|
|
template <unsigned N>
|
|
|
DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
|
|
|
const auto diagId =
|
|
@@ -507,5 +621,14 @@ DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
|
|
|
return diagnosticsEngine.Report({}, diagId);
|
|
|
}
|
|
|
|
|
|
+template <unsigned N>
|
|
|
+DiagnosticBuilder Translator::emitError(const char (&message)[N],
|
|
|
+ llvm::Value &value) {
|
|
|
+ std::string str;
|
|
|
+ llvm::raw_string_ostream os(str);
|
|
|
+ value.print(os);
|
|
|
+ return emitError(message) << os.str();
|
|
|
+}
|
|
|
+
|
|
|
} // namespace dxil2spv
|
|
|
} // namespace clang
|