2
0
Эх сурвалжийг харах

[spirv] Fix location assignment for composite stage IO types (#1098)

Scalars and vectors typically take only one location. But matrices
and composite types take multiple sequential locations.
Lei Zhang 7 жил өмнө
parent
commit
7b7510b2c5

+ 19 - 9
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -799,12 +799,19 @@ public:
   /// Uses the given location.
   /// Uses the given location.
   void useLoc(uint32_t loc) { usedLocs.set(loc); }
   void useLoc(uint32_t loc) { usedLocs.set(loc); }
 
 
-  /// Uses the next available location.
-  uint32_t useNextLoc() {
+  /// Uses the next |count| available location.
+  int useNextLocs(uint32_t count) {
     while (usedLocs[nextLoc])
     while (usedLocs[nextLoc])
       nextLoc++;
       nextLoc++;
-    usedLocs.set(nextLoc);
-    return nextLoc++;
+
+    int toUse = nextLoc;
+
+    for (uint32_t i = 0; i < count; ++i) {
+      assert(!usedLocs[nextLoc]);
+      usedLocs.set(nextLoc++);
+    }
+
+    return toUse;
   }
   }
 
 
   /// Returns true if the given location number is already used.
   /// Returns true if the given location number is already used.
@@ -982,7 +989,8 @@ bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
   }
   }
 
 
   for (const auto *var : vars)
   for (const auto *var : vars)
-    theBuilder.decorateLocation(var->getSpirvId(), locSet.useNextLoc());
+    theBuilder.decorateLocation(var->getSpirvId(),
+                                locSet.useNextLocs(var->getLocationCount()));
 
 
   return true;
   return true;
 }
 }
@@ -1263,9 +1271,11 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
       typeId = theBuilder.getArrayType(typeId,
       typeId = theBuilder.getArrayType(typeId,
                                        theBuilder.getConstantUint32(arraySize));
                                        theBuilder.getConstantUint32(arraySize));
 
 
-    StageVar stageVar(sigPoint, semanticToUse->str, semanticToUse->semantic,
-                      semanticToUse->name, semanticToUse->index, builtinAttr,
-                      typeId);
+    StageVar stageVar(
+        sigPoint, semanticToUse->str, semanticToUse->semantic,
+        semanticToUse->name, semanticToUse->index, builtinAttr, typeId,
+        // For HS/DS/GS, we have already stripped the outmost arrayness on type.
+        typeTranslator.getLocationCount(type));
     const auto name = namePrefix.str() + "." + stageVar.getSemanticStr();
     const auto name = namePrefix.str() + "." + stageVar.getSemanticStr();
     const uint32_t varId =
     const uint32_t varId =
         createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc);
         createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc);
