Browse Source

[spirv] Use decorations (rather than layout rule) to distinguish types.

 structure types with same members and different layout rules may still
 be the same type. We should not emit two different types for them.
Ehsan Nasiri 6 years ago
parent
commit
405f27d159

+ 101 - 43
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -22,24 +22,65 @@ class SpirvBasicBlock;
 class SpirvType;
 class SpirvType;
 class SpirvBuilder;
 class SpirvBuilder;
 
 
-// Provides DenseMapInfo for SpirvLayoutRule so that we can use it as key to
-// DenseMap.
-//
-// Mostly from DenseMapInfo<unsigned> in DenseMapInfo.h.
-struct SpirvLayoutRuleDenseMapInfo {
-  static inline SpirvLayoutRule getEmptyKey() { return SpirvLayoutRule::Max; }
-  static inline SpirvLayoutRule getTombstoneKey() {
-    return SpirvLayoutRule::Max;
-  }
-  static unsigned getHashValue(const SpirvLayoutRule &Val) {
-    return static_cast<unsigned>(Val) * 37U;
-  }
-  static bool isEqual(const SpirvLayoutRule &LHS, const SpirvLayoutRule &RHS) {
-    return LHS == RHS;
-  }
-};
-
 class EmitTypeHandler {
 class EmitTypeHandler {
+public:
+  struct DecorationInfo {
+    DecorationInfo(spv::Decoration decor, llvm::ArrayRef<uint32_t> params = {},
+                   llvm::Optional<uint32_t> index = llvm::None)
+        : decoration(decor), decorationParams(params.begin(), params.end()),
+          memberIndex(index) {}
+
+    bool operator==(const DecorationInfo &other) const {
+      return decoration == other.decoration &&
+             decorationParams == other.decorationParams &&
+             memberIndex.hasValue() == other.memberIndex.hasValue() &&
+             (!memberIndex.hasValue() ||
+              memberIndex.getValue() == other.memberIndex.getValue());
+    }
+
+    spv::Decoration decoration;
+    llvm::SmallVector<uint32_t, 4> decorationParams;
+    llvm::Optional<uint32_t> memberIndex;
+  };
+
+  using DecorationList = llvm::SmallVector<DecorationInfo, 4>;
+
+  // Provides DenseMapInfo for SpirvLayoutRule so that we can use it as key to
+  // DenseMap.
+  struct DecorationSetDenseMapInfo {
+    static inline DecorationList getEmptyKey() {
+      return {DecorationInfo(spv::Decoration::Max)};
+    }
+    static inline DecorationList getTombstoneKey() {
+      return {DecorationInfo(spv::Decoration::Max)};
+    }
+    static unsigned getHashValue(const DecorationList &Val) {
+      unsigned hashValue = Val.size();
+      for (auto &decorationInfo : Val)
+        hashValue += static_cast<unsigned>(decorationInfo.decoration);
+
+      return hashValue;
+    }
+
+    static bool isEqual(const DecorationList &LHS, const DecorationList &RHS) {
+      // Must have the same number of decorations.
+      if (LHS.size() != RHS.size())
+        return false;
+
+      // Order of decorations does not matter.
+      for (auto &dec : LHS) {
+        auto found = std::find_if(
+            RHS.begin(), RHS.end(),
+            [&dec](const DecorationInfo &otherDec) { return dec == otherDec; });
+
+        if (found == RHS.end())
+          return false;
+      }
+
+      return true;
+    }
+  };
+
 public:
 public:
   EmitTypeHandler(ASTContext &astCtx, SpirvBuilder &builder,
   EmitTypeHandler(ASTContext &astCtx, SpirvBuilder &builder,
                   std::vector<uint32_t> *debugVec,
                   std::vector<uint32_t> *debugVec,
@@ -57,36 +98,38 @@ public:
   EmitTypeHandler(const EmitTypeHandler &) = delete;
   EmitTypeHandler(const EmitTypeHandler &) = delete;
   EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
   EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
 
 
-  // Emits OpDecorate (or OpMemberDecorate if memberIndex is non-zero)
-  // targetting the given type. Uses the given decoration kind and its
-  // parameters.
-  void emitDecoration(uint32_t typeResultId, spv::Decoration,
-                      llvm::ArrayRef<uint32_t> decorationParams,
-                      llvm::Optional<uint32_t> memberIndex = llvm::None);
-
   // Emits the instruction for the given type into the typeConstantBinary and
   // Emits the instruction for the given type into the typeConstantBinary and
-  // returns the result-id for the type.
+  // returns the result-id for the type. If the type has already been emitted,
+  // it only returns its result-id.
+  //
+  // If any names are associated with the type (or its members in case of
+  // structs), the OpName/OpMemberNames will also be emitted.
+  //
+  // If any decorations apply to the type, it also emits the decoration
+  // instructions into the annotationsBinary.
   uint32_t emitType(const SpirvType *, SpirvLayoutRule);
   uint32_t emitType(const SpirvType *, SpirvLayoutRule);
 
 
-  uint32_t getResultIdForType(const SpirvType *, SpirvLayoutRule,
-                              bool *alreadyExists);
-
 private:
 private:
   void initTypeInstruction(spv::Op op);
   void initTypeInstruction(spv::Op op);
   void finalizeTypeInstruction();
   void finalizeTypeInstruction();
 
 
-  // Methods associated with layout calculations ----
-
-  // TODO: This function should be merged into the Type class hierarchy.
-  std::pair<uint32_t, uint32_t> getAlignmentAndSize(const SpirvType *type,
-                                                    SpirvLayoutRule rule,
-                                                    uint32_t *stride);
+  // Figures out the decorations that apply to the given type with the given
+  // layout rule, and populates the given decoration set.
+  void getDecorationsForType(const SpirvType *type, SpirvLayoutRule rule,
+                             DecorationList *decorations);
 
 
-  void alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
-                                   uint32_t fieldSize, uint32_t fieldAlignment,
-                                   uint32_t *currentOffset);
+  // Returns the result-id for the given type and decorations. If a type with
+  // the same decorations have already been used, it returns the existing
+  // result-id. If not, creates a new result-id for such type and returns it.
+  uint32_t getResultIdForType(const SpirvType *, const DecorationList &,
+                              bool *alreadyExists);
 
 
-  void emitLayoutDecorations(const StructType *, SpirvLayoutRule);
+  // Emits OpDecorate (or OpMemberDecorate if memberIndex is non-zero)
+  // targetting the given type. Uses the given decoration kind and its
+  // parameters.
+  void emitDecoration(uint32_t typeResultId, spv::Decoration,
+                      llvm::ArrayRef<uint32_t> decorationParams,
+                      llvm::Optional<uint32_t> memberIndex = llvm::None);
 
 
   // Emits an OpName (if memberIndex is not provided) or OpMemberName (if
   // Emits an OpName (if memberIndex is not provided) or OpMemberName (if
   // memberIndex is provided) for the given target result-id.
   // memberIndex is provided) for the given target result-id.
@@ -104,6 +147,21 @@ private:
     return obj->getResultId();
     return obj->getResultId();
   }
   }
 
 
+  // ---- Methods associated with layout calculations ----
+
+  std::pair<uint32_t, uint32_t> getAlignmentAndSize(const SpirvType *type,
+                                                    SpirvLayoutRule rule,
+                                                    uint32_t *stride);
+
+  void alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
+                                   uint32_t fieldSize, uint32_t fieldAlignment,
+                                   uint32_t *currentOffset);
+
+  // Adds the layout decorations for the given type and layout rule to the given
+  // vector of decorations.
+  void getLayoutDecorations(const StructType *, SpirvLayoutRule,
+                            DecorationList *);
+
 private:
 private:
   /// Emits error to the diagnostic engine associated with this visitor.
   /// Emits error to the diagnostic engine associated with this visitor.
   template <unsigned N>
   template <unsigned N>
