Ver Fonte

[spirv] Translate FlatConversion implicit cast. (#750)

Ehsan há 7 anos atrás
pai
commit
529f6ee6bf

+ 122 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1480,14 +1480,134 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
   case CastKind::CK_FunctionToPointerDecay:
     // Just need to return the function id
     return doExpr(subExpr);
+  case CastKind::CK_FlatConversion: {
+    // Optimization: we can use OpConstantNull for cases where we want to
+    // initialize an entire data structure to zeros.
+    llvm::APSInt intValue;
+    if (subExpr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects) &&
+        intValue.getExtValue() == 0) {
+      return theBuilder.getConstantNull(typeTranslator.translateType(toType));
+    } else {
+      return processFlatConversion(toType, subExpr->getType(), doExpr(subExpr));
+    }
+  }
   default:
-    emitError("ImplictCast Kind '%0' is not supported yet.")
-        << expr->getCastKindName();
+    emitError("ImplictCast Kind '%0' is not supported yet.",
+              expr->getLocStart())
+        << expr->getCastKindName() << expr->getSourceRange();
     expr->dump();
     return 0;
   }
 }
 
+uint32_t SPIRVEmitter::processFlatConversion(const QualType type,
+                                             const QualType initType,
+                                             const uint32_t initId) {
+  // Try to translate the canonical type first
+  const auto canonicalType = type.getCanonicalType();
+  if (canonicalType != type)
+    return processFlatConversion(canonicalType, initType, initId);
+
+  // Primitive types
+  {
+    QualType ty = {};
+    if (TypeTranslator::isScalarType(type, &ty)) {
+      if (const auto *builtinType = ty->getAs<BuiltinType>()) {
+        switch (builtinType->getKind()) {
+        case BuiltinType::Void: {
+          emitError("cannot create a constant of void type");
+          return 0;
+        }
+        case BuiltinType::Bool:
+          return castToBool(initId, initType, ty);
+          // int, min16int (short), and min12int are all translated to 32-bit
+          // signed integers in SPIR-V.
+        case BuiltinType::Int:
+        case BuiltinType::Short:
+        case BuiltinType::Min12Int:
+        case BuiltinType::UShort:
+        case BuiltinType::UInt:
+          return castToInt(initId, initType, ty);
+          // float, min16float (half), and min10float are all translated to
+          // 32-bit float in SPIR-V.
+        case BuiltinType::Float:
+        case BuiltinType::Half:
+        case BuiltinType::Min10Float:
+          return castToFloat(initId, initType, ty);
+        default:
+          emitError("flat conversion of type %0 unimplemented")
+              << builtinType->getTypeClassName();
+          return 0;
+        }
+      }
+    }
+  }
+  // Vector types
+  {
+    QualType elemType = {};
+    uint32_t elemCount = {};
+    if (TypeTranslator::isVectorType(type, &elemType, &elemCount)) {
+      const uint32_t elemId = processFlatConversion(elemType, initType, initId);
+      llvm::SmallVector<uint32_t, 4> constituents(size_t(elemCount), elemId);
+      return theBuilder.createCompositeConstruct(
+          typeTranslator.translateType(type), constituents);
+    }
+  }
+
+  // Matrix types
+  {
+    QualType elemType = {};
+    uint32_t rowCount = 0, colCount = 0;
+    if (TypeTranslator::isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
+      if (!elemType->isFloatingType()) {
+        emitError("non-floating-point matrix type unimplemented");
+        return 0;
+      }
+
+      // By default HLSL matrices are row major, while SPIR-V matrices are
+      // column major. We are mapping what HLSL semantically mean a row into a
+      // column here.
+      const uint32_t vecType = theBuilder.getVecType(
+          typeTranslator.translateType(elemType), colCount);
+      const uint32_t elemId = processFlatConversion(elemType, initType, initId);
+      const llvm::SmallVector<uint32_t, 4> constituents(size_t(colCount),
+                                                        elemId);
+      const uint32_t colId =
+          theBuilder.createCompositeConstruct(vecType, constituents);
+      const llvm::SmallVector<uint32_t, 4> rows(size_t(rowCount), colId);
+      return theBuilder.createCompositeConstruct(
+          typeTranslator.translateType(type), rows);
+    }
+  }
+
+  // Struct type
+  if (const auto *structType = type->getAs<RecordType>()) {
+    const auto *decl = structType->getDecl();
+    llvm::SmallVector<uint32_t, 4> fields;
+    for (const auto *field : decl->fields())
+      fields.push_back(
+          processFlatConversion(field->getType(), initType, initId));
+    return theBuilder.createCompositeConstruct(
+        typeTranslator.translateType(type), fields);
+  }
+
+  // Array type
+  if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
+    const auto size =
+        static_cast<uint32_t>(arrayType->getSize().getZExtValue());
+    const uint32_t elemId =
+        processFlatConversion(arrayType->getElementType(), initType, initId);
+    llvm::SmallVector<uint32_t, 4> constituents(size_t(size), elemId);
+    return theBuilder.createCompositeConstruct(
+        typeTranslator.translateType(type), constituents);
+  }
+
+  emitError("flat conversion of type %0 unimplemented")
+      << type->getTypeClassName();
+  type->dump();
+  return 0;
+}
+
 SpirvEvalInfo
 SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
   const auto opcode = expr->getOpcode();

