瀏覽代碼

[spirv] fix no-op matrix type cast bug (#2111)

Type cast float2x2 to int2x2 causes a validation error when emitting
SPIR-V. This CL fixes it by adding type conversion in InitListHandler.
Jaebaek Seo 6 年之前
父節點
當前提交
5b0d61aa05

+ 1 - 1
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -277,4 +277,4 @@ QualType getHLSLMatrixType(ASTContext &, Sema &, ClassTemplateDecl *,
 } // namespace spirv
 } // namespace spirv
 } // namespace clang
 } // namespace clang
 
 
-#endif // LLVM_CLANG_SPIRV_TYPEPROBE_H
+#endif // LLVM_CLANG_SPIRV_TYPEPROBE_H

+ 2 - 4
tools/clang/lib/SPIRV/InitListHandler.cpp

@@ -312,12 +312,10 @@ InitListHandler::createInitForMatrixType(QualType matrixType,
       uint32_t initRowCount = 0, initColCount = 0;
       uint32_t initRowCount = 0, initColCount = 0;
       hlsl::GetHLSLMatRowColCount(init->getAstResultType(), initRowCount,
       hlsl::GetHLSLMatRowColCount(init->getAstResultType(), initRowCount,
                                   initColCount);
                                   initColCount);
-
       if (rowCount == initRowCount && colCount == initColCount) {
       if (rowCount == initRowCount && colCount == initColCount) {
         initializers.pop_back();
         initializers.pop_back();
-        // TODO: We only support FP matrices now. Do type cast here after
-        // adding more matrix types.
-        return init;
+        return theEmitter.castToType(init, init->getAstResultType(), matrixType,
+                                     srcLoc);
       }
       }
     }
     }
   }
   }

+ 6 - 6
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -945,17 +945,17 @@ bool SpirvEmitter::loadIfAliasVarRef(const Expr *varExpr,
 SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value,
 SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value,
                                            QualType fromType, QualType toType,
                                            QualType fromType, QualType toType,
                                            SourceLocation srcLoc) {
                                            SourceLocation srcLoc) {
-  if (isFloatOrVecOfFloatType(toType))
+  if (isFloatOrVecMatOfFloatType(toType))
     return castToFloat(value, fromType, toType, srcLoc);
     return castToFloat(value, fromType, toType, srcLoc);
 
 
   // Order matters here. Bool (vector) values will also be considered as uint
   // Order matters here. Bool (vector) values will also be considered as uint
   // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType()
   // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType()
   // will also return true. We need to check bool before uint. The opposite is
   // will also return true. We need to check bool before uint. The opposite is
   // not true.
   // not true.
-  if (isBoolOrVecOfBoolType(toType))
+  if (isBoolOrVecMatOfBoolType(toType))
     return castToBool(value, fromType, toType);
     return castToBool(value, fromType, toType);
 
 
-  if (isSintOrVecOfSintType(toType) || isUintOrVecOfUintType(toType))
+  if (isSintOrVecMatOfSintType(toType) || isUintOrVecMatOfUintType(toType))
     return castToInt(value, fromType, toType, srcLoc);
     return castToInt(value, fromType, toType, srcLoc);
 
 
   emitError("casting to type %0 unimplemented", {}) << toType;
   emitError("casting to type %0 unimplemented", {}) << toType;
@@ -6117,7 +6117,7 @@ SpirvInstruction *SpirvEmitter::turnIntoElementPtr(
 SpirvInstruction *SpirvEmitter::castToBool(SpirvInstruction *fromVal,
 SpirvInstruction *SpirvEmitter::castToBool(SpirvInstruction *fromVal,
                                            QualType fromType,
                                            QualType fromType,
                                            QualType toBoolType) {
                                            QualType toBoolType) {
-  if (isSameScalarOrVecType(fromType, toBoolType))
+  if (isSameType(astContext, fromType, toBoolType))
     return fromVal;
     return fromVal;
 
 
   { // Special case handling for converting to a matrix of booleans.
   { // Special case handling for converting to a matrix of booleans.
@@ -6147,7 +6147,7 @@ SpirvInstruction *SpirvEmitter::castToBool(SpirvInstruction *fromVal,
 SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
 SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
                                           QualType fromType, QualType toIntType,
                                           QualType fromType, QualType toIntType,
                                           SourceLocation srcLoc) {
                                           SourceLocation srcLoc) {
-  if (isSameScalarOrVecType(fromType, toIntType))
+  if (isSameType(astContext, fromType, toIntType))
     return fromVal;
     return fromVal;
 
 
   if (isBoolOrVecOfBoolType(fromType)) {
   if (isBoolOrVecOfBoolType(fromType)) {
@@ -6254,7 +6254,7 @@ SpirvInstruction *SpirvEmitter::castToFloat(SpirvInstruction *fromVal,
                                             QualType fromType,
                                             QualType fromType,
                                             QualType toFloatType,
                                             QualType toFloatType,
                                             SourceLocation srcLoc) {
                                             SourceLocation srcLoc) {
-  if (isSameScalarOrVecType(fromType, toFloatType))
+  if (isSameType(astContext, fromType, toFloatType))
     return fromVal;
     return fromVal;
 
 
   if (isBoolOrVecOfBoolType(fromType)) {
   if (isBoolOrVecOfBoolType(fromType)) {

+ 14 - 0
tools/clang/test/CodeGenSPIRV/cast.no-op.matrix.float-to-int.hlsl

@@ -0,0 +1,14 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+  float2x2 a;
+  int4 b;
+
+// CHECK:        [[a:%\d+]] = OpLoad %mat2v2float %a
+// CHECK-NEXT: [[a_0:%\d+]] = OpCompositeExtract %v2float [[a]] 0
+// CHECK-NEXT: [[a_0:%\d+]] = OpConvertFToS %v2int [[a_0]]
+// CHECK-NEXT: [[a_1:%\d+]] = OpCompositeExtract %v2float [[a]] 1
+// CHECK-NEXT: [[a_1:%\d+]] = OpConvertFToS %v2int [[a_1]]
+// CHECK-NEXT:   [[a:%\d+]] = OpCompositeConstruct %_arr_v2int_uint_2 [[a_0]] [[a_1]]
+  b.zw = mul(int2x2(a), b.yx);
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -382,6 +382,9 @@ TEST_F(FileTest, OpTextureSampleAccess) {
 
 
 // For casting
 // For casting
 TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
 TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
+TEST_F(FileTest, CastNoOpMatrixFloatToInt) {
+  runFileTest("cast.no-op.matrix.float-to-int.hlsl");
+}
 TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }
 TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }
 TEST_F(FileTest, CastExplicit2Bool) { runFileTest("cast.2bool.explicit.hlsl"); }
 TEST_F(FileTest, CastExplicit2Bool) { runFileTest("cast.2bool.explicit.hlsl"); }
 TEST_F(FileTest, CastImplicit2SInt) { runFileTest("cast.2sint.implicit.hlsl"); }
 TEST_F(FileTest, CastImplicit2SInt) { runFileTest("cast.2sint.implicit.hlsl"); }