ソースを参照

[spirv] Support HLSL 'register(c#)' annotation. (#1912)

Ehsan 6 年 前
コミット
ebb720ced3

+ 16 - 0
docs/SPIR-V.rst

@@ -879,6 +879,22 @@ to a struct memeber affects all variables of the struct type in question. So
 sharing the same struct definition having ``[[vk::offset]]`` annotations means
 also sharing the layout.
 
+For global variables (which are collected into the ``$Globals`` cbuffer), you
+can use the native HLSL ``:register(c#)`` attribute. Note that ``[[vk::offset]]``
+and ``:packoffset`` cannot be applied to these variables.
+
+If ``register(cX)`` is used on any global variable, the offset for that variable
+is set to ``X * 16``, and the offset for all other global variables without the
+``register(c#)`` annotation will be set to the next available address after
+the highest explicit address. For example:
+
+.. code:: hlsl
+
+  float x : register(c10);   // Offset = 160 (10 * 16)
+  float y;                   // Offset = 164 (160 + 4)
+  float z: register(c1);     // Offset = 16  (1  * 16)
+
+
 These attributes give great flexibility but also responsibility to the
 developer; the compiler will just take in what is specified in the source code
 and emit it to SPIR-V with no error checking.

+ 5 - 6
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -401,11 +401,10 @@ public:
   public:
     FieldInfo(QualType astType_, llvm::StringRef name_ = "",
               clang::VKOffsetAttr *offset = nullptr,
-              hlsl::ConstantPacking *packOffset = nullptr)
+              hlsl::ConstantPacking *packOffset = nullptr,
+              const hlsl::RegisterAssignment *regC = nullptr)
         : astType(astType_), name(name_), vkOffsetAttr(offset),
-          packOffsetAttr(packOffset) {}
-
-    bool operator==(const FieldInfo &that) const;
+          packOffsetAttr(packOffset), registerC(regC) {}
 
     // The field's type.
     QualType astType;
@@ -415,6 +414,8 @@ public:
     clang::VKOffsetAttr *vkOffsetAttr;
     // :packoffset() annotations associated with this field.
     hlsl::ConstantPacking *packOffsetAttr;
+    // :register(c#) annotations associated with this field.
+    const hlsl::RegisterAssignment *registerC;
   };
 
   HybridStructType(
@@ -430,8 +431,6 @@ public:
   llvm::StringRef getStructName() const { return getName(); }
   StructInterfaceType getInterfaceType() const { return interfaceType; }
 
-  bool operator==(const HybridStructType &that) const;
-
 private:
   // Reflection is heavily used in graphics pipelines. Reflection relies on
   // struct names and field names. That basically means we cannot ignore these

+ 14 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -253,6 +253,15 @@ const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
   return nullptr;
 }
 
+/// \brief Returns the stage variable's 'register(c#) assignment for the given
+/// Decl. Return nullptr if the given variable does not have such assignment.
+const hlsl::RegisterAssignment *getRegisterCAssignment(const NamedDecl *decl) {
+  const auto *regAssignment = getResourceBinding(decl);
+  if (regAssignment)
+    return regAssignment->RegisterType == 'c' ? regAssignment : nullptr;
+  return nullptr;
+}
+
 /// \brief Returns true if the given declaration has a primitive type qualifier.
 /// Returns false otherwise.
 inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
@@ -699,13 +708,17 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
     const auto *declDecl = cast<DeclaratorDecl>(subDecl);
 
+    // In case 'register(c#)' annotation is placed on a global variable.
+    const hlsl::RegisterAssignment *registerC =
+        forGlobals ? getRegisterCAssignment(declDecl) : nullptr;
+
     // All fields are qualified with const. It will affect the debug name.
     // We don't need it here.
     auto varType = declDecl->getType();
     varType.removeLocalConst();
     HybridStructType::FieldInfo info(varType, declDecl->getName(),
                                      declDecl->getAttr<VKOffsetAttr>(),
-                                     getPackOffset(declDecl));
+                                     getPackOffset(declDecl), registerC);
     fields.push_back(info);
   }
 