@@ -1713,7 +1723,7 @@ uint32_t DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn) {
 
 
   StageVar stageVar(sigPoint, /*semaStr=*/"", hlsl::Semantic::GetInvalid(),
   StageVar stageVar(sigPoint, /*semaStr=*/"", hlsl::Semantic::GetInvalid(),
                     /*semaName=*/"", /*semaIndex=*/0, /*builtinAttr=*/nullptr,
                     /*semaName=*/"", /*semaIndex=*/0, /*builtinAttr=*/nullptr,
-                    type);
+                    type, /*locCount=*/0);
 
 
   stageVar.setIsSpirvBuiltin();
   stageVar.setIsSpirvBuiltin();
   stageVar.setSpirvId(varId);
   stageVar.setSpirvId(varId);

+ 9 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -38,11 +38,13 @@ public:
   inline StageVar(const hlsl::SigPoint *sig, llvm::StringRef semaStr,
   inline StageVar(const hlsl::SigPoint *sig, llvm::StringRef semaStr,
                   const hlsl::Semantic *sema, llvm::StringRef semaName,
                   const hlsl::Semantic *sema, llvm::StringRef semaName,
                   uint32_t semaIndex, const VKBuiltInAttr *builtin,
                   uint32_t semaIndex, const VKBuiltInAttr *builtin,
-                  uint32_t type)
+
+                  uint32_t type, uint32_t locCount)
       : sigPoint(sig), semanticStr(semaStr), semantic(sema),
       : sigPoint(sig), semanticStr(semaStr), semantic(sema),
         semanticName(semaName), semanticIndex(semaIndex), builtinAttr(builtin),
         semanticName(semaName), semanticIndex(semaIndex), builtinAttr(builtin),
         typeId(type), valueId(0), isBuiltin(false),
         typeId(type), valueId(0), isBuiltin(false),
-        storageClass(spv::StorageClass::Max), location(nullptr) {
+        storageClass(spv::StorageClass::Max), location(nullptr),
+        locationCount(locCount) {
     isBuiltin = builtinAttr != nullptr;
     isBuiltin = builtinAttr != nullptr;
   }
   }
 
 
@@ -68,6 +70,8 @@ public:
   const VKLocationAttr *getLocationAttr() const { return location; }
   const VKLocationAttr *getLocationAttr() const { return location; }
   void setLocationAttr(const VKLocationAttr *loc) { location = loc; }
   void setLocationAttr(const VKLocationAttr *loc) { location = loc; }
 
 
+  uint32_t getLocationCount() const { return locationCount; }
+
 private:
 private:
   /// HLSL SigPoint. It uniquely identifies each set of parameters that may be
   /// HLSL SigPoint. It uniquely identifies each set of parameters that may be
   /// input or output for each entry point.
   /// input or output for each entry point.
@@ -92,6 +96,8 @@ private:
   spv::StorageClass storageClass;
   spv::StorageClass storageClass;
   /// Location assignment if input/output variable.
   /// Location assignment if input/output variable.
   const VKLocationAttr *location;
   const VKLocationAttr *location;
+  /// How many locations this stage variable takes.
+  uint32_t locationCount;
 };
 };
 
 
 class ResourceVar {
 class ResourceVar {
@@ -562,7 +568,7 @@ private:
   /// the children of this decl, and the children of this decl will be using
   /// the children of this decl, and the children of this decl will be using
   /// the semantic in inheritSemantic, with index increasing sequentially.
   /// the semantic in inheritSemantic, with index increasing sequentially.
   bool createStageVars(const hlsl::SigPoint *sigPoint, const NamedDecl *decl,
   bool createStageVars(const hlsl::SigPoint *sigPoint, const NamedDecl *decl,
-                       bool asInput, QualType type, uint32_t arraySize,
+                       bool asInput, QualType asType, uint32_t arraySize,
                        const llvm::StringRef namePrefix,
                        const llvm::StringRef namePrefix,
                        llvm::Optional<uint32_t> invocationId, uint32_t *value,
                        llvm::Optional<uint32_t> invocationId, uint32_t *value,
                        bool noWriteBack, SemanticInfo *inheritSemantic);
                        bool noWriteBack, SemanticInfo *inheritSemantic);

+ 84 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -203,6 +203,90 @@ void TypeTranslator::popIntendedLiteralType() {
   intendedLiteralTypes.pop();
   intendedLiteralTypes.pop();
 }
 }
 
 
+uint32_t TypeTranslator::getLocationCount(QualType type) {
+  // See Vulkan spec 14.1.4. Location Assignment for the complete set of rules.
+
+  const auto canonicalType = type.getCanonicalType();
+  if (canonicalType != type)
+    return getLocationCount(canonicalType);
+
+  // Inputs and outputs of the following types consume a single interface
+  // location:
+  // * 16-bit scalar and vector types, and
+  // * 32-bit scalar and vector types, and
+  // * 64-bit scalar and 2-component vector types.
+
+  // 64-bit three- and four- component vectors consume two consecutive
+  // locations.
+
+  // Primitive types
+  if (isScalarType(type))
+    return 1;
+
+  // Vector types
+  {
+    QualType elemType = {};
+    uint32_t elemCount = {};
+    if (isVectorType(type, &elemType, &elemCount)) {
+      const auto *builtinType = elemType->getAs<BuiltinType>();
+      switch (builtinType->getKind()) {
+      case BuiltinType::Double:
+      case BuiltinType::LongLong:
+      case BuiltinType::ULongLong:
+        if (elemCount >= 3)
+          return 2;
+      }
+      return 1;
+    }
+  }
+
+  // If the declared input or output is an n * m 16- , 32- or 64- bit matrix,
+  // it will be assigned multiple locations starting with the location
+  // specified. The number of locations assigned for each matrix will be the
+  // same as for an n-element array of m-component vectors.
+
+  // Matrix types
+  {
+    QualType elemType = {};
+    uint32_t rowCount = 0, colCount = 0;
+    if (isMxNMatrix(type, &elemType, &rowCount, &colCount))
+      return getLocationCount(astContext.getExtVectorType(elemType, colCount)) *
+             rowCount;
+  }
+
+  // Typedefs
+  if (const auto *typedefType = type->getAs<TypedefType>())
+    return getLocationCount(typedefType->desugar());
+
+  // Reference types
+  if (const auto *refType = type->getAs<ReferenceType>())
+    return getLocationCount(refType->getPointeeType());
+
+  // Pointer types
+  if (const auto *ptrType = type->getAs<PointerType>())
+    return getLocationCount(ptrType->getPointeeType());
+
+  // If a declared input or output is an array of size n and each element takes
+  // m locations, it will be assigned m * n consecutive locations starting with
+  // the location specified.
+
+  // Array types
+  if (const auto *arrayType = astContext.getAsConstantArrayType(type))
+    return getLocationCount(arrayType->getElementType()) *
+           static_cast<uint32_t>(arrayType->getSize().getZExtValue());
+
+  // Struct type
+  if (const auto *structType = type->getAs<RecordType>()) {
+    assert(false && "all structs should already be flattened");
+    return 0;
+  }
+
+  emitError(
+      "calculating number of occupied locations for type %0 unimplemented")
+      << type;
+  return 0;
+}
+
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
                                        bool isRowMajor) {
                                        bool isRowMajor) {
   // We can only apply row_major to matrices or arrays of matrices.
   // We can only apply row_major to matrices or arrays of matrices.

+ 3 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -221,6 +221,9 @@ public:
   llvm::SmallVector<const Decoration *, 4>
   llvm::SmallVector<const Decoration *, 4>
   getLayoutDecorations(const DeclContext *decl, LayoutRule rule);
   getLayoutDecorations(const DeclContext *decl, LayoutRule rule);
 
 
+  /// \brief Returns how many sequential locations are consumed by a given type.
+  uint32_t getLocationCount(QualType type);
+
 private:
 private:
   /// \brief Wrapper method to create an error message and report it
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.
   /// in the diagnostic engine associated with this consumer.

+ 4 - 4
tools/clang/test/CodeGenSPIRV/bezier.domain.hlsl2spv

@@ -107,10 +107,10 @@ DS_OUTPUT BezierEvalDS( HS_CONSTANT_DATA_OUTPUT input,
 // OpDecorate %gl_TessCoord Patch
 // OpDecorate %gl_TessCoord Patch
 // OpDecorate %in_var_BEZIERPOS Location 0
 // OpDecorate %in_var_BEZIERPOS Location 0
 // OpDecorate %in_var_TANGENT Location 1
 // OpDecorate %in_var_TANGENT Location 1
-// OpDecorate %in_var_TANUCORNER Location 2
-// OpDecorate %in_var_TANVCORNER Location 3
-// OpDecorate %in_var_TANWEIGHTS Location 4
-// OpDecorate %in_var_TEXCOORD Location 5
+// OpDecorate %in_var_TANUCORNER Location 5
+// OpDecorate %in_var_TANVCORNER Location 9
+// OpDecorate %in_var_TANWEIGHTS Location 13
+// OpDecorate %in_var_TEXCOORD Location 14
 // OpDecorate %out_var_NORMAL Location 0
 // OpDecorate %out_var_NORMAL Location 0
 // OpDecorate %out_var_TEXCOORD Location 1
 // OpDecorate %out_var_TEXCOORD Location 1
 // OpDecorate %out_var_TANGENT Location 2
 // OpDecorate %out_var_TANGENT Location 2

+ 4 - 4
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -129,10 +129,10 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // OpDecorate %in_var_TANGENT Location 2
 // OpDecorate %in_var_TANGENT Location 2
 // OpDecorate %out_var_BEZIERPOS Location 0
 // OpDecorate %out_var_BEZIERPOS Location 0
 // OpDecorate %out_var_TANGENT Location 1
 // OpDecorate %out_var_TANGENT Location 1
-// OpDecorate %out_var_TANUCORNER Location 2
-// OpDecorate %out_var_TANVCORNER Location 3
-// OpDecorate %out_var_TANWEIGHTS Location 4
-// OpDecorate %out_var_TEXCOORD Location 5
+// OpDecorate %out_var_TANUCORNER Location 5
+// OpDecorate %out_var_TANVCORNER Location 9
+// OpDecorate %out_var_TANWEIGHTS Location 13
+// OpDecorate %out_var_TEXCOORD Location 14
 // %void = OpTypeVoid
 // %void = OpTypeVoid
 // %3 = OpTypeFunction %void
 // %3 = OpTypeFunction %void
 // %float = OpTypeFloat 32
 // %float = OpTypeFloat 32

+ 1 - 1
tools/clang/test/CodeGenSPIRV/spirv.interface.hs.hlsl

@@ -90,7 +90,7 @@ struct HsPcfOut
 // CHECK: OpDecorate %out_var_BAR Location 0
 // CHECK: OpDecorate %out_var_BAR Location 0
 // CHECK: OpDecorate %out_var_FOO Location 1
 // CHECK: OpDecorate %out_var_FOO Location 1
 // CHECK: OpDecorate %out_var_TEXCOORD Location 2
 // CHECK: OpDecorate %out_var_TEXCOORD Location 2
-// CHECK: OpDecorate %out_var_WEIGHT Location 3
+// CHECK: OpDecorate %out_var_WEIGHT Location 6
 
 
 // Input : clip0 + clip2         : 3 floats
 // Input : clip0 + clip2         : 3 floats
 // Input : cull3 + cull5         : 4 floats
 // Input : cull3 + cull5         : 4 floats

+ 46 - 0
tools/clang/test/CodeGenSPIRV/vk.location.composite.hlsl

@@ -0,0 +1,46 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: OpDecorate %in_var_A Location 0
+// CHECK: OpDecorate %in_var_B Location 1
+// CHECK: OpDecorate %in_var_C Location 2
+// CHECK: OpDecorate %in_var_D Location 4
+// CHECK: OpDecorate %in_var_E Location 6
+// CHECK: OpDecorate %in_var_F Location 8
+// CHECK: OpDecorate %in_var_G Location 16
+
+// CHECK: OpDecorate %out_var_A Location 0
+// CHECK: OpDecorate %out_var_B Location 2
+// CHECK: OpDecorate %out_var_C Location 3
+// CHECK: OpDecorate %out_var_D Location 4
+// CHECK: OpDecorate %out_var_E Location 5
+// CHECK: OpDecorate %out_var_F Location 11
+// CHECK: OpDecorate %out_var_G Location 13
+// CHECK: OpDecorate %out_var_H Location 14
+
+struct S {
+    half2x3  matrix2x3 : A; // 0 (+2)
+    float1x2 vector1x2 : B; // 2 (+1)
+    float3x1 vector3x1 : C; // 3 (+1)
+    float1x1 scalar1x1 : D; // 4 (+1)
+};
+
+struct T {
+    S        s;
+    float2x3 array1[3] : E; // 5  (+2*3)
+    half1x2  array2[2] : F; // 11 (+1*2)
+    half3x1  array3[1] : G; // 13 (+1*1)
+    float    array4[4] : H; // 14 (+1*4)
+};
+
+T main(
+    double    a   : A, // 0  (+1)
+    double2   b   : B, // 1  (+1)
+    double3   c   : C, // 2  (+2)
+    double4   d   : D, // 4  (+2)
+    double2x2 e   : E, // 6  (+1*2)
+    double2x3 f[2]: F, // 8  (+2*2*2)
+    double2x3 g   : G  // 16 (+2x2)
+) {
+    T t = (T)0;
+    return t;
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1097,6 +1097,9 @@ TEST_F(FileTest, VulkanLocationInputExplicitOutputImplicit) {
 TEST_F(FileTest, VulkanLocationInputImplicitOutputExplicit) {
 TEST_F(FileTest, VulkanLocationInputImplicitOutputExplicit) {
   runFileTest("vk.location.exp-out.hlsl");
   runFileTest("vk.location.exp-out.hlsl");
 }
 }
+TEST_F(FileTest, VulkanLocationCompositeTypes) {
+  runFileTest("vk.location.composite.hlsl");
+}
 TEST_F(FileTest, VulkanLocationTooLarge) {
 TEST_F(FileTest, VulkanLocationTooLarge) {
   runFileTest("vk.location.large.hlsl", Expect::Failure);
   runFileTest("vk.location.large.hlsl", Expect::Failure);
 }
 }