Prechádzať zdrojové kódy

[spirv] Fix missing coverage for literal types. (#1934)

Ehsan 6 rokov pred
rodič
commit
f3c03360b3

+ 14 - 0
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -39,6 +39,15 @@ bool LiteralTypeVisitor::canDeduceTypeFromLitType(QualType litType,
                                                   QualType newType) {
   if (litType == QualType() || newType == QualType() || litType == newType)
     return false;
+
+  // The 'inout' and 'out' function arguments are of a reference type.
+  // For example: 'uint &'.
+  // We should first remove such reference from QualType (if any).
+  if (const auto *refType = litType->getAs<ReferenceType>())
+    litType = refType->getPointeeType();
+  if (const auto *refType = newType->getAs<ReferenceType>())
+    newType = refType->getPointeeType();
+
   if (!isLitTypeOrVecOfLitType(litType))
     return false;
   if (isLitTypeOrVecOfLitType(newType))
@@ -399,6 +408,11 @@ bool LiteralTypeVisitor::visit(SpirvAccessChain *inst) {
                           ? astContext.IntTy
                           : astContext.UnsignedIntTy);
       }
+    } else {
+      tryToUpdateInstLitType(index,
+                             index->getAstResultType()->isSignedIntegerType()
+                                 ? astContext.IntTy
+                                 : astContext.UnsignedIntTy);
     }
   }
   return true;

+ 12 - 0
tools/clang/test/CodeGenSPIRV/cast.literal-type.array-subscript.hlsl

@@ -0,0 +1,12 @@
+// Run: %dxc -T cs_6_0 -E main
+
+RWStructuredBuffer<uint> Out;
+groupshared uint Mem[1];
+[numthreads(1, 1, 1)]
+void main() {
+  // CHECK: [[sub:%\d+]] = OpISub %int %int_1 %int_1
+  // CHECK:     {{%\d+}} = OpAccessChain %_ptr_Workgroup_uint %Mem [[sub]]
+  Mem[1 - 1] = 0;
+  Out[0] = Mem[0];
+}
+

+ 14 - 0
tools/clang/test/CodeGenSPIRV/cast.literal-type.ternary.hlsl

@@ -0,0 +1,14 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// The 'out' argument in the function should be handled correctly when deducing
+// the literal type.
+void foo(out uint value, uint x) {
+  // CHECK: [[cond:%\d+]] = OpULessThan %bool {{%\d+}} %uint_64
+  // CHECK:      {{%\d+}} = OpSelect %uint [[cond]] %uint_1 %uint_0
+  value = x < 64 ? 1 : 0;
+}
+
+void main() {
+  uint value;
+  foo(value, 2);
+}

+ 8 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -397,6 +397,14 @@ TEST_F(FileTest, CastExplicitVecToMat) {
 }
 TEST_F(FileTest, CastBitwidth) { runFileTest("cast.bitwidth.hlsl"); }
 
+TEST_F(FileTest, CastLiteralTypeForArraySubscript) {
+  runFileTest("cast.literal-type.array-subscript.hlsl");
+}
+
+TEST_F(FileTest, CastLiteralTypeForTernary) {
+  runFileTest("cast.literal-type.ternary.hlsl");
+}
+
 // For vector/matrix splatting and trunction
 TEST_F(FileTest, CastTruncateVector) { runFileTest("cast.vector.trunc.hlsl"); }
 TEST_F(FileTest, CastTruncateMatrix) { runFileTest("cast.matrix.trunc.hlsl"); }