+ 7 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -348,6 +348,13 @@ private:
   /// one will be a vector of size N.
   uint32_t getMatElemValueOne(QualType type);
 
+private:
+  /// \brief Performs a FlatConversion implicit cast. Fills an instance of the
+  /// given type with initializer <result-id>. The initializer is of type
+  /// initType.
+  uint32_t processFlatConversion(const QualType type, const QualType initType,
+                                 uint32_t initId);
+
 private:
   /// Translates the given frontend APValue into its SPIR-V equivalent for the
   /// given targetType.

+ 56 - 0
tools/clang/test/CodeGenSPIRV/cast.flat-conversion.implicit.hlsl

@@ -0,0 +1,56 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct VSOutput {
+  float4   sv_pos     : SV_POSITION;
+  uint3    normal     : NORMAL;
+  int2     tex_coord  : TEXCOORD;
+  bool     mybool[2]  : MYBOOL;
+  int      arr[5]     : MYARRAY;
+  float2x3 mat2x3     : MYMATRIX;
+};
+
+
+// CHECK: [[nullVSOutput:%\d+]] = OpConstantNull %VSOutput
+
+
+void main() {
+  int x = 3;
+
+// CHECK: OpStore %output1 [[nullVSOutput]]
+  VSOutput output1 = (VSOutput)0;
+
+// TODO: Avoid OpBitCast from 'literal int' to 'int'
+//
+// CHECK:                [[f1:%\d+]] = OpConvertSToF %float %int_1
+// CHECK-NEXT:         [[v4f1:%\d+]] = OpCompositeConstruct %v4float [[f1]] [[f1]] [[f1]] [[f1]]
+// CHECK-NEXT:           [[u1:%\d+]] = OpBitcast %uint %int_1
+// CHECK-NEXT:         [[v3u1:%\d+]] = OpCompositeConstruct %v3uint [[u1]] [[u1]] [[u1]]
+// CHECK-NEXT:         [[i1_0:%\d+]] = OpBitcast %int %int_1
+// CHECK-NEXT:         [[v2i1:%\d+]] = OpCompositeConstruct %v2int [[i1_0]] [[i1_0]]
+// CHECK-NEXT:        [[bool1:%\d+]] = OpINotEqual %bool %int_1 %int_0
+// CHECK-NEXT:    [[arr2bool1:%\d+]] = OpCompositeConstruct %_arr_bool_uint_2 [[bool1]] [[bool1]]
+// CHECK-NEXT:         [[i1_1:%\d+]] = OpBitcast %int %int_1
+// CHECK-NEXT:       [[arr5i1:%\d+]] = OpCompositeConstruct %_arr_int_uint_5 [[i1_1]] [[i1_1]] [[i1_1]] [[i1_1]] [[i1_1]]
+// CHECK-NEXT:         [[f1_1:%\d+]] = OpConvertSToF %float %int_1
+// CHECK-NEXT:         [[col3:%\d+]] = OpCompositeConstruct %v3float [[f1_1]] [[f1_1]] [[f1_1]]
+// CHECK-NEXT:    [[matFloat1:%\d+]] = OpCompositeConstruct %mat2v3float [[col3]] [[col3]]
+// CHECK-NEXT: [[flatConvert1:%\d+]] = OpCompositeConstruct %VSOutput [[v4f1]] [[v3u1]] [[v2i1]] [[arr2bool1]] [[arr5i1]] [[matFloat1]]
+// CHECK-NEXT:                         OpStore %output2 [[flatConvert1]]
+  VSOutput output2 = (VSOutput)1;
+
+// CHECK:                [[x:%\d+]] = OpLoad %int %x
+// CHECK-NEXT:       [[floatX:%\d+]] = OpConvertSToF %float [[x]]
+// CHECK-NEXT:         [[v4fX:%\d+]] = OpCompositeConstruct %v4float [[floatX]] [[floatX]] [[floatX]] [[floatX]]
+// CHECK-NEXT:        [[uintX:%\d+]] = OpBitcast %uint [[x]]
+// CHECK-NEXT:         [[v3uX:%\d+]] = OpCompositeConstruct %v3uint [[uintX]] [[uintX]] [[uintX]]
+// CHECK-NEXT:         [[v2iX:%\d+]] = OpCompositeConstruct %v2int [[x]] [[x]]
+// CHECK-NEXT:        [[boolX:%\d+]] = OpINotEqual %bool [[x]] %int_0
+// CHECK-NEXT:    [[arr2boolX:%\d+]] = OpCompositeConstruct %_arr_bool_uint_2 [[boolX]] [[boolX]]
+// CHECK-NEXT:       [[arr5iX:%\d+]] = OpCompositeConstruct %_arr_int_uint_5 [[x]] [[x]] [[x]] [[x]] [[x]]
+// CHECK-NEXT:      [[floatX2:%\d+]] = OpConvertSToF %float [[x]]
+// CHECK-NEXT:         [[v3fX:%\d+]] = OpCompositeConstruct %v3float [[floatX2]] [[floatX2]] [[floatX2]]
+// CHECK-NEXT:    [[matFloatX:%\d+]] = OpCompositeConstruct %mat2v3float [[v3fX]] [[v3fX]]
+// CHECK-NEXT: [[flatConvert2:%\d+]] = OpCompositeConstruct %VSOutput [[v4fX]] [[v3uX]] [[v2iX]] [[arr2boolX]] [[arr5iX]] [[matFloatX]]
+// CHECK-NEXT:                         OpStore %output3 [[flatConvert2]]
+  VSOutput output3 = (VSOutput)x;
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -270,6 +270,9 @@ TEST_F(FileTest, CastImplicit2UInt) { runFileTest("cast.2uint.implicit.hlsl"); }
 TEST_F(FileTest, CastExplicit2UInt) { runFileTest("cast.2uint.explicit.hlsl"); }
 TEST_F(FileTest, CastImplicit2FP) { runFileTest("cast.2fp.implicit.hlsl"); }
 TEST_F(FileTest, CastExplicit2FP) { runFileTest("cast.2fp.explicit.hlsl"); }
+TEST_F(FileTest, CastImplicitFlatConversion) {
+  runFileTest("cast.flat-conversion.implicit.hlsl");
+}
 
 // For vector/matrix splatting and trunction
 TEST_F(FileTest, CastTruncateVector) { runFileTest("cast.vector.trunc.hlsl"); }