瀏覽代碼

spir-v: fix missing typecast on HLSL21 ternary (#4935)

In the AST, the type of the ternary in the case `i < 10 ? 1 : 0`
is `literal int`.

In the cast logic, we had a check saying "if it's a literal, ignore the
sign for now". Then a pass would determine the correct SPIR-V sign depending
on the variable the literal was stored to.

When introducing HLSL-2021, short-circuiting was allowed.
This means we have to evaluate only the valid side of the ternary.
This was done using a temporary register, storing the branch evaluation
result.

So instead of having:
`literal -> result variable (sign=U)`
we had:
`literal -> tmp variable (sign=?) -> result variable (sign=U)`

In the first case, we could generate the load correctly, as we had the
sign of the variable. In the second one, we could not, because the tmp
variable sign was unknown.

Now, why cannot we know the variable sign? because the AST type is
'literal int', and has no sign information. Adding more complexity to
the pass determining types could solve it, but doesn't seem like good
solution. After all, the AST tells us when to add a cast! So we should
simply follow the AST, and generate the cast when asked to.

This is not a perfect fix. I believe a good fix should be to remove this
'literal int' type, and keep sign information, like clang is doing
today, but it would impact the DXIL side.
Nathan Gauër 2 年之前
父節點
當前提交
c524f5e175

+ 21 - 13
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -568,16 +568,7 @@ bool canTreatAsSameScalarType(QualType type1, QualType type2) {
          (type1->isSpecificBuiltinType(BuiltinType::LitFloat) &&
           type2->isFloatingType()) ||
          (type2->isSpecificBuiltinType(BuiltinType::LitFloat) &&
-          type1->isFloatingType()) ||
-         // Treat 'literal int' and 'int'/'uint' as the same
-         (type1->isSpecificBuiltinType(BuiltinType::LitInt) &&
-          type2->isIntegerType() &&
-          // Disallow boolean types
-          !type2->isSpecificBuiltinType(BuiltinType::Bool)) ||
-         (type2->isSpecificBuiltinType(BuiltinType::LitInt) &&
-          type1->isIntegerType() &&
-          // Disallow boolean types
-          !type1->isSpecificBuiltinType(BuiltinType::Bool));
+          type1->isFloatingType());
 }
 
 bool canFitIntoOneRegister(const ASTContext &astContext, QualType structType,
@@ -738,9 +729,26 @@ bool isSameScalarOrVecType(QualType type1, QualType type2) {
   { // Vector types
     QualType elemType1 = {}, elemType2 = {};
     uint32_t count1 = {}, count2 = {};
-    if (isVectorType(type1, &elemType1, &count1) &&
-        isVectorType(type2, &elemType2, &count2))
-      return count1 == count2 && canTreatAsSameScalarType(elemType1, elemType2);
+    if (!isVectorType(type1, &elemType1, &count1) ||
+        !isVectorType(type2, &elemType2, &count2))
+      return false;
+
+    if (count1 != count2)
+      return false;
+
+    // That's a corner case we had to add to solve #4727.
+    // Normally, clang doesn't have the 'literal type', thus we can rely on
+    // direct type check. But this flavor of the AST has this 'literal int' type
+    // that is sign-less (nor signed or unsigned), until usage. Obviously,
+    // int(3) == literal int (3), but since they are considered different in the
+    // AST, we must check explicitly. Note: this is only valid here, as this is
+    // related to a vector size. Considering int == literal int elsewhere could
+    // break codegen, as SPIR-V does need explicit signedness.
+    return canTreatAsSameScalarType(elemType1, elemType2) ||
+           (elemType1->isIntegerType() &&
+            elemType2->isSpecificBuiltinType(BuiltinType::LitInt)) ||
+           (elemType2->isIntegerType() &&
+            elemType1->isSpecificBuiltinType(BuiltinType::LitInt));
   }
 
   return false;

+ 5 - 0
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -7946,6 +7946,11 @@ SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
                                    srcRange);
   }
 
