Forráskód Böngészése

[spirv] Avoid generating duplicated debug names for images (#597)

Lei Zhang 8 éve
szülő
commit
10a4a84f68

+ 7 - 2
tools/clang/include/clang/SPIRV/Structure.h

@@ -20,6 +20,7 @@
 
 #include <deque>
 #include <memory>
+#include <set>
 #include <string>
 #include <vector>
 
@@ -240,6 +241,9 @@ struct DebugName {
   inline DebugName(uint32_t id, std::string targetName,
                    llvm::Optional<uint32_t> index = llvm::None);
 
+  bool operator==(const DebugName &that) const;
+  bool operator<(const DebugName &that) const;
+
   const uint32_t targetId;
   const std::string name;
   const llvm::Optional<uint32_t> memberIndex;
@@ -344,7 +348,7 @@ private:
   std::vector<EntryPoint> entryPoints;
   std::vector<Instruction> executionModes;
   // TODO: source code debug information
-  std::vector<DebugName> debugNames;
+  std::set<DebugName> debugNames;
   llvm::SetVector<std::pair<uint32_t, const Decoration *>> decorations;
 
   // Note that types and constants are interdependent; Types like arrays have
@@ -488,8 +492,9 @@ void SPIRVModule::addExecutionMode(Instruction &&execMode) {
 
 void SPIRVModule::addDebugName(uint32_t targetId, llvm::StringRef name,
                                llvm::Optional<uint32_t> memberIndex) {
+
   if (!name.empty()) {
-    debugNames.emplace_back(targetId, name, memberIndex);
+    debugNames.insert(DebugName(targetId, name, memberIndex));
   }
 }
 

+ 34 - 31
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -598,8 +598,6 @@ ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
   bool isRegistered = false;
   const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
   theModule.addType(type, typeId);
-  // TODO: Probably we should check duplication and do nothing if trying to add
-  // the same debug name for the same entity in addDebugName().
   if (!isRegistered) {
     theModule.addDebugName(typeId, structName);
     if (!fieldNames.empty()) {
@@ -642,7 +640,8 @@ uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
                                      spv::ImageFormat format) {
   const Type *type = Type::getImage(theContext, sampledType, dim, depth,
                                     isArray, ms, sampled, format);
-  const uint32_t typeId = theContext.getResultIdForType(type);
+  bool isRegistered = false;
+  const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
   theModule.addType(type, typeId);
 
   switch (format) {
@@ -678,35 +677,39 @@ uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
   if (dim == spv::Dim::Buffer)
     requireCapability(spv::Capability::SampledBuffer);
 
-  const char *dimStr = "";
-  switch (dim) {
-  case spv::Dim::Dim1D:
-    dimStr = "1d.";
-    break;
-  case spv::Dim::Dim2D:
-    dimStr = "2d.";
-    break;
-  case spv::Dim::Dim3D:
-    dimStr = "3d.";
-    break;
-  case spv::Dim::Cube:
-    dimStr = "cube.";
-    break;
-  case spv::Dim::Rect:
-    dimStr = "rect.";
-    break;
-  case spv::Dim::Buffer:
-    dimStr = "buffer.";
-    break;
-  case spv::Dim::SubpassData:
-    dimStr = "subpass.";
-    break;
-  default:
-    break;
+  // Skip constructing the debug name if we have already done it before.
+  if (!isRegistered) {
+    const char *dimStr = "";
+    switch (dim) {
+    case spv::Dim::Dim1D:
+      dimStr = "1d.";
+      break;
+    case spv::Dim::Dim2D:
+      dimStr = "2d.";
+      break;
+    case spv::Dim::Dim3D:
+      dimStr = "3d.";
+      break;
+    case spv::Dim::Cube:
+      dimStr = "cube.";
+      break;
+    case spv::Dim::Rect:
+      dimStr = "rect.";
+      break;
+    case spv::Dim::Buffer:
+      dimStr = "buffer.";
+      break;
+    case spv::Dim::SubpassData:
+      dimStr = "subpass.";
+      break;
+    default:
+      break;
+    }
+
+    std::string name =
+        std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
+    theModule.addDebugName(typeId, name);
   }
-  std::string name =
-      std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
-  theModule.addDebugName(typeId, name);
 
   return typeId;
 }

+ 42 - 10
tools/clang/lib/SPIRV/Structure.cpp

@@ -144,9 +144,9 @@ void Function::take(InstBuilder *builder) {
   // validation rules.
   std::vector<BasicBlock *> orderedBlocks;
   if (!blocks.empty()) {
-    BlockReadableOrderVisitor(
-        [&orderedBlocks](BasicBlock *block) { orderedBlocks.push_back(block); })
-        .visit(blocks.front().get());
+    BlockReadableOrderVisitor([&orderedBlocks](BasicBlock *block) {
+      orderedBlocks.push_back(block);
+    }).visit(blocks.front().get());
   }
 
   // Write out all basic blocks.
@@ -168,9 +168,9 @@ void Function::addVariable(uint32_t varType, uint32_t varId,
 
 void Function::getReachableBasicBlocks(std::vector<BasicBlock *> *bbVec) const {
   if (!blocks.empty()) {
-    BlockReadableOrderVisitor(
-        [&bbVec](BasicBlock *block) { bbVec->push_back(block); })
-        .visit(blocks.front().get());
+    BlockReadableOrderVisitor([&bbVec](BasicBlock *block) {
+      bbVec->push_back(block);
+    }).visit(blocks.front().get());
   }
 }
 
@@ -191,6 +191,37 @@ void Header::collect(const WordConsumer &consumer) {
   consumer(std::move(words));
 }
 
+bool DebugName::operator==(const DebugName &that) const {
+  if (targetId == that.targetId && name == that.name) {
+    if (memberIndex.hasValue()) {
+      return that.memberIndex.hasValue() &&
+             memberIndex.getValue() == that.memberIndex.getValue();
+    }
+    return !that.memberIndex.hasValue();
+  }
+  return false;
+}
+
+bool DebugName::operator<(const DebugName &that) const {
+  // Sort according to target id first
+  if (targetId != that.targetId)
+    return targetId < that.targetId;
+
+  if (memberIndex.hasValue()) {
+    // Sort member decorations according to member index
+    if (that.memberIndex.hasValue())
+      return memberIndex.getValue() < that.memberIndex.getValue();
+    // Decorations on the id itself goes before those on its members
+    return false;
+  }
+
+  // Decorations on the id itself goes before those on its members
+  if (that.memberIndex.hasValue())
+    return true;
+
+  return name < that.name;
+}
+
 // === Module implementations ===
 
 bool SPIRVModule::isEmpty() const {
@@ -255,10 +286,11 @@ void SPIRVModule::take(InstBuilder *builder) {
     consumer(inst.take());
   }
 
-  // BasicBlock debug names should be emitted only for blocks that are reachable.
+  // BasicBlock debug names should be emitted only for blocks that are
+  // reachable.
   // The debug name for a basic block is stored in the basic block object.
-  std::vector<BasicBlock*> reachableBasicBlocks;
-  for (const auto& fn : functions)
+  std::vector<BasicBlock *> reachableBasicBlocks;
+  for (const auto &fn : functions)
     fn->getReachableBasicBlocks(&reachableBasicBlocks);
   for (BasicBlock *bb : reachableBasicBlocks)
     if (!bb->getDebugName().empty())
@@ -335,7 +367,7 @@ void SPIRVModule::takeConstantForArrayType(const Type &arrType,
   // If it finds the constant, feeds it into the consumer, and removes it
   // from the constants collection.
   constants.remove_if([&consumer, arrayLengthResultId](
-                          std::pair<const Constant *, uint32_t> &item) {
+      std::pair<const Constant *, uint32_t> &item) {
     const bool isArrayLengthConstant = (item.second == arrayLengthResultId);
     if (isArrayLengthConstant)
       consumer(item.first->withResultId(item.second));

+ 1 - 1
tools/clang/test/CodeGenSPIRV/constant-ps.hlsl2spv

@@ -16,9 +16,9 @@ float4 main(): SV_Target
 // OpEntryPoint Fragment %main "main" %out_var_SV_Target
 // OpExecutionMode %main OriginUpperLeft
 // OpName %bb_entry "bb.entry"
+// OpName %src_main "src.main"
 // OpName %main "main"
 // OpName %out_var_SV_Target "out.var.SV_Target"
-// OpName %src_main "src.main"
 // OpDecorate %out_var_SV_Target Location 0
 // %void = OpTypeVoid
 // %3 = OpTypeFunction %void

+ 1 - 1
tools/clang/test/CodeGenSPIRV/empty-void-main.hlsl2spv

@@ -15,8 +15,8 @@ void main()
 // OpEntryPoint Fragment %main "main"
 // OpExecutionMode %main OriginUpperLeft
 // OpName %bb_entry "bb.entry"
-// OpName %main "main"
 // OpName %src_main "src.main"
+// OpName %main "main"
 // %void = OpTypeVoid
 // %3 = OpTypeFunction %void
 // %main = OpFunction %void None %3

+ 1 - 1
tools/clang/test/CodeGenSPIRV/passthru-cs.hlsl2spv

@@ -32,9 +32,9 @@ void main( uint3 DTid : SV_DispatchThreadID )
 // OpName %Buffer0 "Buffer0"
 // OpName %type_RWByteAddressBuffer "type.RWByteAddressBuffer"
 // OpName %BufferOut "BufferOut"
+// OpName %src_main "src.main"
 // OpName %main "main"
 // OpName %param_var_DTid "param.var.DTid"
-// OpName %src_main "src.main"
 // OpName %DTid "DTid"
 // OpName %word "word"
 // OpDecorate %_runtimearr_uint ArrayStride 4

+ 1 - 1
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -16,11 +16,11 @@ float4 main(float4 input: COLOR): SV_Target
 // OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_Target
 // OpExecutionMode %main OriginUpperLeft
 // OpName %bb_entry "bb.entry"
+// OpName %src_main "src.main"
 // OpName %main "main"
 // OpName %param_var_input "param.var.input"
 // OpName %in_var_COLOR "in.var.COLOR"
 // OpName %out_var_SV_Target "out.var.SV_Target"
-// OpName %src_main "src.main"
 // OpName %input "input"
 // OpDecorate %in_var_COLOR Location 0
 // OpDecorate %out_var_SV_Target Location 0

+ 1 - 1
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -23,6 +23,7 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // OpMemoryModel Logical GLSL450
 // OpEntryPoint Vertex %VSmain "VSmain" %in_var_POSITION %in_var_COLOR %gl_Position %out_var_COLOR
 // OpName %bb_entry "bb.entry"
+// OpName %src_VSmain "src.VSmain"
 // OpName %VSmain "VSmain"
 // OpName %param_var_position "param.var.position"
 // OpName %in_var_POSITION "in.var.POSITION"
@@ -32,7 +33,6 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // OpMemberName %PSInput 0 "position"
 // OpMemberName %PSInput 1 "color"
 // OpName %out_var_COLOR "out.var.COLOR"
-// OpName %src_VSmain "src.VSmain"
 // OpName %position "position"
 // OpName %color "color"
 // OpName %result "result"