+ 54 - 4
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -374,7 +374,8 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
       fields.push_back(HybridStructType::FieldInfo(
           field->getType(), field->getName(),
           /*vkoffset*/ field->getAttr<VKOffsetAttr>(),
-          /*packoffset*/ getPackOffset(field)));
+          /*packoffset*/ getPackOffset(field),
+          /*RegisterAssignment*/ nullptr));
     }
 
     auto loweredFields = populateLayoutInformation(fields, rule);
@@ -667,12 +668,44 @@ LowerTypeVisitor::populateLayoutInformation(
 
   // The resulting vector of fields with proper layout information.
   llvm::SmallVector<StructType::FieldInfo, 4> loweredFields;
+  llvm::SmallVector<StructType::FieldInfo, 4> result;
+
+  using RegisterFieldPair =
+      std::pair<uint32_t, const HybridStructType::FieldInfo *>;
+  struct RegisterFieldPairLess {
+    bool operator()(const RegisterFieldPair &obj1,
+                    const RegisterFieldPair &obj2) const {
+      return obj1.first < obj2.first;
+    }
+  };
+  std::set<RegisterFieldPair, RegisterFieldPairLess> registerCSet;
+  std::vector<const HybridStructType::FieldInfo *> sortedFields;
+  llvm::DenseMap<const HybridStructType::FieldInfo *, uint32_t> fieldToIndexMap;
+
+  // First, check to see if any of the structure members had 'register(c#)'
+  // location semantics. If so, members that do not have the 'register(c#)'
+  // assignment should be allocated after the *highest explicit address*.
+  // Example:
+  // float x : register(c10);   // Offset = 160 (10 * 16)
+  // float y;                   // Offset = 164 (160 + 4)
+  // float z: register(c1);     // Offset = 16  (1  * 16)
+  for (const auto &field : fields)
+    if (field.registerC)
+      registerCSet.insert(
+          RegisterFieldPair(field.registerC->RegisterNumber, &field));
+  for (const auto &pair : registerCSet)
+    sortedFields.push_back(pair.second);
+  for (const auto &field : fields)
+    if (!field.registerC)
+      sortedFields.push_back(&field);
 
   uint32_t offset = 0;
-  for (const auto field : fields) {
+  for (const auto *fieldPtr : sortedFields) {
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
     // HLSLBufferDecls).
+    const auto field = *fieldPtr;
     auto fieldType = field.astType;
+    fieldToIndexMap[fieldPtr] = loweredFields.size();
 
     // Lower the field type fist. This call will populate proper matrix
     // majorness information.
@@ -689,7 +722,7 @@ LowerTypeVisitor::populateLayoutInformation(
     std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
         fieldType, rule, /*isRowMajor*/ llvm::None, &stride);
 
-    // The next avaiable location after layouting the previos members
+    // The next avaiable location after laying out the previous members
     const uint32_t nextLoc = offset;
 
     if (rule == SpirvLayoutRule::RelaxedGLSLStd140 ||
@@ -719,6 +752,19 @@ LowerTypeVisitor::populateLayoutInformation(
         offset = packOffset;
       }
     }
+    // The :register(c#) annotation takes precedence over normal layout
+    // calculation.
+    else if (field.registerC) {
+      offset = 16 * field.registerC->RegisterNumber;
+      // Do minimal check to make sure the offset specified by :register(c#)
+      // does not cause overlap.
+      if (offset < nextLoc) {
+        emitError(
+            "found offset overlap when processing register(c%0) assignment",
+            field.registerC->Loc)
+            << field.registerC->RegisterNumber;
+      }
+    }
 
     // Each structure-type member must have an Offset Decoration.
     loweredField.offset = offset;
@@ -750,7 +796,11 @@ LowerTypeVisitor::populateLayoutInformation(
     loweredFields.push_back(loweredField);
   }
 
-  return loweredFields;
+  // Re-order the sorted fields back to their original order.
+  for (const auto &field : fields)
+    result.push_back(loweredFields[fieldToIndexMap[&field]]);
+
+  return result;
 }
 
 } // namespace spirv

+ 0 - 18
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -227,24 +227,6 @@ HybridStructType::HybridStructType(
       fields(fieldsVec.begin(), fieldsVec.end()), readOnly(isReadOnly),
       interfaceType(iface) {}
 
-bool HybridStructType::FieldInfo::
-operator==(const HybridStructType::FieldInfo &that) const {
-  return astType == that.astType &&
-         // vkOffsetAttr may be nullptr. If not, should have the same offset.
-         (vkOffsetAttr == that.vkOffsetAttr ||
-          vkOffsetAttr->getOffset() == that.vkOffsetAttr->getOffset()) &&
-         // packOffsetAttr may be nullptr. If not, should have the same offset.
-         (packOffsetAttr == that.packOffsetAttr ||
-          (packOffsetAttr->Subcomponent == that.packOffsetAttr->Subcomponent &&
-           packOffsetAttr->ComponentOffset ==
-               that.packOffsetAttr->ComponentOffset));
-}
-
-bool HybridStructType::operator==(const HybridStructType &that) const {
-  return fields == that.fields && getName() == that.getName() &&
-         readOnly == that.readOnly && interfaceType == that.interfaceType;
-}
-
 FunctionType::FunctionType(const SpirvType *ret,
                            llvm::ArrayRef<const SpirvType *> param)
     : SpirvType(TK_Function), returnType(ret),

+ 38 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.register-c.all.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T vs_6_0 -E main -fvk-use-dx-layout
+
+// CHECK: OpMemberDecorate %type__Globals 0 Offset 0
+// CHECK: OpMemberDecorate %type__Globals 1 Offset 16
+// CHECK: OpMemberDecorate %type__Globals 2 Offset 32
+// CHECK: OpMemberDecorate %type__Globals 3 Offset 64
+// CHECK: OpMemberDecorate %type__Globals 4 Offset 96
+// CHECK: OpMemberDecorate %type__Globals 5 Offset 144
+// CHECK: OpMemberDecorate %type__Globals 6 Offset 160
+// CHECK: OpMemberDecorate %type__Globals 7 Offset 192
+// CHECK: OpMemberDecorate %type__Globals 8 Offset 240
+// CHECK: OpMemberDecorate %type__Globals 9 Offset 288
+// CHECK: OpMemberDecorate %type__Globals 10 Offset 336
+
+float  x       : register(c0);  // Offset:   0   Size:  4
+float  y       : register(c1);  // Offset:  16   Size:  4
+float  z       : register(c2);  // Offset:  32   Size:  4
+float  w       : register(c4);  // Offset:  64   Size:  4
+float2 xy      : register(c6);  // Offset:  96   Size:  8
+float3 xyz     : register(c9);  // Offset: 144   Size: 12
+float4 xyzw    : register(c10); // Offset: 160   Size: 16
+float4 arr4[3] : register(c12); // Offset: 192   Size: 48
+float2 arr2[3] : register(c15); // Offset: 240   Size: 40
+float3 arr3[3] : register(c18); // Offset: 288   Size: 44
+float  s       : register(c21); // Offset: 336   Size:  4
+
+float4 main(float4 Pos : Position) : SV_Position
+{
+  float4 output = Pos;
+  output.x    += x + s;
+  output.y    += y;
+  output.z    += z;
+  output.w    += w;
+  output.xy   += xy + arr2[0];
+  output.xyz  += xyz + arr3[1];
+  output.xyzw += xyzw + arr4[2];
+  return output;
+}

+ 30 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.register-c.error.hlsl

@@ -0,0 +1,30 @@
+// Run: %dxc -T vs_6_0 -E main -fvk-use-dx-layout
+
+// CHECK: 15:18: error: found offset overlap when processing register(c8) assignment
+// CHECK: 16:18: error: found offset overlap when processing register(c9) assignment
+// CHECK: 17:18: error: found offset overlap when processing register(c10) assignment
+
+float  x       : register(c0);
+float  y       : register(c1);
+float  z       : register(c2);
+float  w       : register(c3);
+float2 xy      : register(c4);
+float3 xyz     : register(c5);
+float4 xyzw    : register(c6);
+float4 arr4[3] : register(c7);
+float2 arr2[3] : register(c8);   // This should generate an overlap error with the previous line
+float3 arr3[3] : register(c9);   // This should generate an overlap error with the previous line
+float  s       : register(c10);  // This should generate an overlap error with the previous line
+
+float4 main(float4 Pos : Position) : SV_Position
+{
+  float4 output = Pos;
+  output.x    += x + s;
+  output.y    += y;
+  output.z    += z;
+  output.w    += w;
+  output.xy   += xy + arr2[0];
+  output.xyz  += xyz + arr3[1];
+  output.xyzw += xyzw + arr4[2];
+  return output;
+}

+ 38 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.register-c.mixed.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T vs_6_0 -E main -fvk-use-dx-layout
+
+// CHECK: OpMemberDecorate %type__Globals 0 Offset 200
+// CHECK: OpMemberDecorate %type__Globals 1 Offset 0
+// CHECK: OpMemberDecorate %type__Globals 2 Offset 32
+// CHECK: OpMemberDecorate %type__Globals 3 Offset 16
+// CHECK: OpMemberDecorate %type__Globals 4 Offset 144
+// CHECK: OpMemberDecorate %type__Globals 5 Offset 48
+// CHECK: OpMemberDecorate %type__Globals 6 Offset 208
+// CHECK: OpMemberDecorate %type__Globals 7 Offset 64
+// CHECK: OpMemberDecorate %type__Globals 8 Offset 160
+// CHECK: OpMemberDecorate %type__Globals 9 Offset 224
+// CHECK: OpMemberDecorate %type__Globals 10 Offset 268
+
+float x                      ;  // Offset:  200   Size:     4
+float y        : register(c0);  // Offset:    0   Size:     4
+float z        : register(c2);  // Offset:   32   Size:     4
+float w        : register(c1);  // Offset:   16   Size:     4
+float2 xy      : register(c9);  // Offset:  144   Size:     8
+float3 xyz     : register(c3);  // Offset:   48   Size:    12
+float4 xyzw                  ;  // Offset:  208   Size:    16
+float4 arr4[3] : register(c4);  // Offset:   64   Size:    48
+float2 arr2[3] : register(c10); // Offset:  160   Size:    40
+float3 arr3[3]               ;  // Offset:  224   Size:    44
+float s                      ;  // Offset:  268   Size:     4
+
+float4 main(float4 Pos : Position) : SV_Position
+{
+  float4 output = Pos;
+  output.x    += x + s;
+  output.y    += y;
+  output.z    += z;
+  output.w    += w;
+  output.xy   += xy + arr2[0];
+  output.xyz  += xyz + arr3[1];
+  output.xyzw += xyzw + arr4[2];
+  return output;
+}

+ 18 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1637,6 +1637,24 @@ TEST_F(FileTest, VulkanLayoutCBufferScalar) {
   runFileTest("vk.layout.cbuffer.scalar.hlsl");
 }
 
+TEST_F(FileTest, VulkanLayoutRegisterCAll) {
+  // :register(c#) used on all global variables.
+  setDxLayout();
+  runFileTest("vk.layout.register-c.all.hlsl");
+}
+
+TEST_F(FileTest, VulkanLayoutRegisterCMixed) {
+  // :register(c#) used only on some global variables.
+  setDxLayout();
+  runFileTest("vk.layout.register-c.mixed.hlsl");
+}
+
+TEST_F(FileTest, VulkanLayoutRegisterCError) {
+  // :register(c#) causing offset overlap for global variables.
+  setDxLayout();
+  runFileTest("vk.layout.register-c.error.hlsl", Expect::Failure);
+}
+
 TEST_F(FileTest, VulkanSubpassInput) { runFileTest("vk.subpass-input.hlsl"); }
 TEST_F(FileTest, VulkanSubpassInputBinding) {
   runFileTest("vk.subpass-input.binding.hlsl");