+  if (fromType->isSpecificBuiltinType(BuiltinType::LitInt)) {
+    return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toIntType, fromVal,
+                                    srcLoc, srcRange);
+  }
+
   if (isSintOrVecOfSintType(fromType) || isUintOrVecOfUintType(fromType)) {
     // First convert the source to the bitwidth of the destination if necessary.
     QualType convertedType = {};

+ 23 - 0
tools/clang/test/CodeGenSPIRV/cast.literal-type.ternary.2021.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -T ps_6_0 -HV 2021 -E main
+
+// The 'out' argument in the function should be handled correctly when deducing
+// the literal type, even in HLSL 2021 (with shortcicuiting).
+void foo(out uint value, uint x) {
+  // CHECK:   [[cond:%\d+]] = OpULessThan %bool {{%\d+}} %uint_64
+  // CHECK:                   OpBranchConditional [[cond]] [[ternary_lhs:%\w+]] [[ternary_rhs:%\w+]]
+  // CHECK: [[ternary_lhs]] = OpLabel
+  // CHECK:                   OpStore [[tmp:%\w+]] %int_1
+  // CHECK:                   OpBranch [[merge:%\w+]]
+  // CHECK: [[ternary_rhs]] = OpLabel
+  // CHECK:                   OpStore [[tmp:%\w+]] %int_0
+  // CHECK:                   OpBranch [[merge]]
+  // CHECK:       [[merge]] = OpLabel
+  // CHECK:    [[res:%\d+]] = OpLoad %int [[tmp]]
+  // CHECK:        {{%\d+}} = OpBitcast %uint [[res]]
+  value = x < 64 ? 1 : 0;
+}
+
+void main() {
+  uint value;
+  foo(value, 2);
+}

+ 2 - 1
tools/clang/test/CodeGenSPIRV/cast.literal-type.ternary.hlsl

@@ -4,7 +4,8 @@
 // the literal type.
 void foo(out uint value, uint x) {
   // CHECK: [[cond:%\d+]] = OpULessThan %bool {{%\d+}} %uint_64
-  // CHECK:      {{%\d+}} = OpSelect %uint [[cond]] %uint_1 %uint_0
+  // CHECK: [[result:%\d+]] = OpSelect %int [[cond]] %int_1 %int_0
+  // CHECK: {{%\d+}} = OpBitcast %uint [[result]]
   value = x < 64 ? 1 : 0;
 }
 

+ 7 - 4
tools/clang/test/CodeGenSPIRV/intrinsics.select.hlsl

@@ -109,7 +109,8 @@ void main() {
   // CHECK:      [[c_long:%\d+]] = OpSelect %long {{%\d+}} %long_3000000000 %long_4000000000
   double c = select(cond, 3000000000, 4000000000);
 
-  // CHECK:      [[d_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
+  // CHECK:      [[d_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+  // CHECK-NEXT:       {{%\d+}} = OpBitcast %uint [[d_int]]
   uint d = select(cond, 1, 0);
 
   float2x3 e;
@@ -130,8 +131,9 @@ void main() {
   // CHECK-NEXT:                OpStore %g [[temp]]
   float2x3 g = select(cond, e, f);
 
-  // CHECK:      [[inner:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_2
-  // CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
+  // CHECK:       [[inner:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_2
+  // CHECK:      [[outter:%\d+]] = OpSelect %int {{%\d+}} %int_9 [[inner]]
+  // CHECK-NEXT:       {{%\d+}} = OpBitcast %uint [[outter]]
   uint h = select(cond, 9, select(cond, 1, 2));
 
   //CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
@@ -170,7 +172,8 @@ void main() {
 //
 // The literal integer type should be deduced from the function return type.
 //
-// CHECK: OpSelect %uint {{%\d+}} %uint_1 %uint_2
+// CHECK:      [[result:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_2
+// CHECK-NEXT:        {{%\d+}} = OpBitcast %uint [[result]]
 uint zoo() {
   bool cond;
   return select(cond, 1, 2);

+ 7 - 4
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -114,7 +114,8 @@ void main() {
   // CHECK:      [[c_long:%\d+]] = OpSelect %long {{%\d+}} %long_3000000000 %long_4000000000
   double c = cond ? 3000000000 : 4000000000;
 
-  // CHECK:      [[d_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
+  // CHECK:      [[d_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+  // CHECK:      {{%\d+}} = OpBitcast %uint [[d_int]]
   uint d = cond ? 1 : 0;
 
   float2x3 e;
@@ -135,8 +136,9 @@ void main() {
   // CHECK-NEXT:                OpStore %g [[temp]]
   float2x3 g = cond ? e : f;
 
-  // CHECK:      [[inner:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_2
-  // CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
+  // CHECK:       [[inner:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_2
+  // CHECK-NEXT: [[outter:%\d+]] = OpSelect %int {{%\d+}} %int_9 [[inner]]
+  // CHECK-NEXT:        {{%\d+}} = OpBitcast %uint [[outter]]
   uint h = cond ? 9 : (cond ? 1 : 2);
 
   //CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
@@ -221,7 +223,8 @@ void main() {
 //
 // The literal integer type should be deduced from the function return type.
 //
-// CHECK: OpSelect %uint {{%\d+}} %uint_1 %uint_2
+// CHECK:      [[result:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_2
+// CHECK-NEXT:                   OpBitcast %uint [[result]]
 uint zoo() {
   bool cond;
   return cond ? 1 : 2;

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -544,6 +544,10 @@ TEST_F(FileTest, CastLiteralTypeForTernary) {
   runFileTest("cast.literal-type.ternary.hlsl");
 }
 
+TEST_F(FileTest, CastLiteralTypeForTernary2021) {
+  runFileTest("cast.literal-type.ternary.2021.hlsl");
+}
+
 // For vector/matrix splatting and trunction
 TEST_F(FileTest, CastTruncateVector) { runFileTest("cast.vector.trunc.hlsl"); }
 TEST_F(FileTest, CastTruncateMatrix) { runFileTest("cast.matrix.trunc.hlsl"); }