@@ -124,11 +182,11 @@ private:
   std::vector<uint32_t> *typeConstantBinary;
   std::vector<uint32_t> *typeConstantBinary;
   std::function<uint32_t()> takeNextIdFunction;
   std::function<uint32_t()> takeNextIdFunction;
 
 
-  // emittedTypes is a map that caches the <result-id> of types in order to
-  // avoid translating a type multiple times.
-  using LayoutRuleToTypeIdMap =
-      llvm::DenseMap<SpirvLayoutRule, uint32_t, SpirvLayoutRuleDenseMapInfo>;
-  llvm::DenseMap<const SpirvType *, LayoutRuleToTypeIdMap> emittedTypes;
+  // emittedTypes is a map that caches the result-id of types with a given list
+  // of decorations in order to avoid emitting an identical type multiple times.
+  using DecorationSetToTypeIdMap =
+      llvm::DenseMap<DecorationList, uint32_t, DecorationSetDenseMapInfo>;
+  llvm::DenseMap<const SpirvType *, DecorationSetToTypeIdMap> emittedTypes;
 };
 };
 
 
 /// \breif The visitor class that emits the SPIR-V words from the in-memory
 /// \breif The visitor class that emits the SPIR-V words from the in-memory

+ 72 - 62
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1052,39 +1052,80 @@ void EmitTypeHandler::finalizeTypeInstruction() {
 }
 }
 
 
 uint32_t EmitTypeHandler::getResultIdForType(const SpirvType *type,
 uint32_t EmitTypeHandler::getResultIdForType(const SpirvType *type,
-                                             SpirvLayoutRule rule,
+                                             const DecorationList &decs,
                                              bool *alreadyExists) {
                                              bool *alreadyExists) {
   assert(alreadyExists);
   assert(alreadyExists);
-
-  // Note: Layout rules only affect struct types. Therefore, for non-struct
-  // types, we must use the same result-id regardless of the layout rule.
-  if (!isa<StructType>(type))
-    rule = SpirvLayoutRule::Void;
-
-  // Check if the type has already been emitted.
   auto foundType = emittedTypes.find(type);
   auto foundType = emittedTypes.find(type);
   if (foundType != emittedTypes.end()) {
   if (foundType != emittedTypes.end()) {
-    auto foundLayoutRule = foundType->second.find(rule);
-    if (foundLayoutRule != foundType->second.end()) {
+    auto foundDecorationSet = foundType->second.find(decs);
+    if (foundDecorationSet != foundType->second.end()) {
       *alreadyExists = true;
       *alreadyExists = true;
-      return foundLayoutRule->second;
+      return foundDecorationSet->second;
     }
     }
   }
   }
 
 
   *alreadyExists = false;
   *alreadyExists = false;
   const uint32_t id = takeNextIdFunction();
   const uint32_t id = takeNextIdFunction();
-  emittedTypes[type][rule] = id;
+  emittedTypes[type][decs] = id;
   return id;
   return id;
 }
 }
 
 
