2
0
Эх сурвалжийг харах

[spirv] Support assigning arrays as a whole (#1038)

For array wholesale assignments, the rhs will be wrapped in a
FlatConversion implicit cast. This cast can be ignored since it
does not affect CodeGen.
Lei Zhang 7 жил өмнө
parent
commit
d90f31f177

+ 15 - 49
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -60,52 +60,6 @@ bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
 
 // TODO: Maybe we should move these type probing functions to TypeTranslator.
 
-/// Returns true if the two types can be treated as the same scalar type:
-/// * Having the same canonical type
-/// * Literal vs no-literal
-bool canTreatAsSameScalarType(QualType type1, QualType type2) {
-  return (type1.getCanonicalType() == type2.getCanonicalType()) ||
-         // Treat 'literal float' and 'float' as the same
-         (type1->isSpecificBuiltinType(BuiltinType::LitFloat) &&
-          type2->isFloatingType()) ||
-         (type2->isSpecificBuiltinType(BuiltinType::LitFloat) &&
-          type1->isFloatingType()) ||
-         // Treat 'literal int' and 'int'/'uint' as the same
-         (type1->isSpecificBuiltinType(BuiltinType::LitInt) &&
-          type2->isIntegerType() &&
-          // Disallow boolean types
-          !type2->isSpecificBuiltinType(BuiltinType::Bool)) ||
-         (type2->isSpecificBuiltinType(BuiltinType::LitInt) &&
-          type1->isIntegerType() &&
-          // Disallow boolean types
-          !type1->isSpecificBuiltinType(BuiltinType::Bool));
-}
-
-/// Returns true if the two types are the same scalar or vector type,
-/// disregarding constness and literalness.
-bool isSameScalarOrVecType(QualType type1, QualType type2) {
-  // Consider cases such as 'const bool' and 'bool' to be the same type.
-  type1.removeLocalConst();
-  type2.removeLocalConst();
-
-  {
-    QualType scalarType1 = {}, scalarType2 = {};
-    if (TypeTranslator::isScalarType(type1, &scalarType1) &&
-        TypeTranslator::isScalarType(type2, &scalarType2))
-      return canTreatAsSameScalarType(scalarType1, scalarType2);
-  }
-
-  {
-    QualType elemType1 = {}, elemType2 = {};
-    uint32_t count1 = {}, count2 = {};
-    if (TypeTranslator::isVectorType(type1, &elemType1, &count1) &&
-        TypeTranslator::isVectorType(type2, &elemType2, &count2))
-      return count1 == count2 && canTreatAsSameScalarType(elemType1, elemType2);
-  }
-
-  return false;
-}
-
 /// Returns true if the given type is a bool or vector of bool type.
 bool isBoolOrVecOfBoolType(QualType type) {
   QualType elemType = {};
@@ -2150,6 +2104,18 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
       if (subExprId)
         evalType = isSigned ? astContext.IntTy : astContext.UnsignedIntTy;
     }
+    // For assigning one array instance to another one with the same array type
+    // (regardless of constness and literalness), the rhs will be wrapped in a
+    // FlatConversion:
+    //  |- <lhs>
+    //  `- ImplicitCastExpr <FlatConversion>
+    //     `- ImplicitCastExpr <LValueToRValue>
+    //        `- <rhs>
+    // This FlatConversion does not affect CodeGen, so that we can ignore it.
+    else if (subExprType->isArrayType() &&
+             typeTranslator.isSameType(expr->getType(), subExprType)) {
+      return doExpr(subExpr);
+    }
 
     if (!subExprId)
       subExprId = doExpr(subExpr);
@@ -5471,7 +5437,7 @@ SpirvEvalInfo &SPIRVEmitter::turnIntoElementPtr(
 
 uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
                                   QualType toBoolType) {
-  if (isSameScalarOrVecType(fromType, toBoolType))
+  if (TypeTranslator::isSameScalarOrVecType(fromType, toBoolType))
     return fromVal;
 
   // Converting to bool means comparing with value zero.
@@ -5484,7 +5450,7 @@ uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
 
 uint32_t SPIRVEmitter::castToInt(const uint32_t fromVal, QualType fromType,
                                  QualType toIntType, SourceLocation srcLoc) {
-  if (isSameScalarOrVecType(fromType, toIntType))
+  if (TypeTranslator::isSameScalarOrVecType(fromType, toIntType))
     return fromVal;
 
   uint32_t intType = typeTranslator.translateType(toIntType);
@@ -5518,7 +5484,7 @@ uint32_t SPIRVEmitter::castToInt(const uint32_t fromVal, QualType fromType,
 uint32_t SPIRVEmitter::castToFloat(const uint32_t fromVal, QualType fromType,
                                    QualType toFloatType,
                                    SourceLocation srcLoc) {
-  if (isSameScalarOrVecType(fromType, toFloatType))
+  if (TypeTranslator::isSameScalarOrVecType(fromType, toFloatType))
     return fromVal;
 
   const uint32_t floatType = typeTranslator.translateType(toFloatType);

+ 70 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -758,6 +758,76 @@ bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {
   return isMxNMatrix(type, &elemType) && elemType->isFloatingType();
 }
 
+bool TypeTranslator::canTreatAsSameScalarType(QualType type1, QualType type2) {
+  // Treat const int/float the same as const int/float
+  type1.removeLocalConst();
+  type2.removeLocalConst();
+
+  return (type1.getCanonicalType() == type2.getCanonicalType()) ||
+         // Treat 'literal float' and 'float' as the same
+         (type1->isSpecificBuiltinType(BuiltinType::LitFloat) &&
+          type2->isFloatingType()) ||
+         (type2->isSpecificBuiltinType(BuiltinType::LitFloat) &&
+          type1->isFloatingType()) ||
+         // Treat 'literal int' and 'int'/'uint' as the same
+         (type1->isSpecificBuiltinType(BuiltinType::LitInt) &&
+          type2->isIntegerType() &&
+          // Disallow boolean types
+          !type2->isSpecificBuiltinType(BuiltinType::Bool)) ||
+         (type2->isSpecificBuiltinType(BuiltinType::LitInt) &&
+          type1->isIntegerType() &&
+          // Disallow boolean types
+          !type1->isSpecificBuiltinType(BuiltinType::Bool));
+}
+
+bool TypeTranslator::isSameScalarOrVecType(QualType type1, QualType type2) {
+  { // Scalar types
+    QualType scalarType1 = {}, scalarType2 = {};
+    if (TypeTranslator::isScalarType(type1, &scalarType1) &&
+        TypeTranslator::isScalarType(type2, &scalarType2))
+      return canTreatAsSameScalarType(scalarType1, scalarType2);
+  }
+
+  { // Vector types
+    QualType elemType1 = {}, elemType2 = {};
+    uint32_t count1 = {}, count2 = {};
+    if (TypeTranslator::isVectorType(type1, &elemType1, &count1) &&
+        TypeTranslator::isVectorType(type2, &elemType2, &count2))
+      return count1 == count2 && canTreatAsSameScalarType(elemType1, elemType2);
+  }
+
+  return false;
+}
+
+bool TypeTranslator::isSameType(QualType type1, QualType type2) {
+  if (isSameScalarOrVecType(type1, type2))
+    return true;
+
+  type1.removeLocalConst();
+  type2.removeLocalConst();
+
+  { // Matrix types
+    QualType elemType1 = {}, elemType2 = {};
+    uint32_t row1 = 0, row2 = 0, col1 = 0, col2 = 0;
+    if (TypeTranslator::isMxNMatrix(type1, &elemType1, &row1, &col1) &&
+        TypeTranslator::isMxNMatrix(type2, &elemType2, &row2, &col2))
+      return row1 == row2 && col1 == col2 &&
+             canTreatAsSameScalarType(elemType1, elemType2);
+  }
+
+  { // Array types
+    if (const auto *arrType1 = astContext.getAsConstantArrayType(type1))
+      if (const auto *arrType2 = astContext.getAsConstantArrayType(type2))
+        return hlsl::GetArraySize(type1) == hlsl::GetArraySize(type2) &&
+               isSameType(arrType1->getElementType(),
+                          arrType2->getElementType());
+  }
+
+  // TODO: support other types if needed
+
+  return false;
+}
+
 QualType TypeTranslator::getElementType(QualType type) {
   QualType elemType = {};
   (void)(isScalarType(type, &elemType) || isVectorType(type, &elemType) ||

+ 13 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -173,6 +173,14 @@ public:
   /// counts.
   static bool isSpirvAcceptableMatrixType(QualType type);
 
+  /// \brief Returns true if the two types are the same scalar or vector type,
+  /// regardless of constness and literalness.
+  static bool isSameScalarOrVecType(QualType type1, QualType type2);
+
+  /// \brief Returns true if the two types are the same type, regardless of
+  /// constness and literalness.
+  bool isSameType(QualType type1, QualType type2);
+
   /// \brief Returns true if the given type can use relaxed precision
   /// decoration. Integer and float types with lower than 32 bits can be
   /// operated on with a relaxed precision.
@@ -227,6 +235,11 @@ private:
     return diags.Report(diagId);
   }
 
+  /// \brief Returns true if the two types can be treated as the same scalar
+  /// type, which means they have the same canonical type, regardless of
+  /// constnesss and literalness.
+  static bool canTreatAsSameScalarType(QualType type1, QualType type2);
+
   /// \brief Translates the given HLSL resource type into its SPIR-V
   /// instructions and returns the <result-id>. Returns 0 on failure.
   uint32_t translateResourceType(QualType type, LayoutRule rule);

+ 46 - 0
tools/clang/test/CodeGenSPIRV/cast.flat-conversion.no-op.hlsl

@@ -0,0 +1,46 @@
+// Run: %dxc -T ps_6_0 -E main
+
+cbuffer Data {
+    float    gScalars[1];
+    float4   gVecs[2];
+    float2x3 gMats[1];
+}
+
+struct T {
+    float    scalars[1];
+    float4   vecs[2];
+    float2x3 mats[1];
+};
+
+float4 main() : SV_Target {
+    T t;
+
+// CHECK:        [[gscalars_ptr:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_1 %var_Data %int_0
+// CHECK-NEXT:   [[gscalars_val:%\d+]] = OpLoad %_arr_float_uint_1 [[gscalars_ptr]]
+// CHECK-NEXT:    [[scalars_ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_float_uint_1_0 %t %int_0
+// CHECK-NEXT:      [[gscalars0:%\d+]] = OpCompositeExtract %float [[gscalars_val]] 0
+// CHECK-NEXT:   [[scalars0_ptr:%\d+]] = OpAccessChain %_ptr_Function_float [[scalars_ptr]] %uint_0
+// CHECK-NEXT:                           OpStore [[scalars0_ptr]] [[gscalars0]]
+    t.scalars = gScalars;
+
+// CHECK-NEXT: [[gvecs_ptr:%\d+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 %var_Data %int_1
+// CHECK-NEXT: [[gvecs_val:%\d+]] = OpLoad %_arr_v4float_uint_2 [[gvecs_ptr]]
+// CHECK-NEXT:  [[vecs_ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v4float_uint_2_0 %t %int_1
+// CHECK-NEXT:    [[gvecs0:%\d+]] = OpCompositeExtract %v4float [[gvecs_val]] 0
+// CHECK-NEXT: [[vecs0_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[vecs_ptr]] %uint_0
+// CHECK-NEXT:                      OpStore [[vecs0_ptr]] [[gvecs0]]
+// CHECK-NEXT:    [[gvecs1:%\d+]] = OpCompositeExtract %v4float [[gvecs_val]] 1
+// CHECK-NEXT: [[vecs1_ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float [[vecs_ptr]] %uint_1
+// CHECK-NEXT:                      OpStore [[vecs1_ptr]] [[gvecs1]]
+    t.vecs    = gVecs;
+
+// CHECK-NEXT: [[gmats_ptr:%\d+]] = OpAccessChain %_ptr_Uniform__arr_mat2v3float_uint_1 %var_Data %int_2
+// CHECK-NEXT: [[gmats_val:%\d+]] = OpLoad %_arr_mat2v3float_uint_1 [[gmats_ptr]]
+// CHECK-NEXT:  [[mats_ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_mat2v3float_uint_1_0 %t %int_2
+// CHECK-NEXT:    [[gmats0:%\d+]] = OpCompositeExtract %mat2v3float [[gmats_val]] 0
+// CHECK-NEXT: [[mats0_ptr:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float [[mats_ptr]] %uint_0
+// CHECK-NEXT:                      OpStore [[mats0_ptr]] [[gmats0]]
+    t.mats    = gMats;
+
+    return t.vecs[1];
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -323,6 +323,9 @@ TEST_F(FileTest, CastImplicitFlatConversion) {
 TEST_F(FileTest, CastFlatConversionStruct) {
   runFileTest("cast.flat-conversion.struct.hlsl");
 }
+TEST_F(FileTest, CastFlatConversionNoOp) {
+  runFileTest("cast.flat-conversion.no-op.hlsl");
+}
 TEST_F(FileTest, CastExplicitVecToMat) {
   runFileTest("cast.vec-to-mat.explicit.hlsl");
 }