+void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
+                                            SpirvLayoutRule rule,
+                                            DecorationList *decs) {
+  // Array types
+  if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
+    // ArrayStride decoration is needed for array types, but we won't have
+    // stride information for structured/byte buffers since they contain runtime
+    // arrays.
+    if (rule != SpirvLayoutRule::Void &&
+        !isAKindOfStructuredOrByteBuffer(type)) {
+      uint32_t stride = 0;
+      (void)getAlignmentAndSize(type, rule, &stride);
+      decs->push_back(DecorationInfo(spv::Decoration::ArrayStride, {stride}));
+    }
+  }
+  // RuntimeArray types
+  else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
+    // ArrayStride decoration is needed for runtime array types.
+    if (rule != SpirvLayoutRule::Void) {
+      uint32_t stride = 0;
+      (void)getAlignmentAndSize(type, rule, &stride);
+      decs->push_back(DecorationInfo(spv::Decoration::ArrayStride, {stride}));
+    }
+  }
+  // Structure types
+  else if (const auto *structType = dyn_cast<StructType>(type)) {
+    llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
+    size_t numFields = fields.size();
+
+    // Emit the layout decorations for the structure.
+    getLayoutDecorations(structType, rule, decs);
+
+    // Emit NonWritable decorations
+    if (structType->isReadOnly())
+      for (size_t i = 0; i < numFields; ++i)
+        decs->push_back(DecorationInfo(spv::Decoration::NonWritable, {}, i));
+
+    // Emit Block or BufferBlock decorations if necessary.
+    auto interfaceType = structType->getInterfaceType();
+    if (interfaceType == StructInterfaceType::StorageBuffer)
+      decs->push_back(DecorationInfo(spv::Decoration::BufferBlock, {}));
+    else if (interfaceType == StructInterfaceType::UniformBuffer)
+      decs->push_back(DecorationInfo(spv::Decoration::Block, {}));
+  }
+
+  // We currently only have decorations for arrays, runtime arrays, and
+  // structure types.
+}
+
 uint32_t EmitTypeHandler::emitType(const SpirvType *type,
 uint32_t EmitTypeHandler::emitType(const SpirvType *type,
                                    SpirvLayoutRule rule) {
                                    SpirvLayoutRule rule) {
-  //
-  // TODO: This method is currently missing decorations for types completely.
-  //
-
+  // First get the decorations that would apply to this type.
   bool alreadyExists = false;
   bool alreadyExists = false;
-  const uint32_t id = getResultIdForType(type, rule, &alreadyExists);
+  DecorationList decs;
+  getDecorationsForType(type, rule, &decs);
+  const uint32_t id = getResultIdForType(type, decs, &alreadyExists);
 
 
   // If the type has already been emitted, we just need to return its
   // If the type has already been emitted, we just need to return its
   // <result-id>.
   // <result-id>.
@@ -1188,16 +1229,6 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     curTypeInst.push_back(elemTypeId);
     curTypeInst.push_back(elemTypeId);
     curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(constant));
     curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(constant));
     finalizeTypeInstruction();
     finalizeTypeInstruction();
-
-    // ArrayStride decoration is needed for array types, but we won't have
-    // stride information for structured/byte buffers since they contain runtime
-    // arrays.
-    if (rule != SpirvLayoutRule::Void &&
-        !isAKindOfStructuredOrByteBuffer(type)) {
-      uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, &stride);
-      emitDecoration(id, spv::Decoration::ArrayStride, {stride});
-    }
   }
   }
   // RuntimeArray types
   // RuntimeArray types
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
@@ -1206,13 +1237,6 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     curTypeInst.push_back(id);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
     curTypeInst.push_back(elemTypeId);
     finalizeTypeInstruction();
     finalizeTypeInstruction();
-
-    // ArrayStride decoration is needed for runtime array types.
-    if (rule != SpirvLayoutRule::Void) {
-      uint32_t stride = 0;
-      (void)getAlignmentAndSize(type, rule, &stride);
-      emitDecoration(id, spv::Decoration::ArrayStride, {stride});
-    }
   }
   }
   // Structure types
   // Structure types
   else if (const auto *structType = dyn_cast<StructType>(type)) {
   else if (const auto *structType = dyn_cast<StructType>(type)) {
@@ -1234,21 +1258,6 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     for (auto fieldTypeId : fieldTypeIds)
     for (auto fieldTypeId : fieldTypeIds)
       curTypeInst.push_back(fieldTypeId);
       curTypeInst.push_back(fieldTypeId);
     finalizeTypeInstruction();
     finalizeTypeInstruction();
-
-    // Emit the layout decorations for the structure.
-    emitLayoutDecorations(structType, rule);
-
-    // Emit NonWritable decorations
-    if (structType->isReadOnly())
-      for (size_t i = 0; i < numFields; ++i)
-        emitDecoration(id, spv::Decoration::NonWritable, {}, i);
-
-    // Emit Block or BufferBlock decorations if necessary.
-    auto interfaceType = structType->getInterfaceType();
-    if (interfaceType == StructInterfaceType::StorageBuffer)
-      emitDecoration(id, spv::Decoration::BufferBlock, {});
-    else if (interfaceType == StructInterfaceType::UniformBuffer)
-      emitDecoration(id, spv::Decoration::Block, {});
   }
   }
   // Pointer types
   // Pointer types
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
@@ -1285,6 +1294,11 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     llvm_unreachable("unhandled type in emitType");
     llvm_unreachable("unhandled type in emitType");
   }
   }
 
 
+  // Finally, emit decorations for the type into the annotationsBinary.
+  for (auto &decorInfo : decs)
+    emitDecoration(id, decorInfo.decoration, decorInfo.decorationParams,
+                   decorInfo.memberIndex);
+
   return id;
   return id;
 }
 }
 
 
@@ -1570,14 +1584,10 @@ void EmitTypeHandler::alignUsingHLSLRelaxedLayout(const SpirvType *fieldType,
   }
   }
 }
 }
 
 
-void EmitTypeHandler::emitLayoutDecorations(const StructType *structType,
-                                            SpirvLayoutRule rule) {
-  // Decorations for a type can be emitted after the type itself has been
-  // visited, because we need the result-id of the type as the target of the
-  // decoration.
-  bool visited = false;
-  const uint32_t typeResultId = getResultIdForType(structType, rule, &visited);
-  assert(visited);
+void EmitTypeHandler::getLayoutDecorations(const StructType *structType,
+                                           SpirvLayoutRule rule,
+                                           DecorationList *decs) {
+  assert(decs);
 
 
   uint32_t offset = 0, index = 0;
   uint32_t offset = 0, index = 0;
   for (auto &field : structType->getFields()) {
   for (auto &field : structType->getFields()) {
@@ -1618,7 +1628,7 @@ void EmitTypeHandler::emitLayoutDecorations(const StructType *structType,
     }
     }
 
 
     // Each structure-type member must have an Offset Decoration.
     // Each structure-type member must have an Offset Decoration.
-    emitDecoration(typeResultId, spv::Decoration::Offset, {offset}, index);
+    decs->push_back(DecorationInfo(spv::Decoration::Offset, {offset}, index));
     offset += memberSize;
     offset += memberSize;
 
 
     // Each structure-type member that is a matrix or array-of-matrices must be
     // Each structure-type member that is a matrix or array-of-matrices must be
@@ -1640,19 +1650,19 @@ void EmitTypeHandler::emitLayoutDecorations(const StructType *structType,
         std::tie(memberAlignment, memberSize) =
         std::tie(memberAlignment, memberSize) =
             getAlignmentAndSize(fieldType, rule, &stride);
             getAlignmentAndSize(fieldType, rule, &stride);
 
 
-        emitDecoration(typeResultId, spv::Decoration::MatrixStride, {stride},
-                       index);
+        decs->push_back(
+            DecorationInfo(spv::Decoration::MatrixStride, {stride}, index));
 
 
         // We need to swap the RowMajor and ColMajor decorations since HLSL
         // We need to swap the RowMajor and ColMajor decorations since HLSL
         // matrices are conceptually row-major while SPIR-V are conceptually
         // matrices are conceptually row-major while SPIR-V are conceptually
         // column-major.
         // column-major.
         if (matType->isRowMajorMat()) {
         if (matType->isRowMajorMat()) {
-          emitDecoration(typeResultId, spv::Decoration::ColMajor, {}, index);
+          decs->push_back(DecorationInfo(spv::Decoration::ColMajor, {}, index));
         } else {
         } else {
           // If the source code has neither row_major nor column_major
           // If the source code has neither row_major nor column_major
           // annotated, it should be treated as column_major since that's the
           // annotated, it should be treated as column_major since that's the
           // default.
           // default.
-          emitDecoration(typeResultId, spv::Decoration::RowMajor, {}, index);
+          decs->push_back(DecorationInfo(spv::Decoration::RowMajor, {}, index));
         }
         }
       }
       }
     }
     }

+ 2 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -222,6 +222,7 @@ SpirvBuilder::createAccessChain(QualType resultType, SpirvInstruction *base,
   auto *instruction =
   auto *instruction =
       new (context) SpirvAccessChain(resultType, /*id*/ 0, loc, base, indexes);
       new (context) SpirvAccessChain(resultType, /*id*/ 0, loc, base, indexes);
   instruction->setStorageClass(base->getStorageClass());
   instruction->setStorageClass(base->getStorageClass());
+  instruction->setLayoutRule(base->getLayoutRule());
   insertPoint->addInstruction(instruction);
   insertPoint->addInstruction(instruction);
   return instruction;
   return instruction;
 }
 }
@@ -234,6 +235,7 @@ SpirvAccessChain *SpirvBuilder::createAccessChain(
       SpirvAccessChain(/*QualType*/ {}, /*id*/ 0, loc, base, indexes);
       SpirvAccessChain(/*QualType*/ {}, /*id*/ 0, loc, base, indexes);
   instruction->setResultType(resultType);
   instruction->setResultType(resultType);
   instruction->setStorageClass(base->getStorageClass());
   instruction->setStorageClass(base->getStorageClass());
+  instruction->setLayoutRule(base->getLayoutRule());
   insertPoint->addInstruction(instruction);
   insertPoint->addInstruction(instruction);
   return instruction;
   return instruction;
 }
 }