Browse Source

Propagate literal types through loads. (#5339)

We currently do not do anything to determine the correct type for
variables with a literal type. This leads to type mismatches. In this
commit, I propagate the type the result of the load to the variable that
is loaded. This makes sure that the load is correct.

Fixes #5319
Steven Perron 2 years ago
parent
commit
bf2d6c073a

+ 21 - 0
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -298,6 +298,27 @@ bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
   return true;
 }
 
+bool LiteralTypeVisitor::visit(SpirvLoad *inst) {
+  auto *pointer = inst->getPointer();
+  if (!pointer->hasAstResultType())
+    return true;
+
+  QualType pointerType = pointer->getAstResultType();
+  if (!isLitTypeOrVecOfLitType(pointerType))
+    return true;
+
+  assert(inst->hasAstResultType());
+  QualType resultType = inst->getAstResultType();
+  assert(!isLitTypeOrVecOfLitType(resultType));
+
+  if (!canDeduceTypeFromLitType(pointerType, resultType))
+    return true;
+
+  QualType newPointerType = astContext.getPointerType(resultType);
+  pointer->setAstResultType(newPointerType);
+  return true;
+}
+
 bool LiteralTypeVisitor::visit(SpirvStore *inst) {
   auto *object = inst->getObject();
   auto *pointer = inst->getPointer();

+ 1 - 0
tools/clang/lib/SPIRV/LiteralTypeVisitor.h

@@ -34,6 +34,7 @@ public:
   bool visit(SpirvVectorShuffle *) override;
   bool visit(SpirvNonUniformUnaryOp *) override;
   bool visit(SpirvNonUniformBinaryOp *) override;
+  bool visit(SpirvLoad *) override;
   bool visit(SpirvStore *) override;
   bool visit(SpirvConstantComposite *) override;
   bool visit(SpirvCompositeConstruct *) override;

+ 59 - 0
tools/clang/test/CodeGenSPIRV/select.long.lit.hlsl2021.hlsl

@@ -0,0 +1,59 @@
+// RUN: %dxc -T ps_6_0 -HV 2021 -E main
+
+// Check that the literals get a 64-bit type, and the result of the select is
+// then cast to an unsigned 64-bit value.
+void foo(uint x) {
+// CHECK:      %foo = OpFunction
+// CHECK-NEXT: [[param:%\w+]] = OpFunctionParameter %_ptr_Function_uint
+// CHECK-NEXT: OpLabel
+// CHECK-NEXT: [[value:%\w+]] = OpVariable %_ptr_Function_ulong Function
+// CHECK-NEXT: [[temp:%\w+]] = OpVariable %_ptr_Function_long Function
+// CHECK-NEXT: [[ld:%\w+]] = OpLoad %uint [[param]]
+// CHECK-NEXT: [[cmp:%\w+]] = OpULessThan %bool [[ld]] %uint_64
+// CHECK-NEXT: OpSelectionMerge [[merge_bb:%\w+]] None
+// CHECK-NEXT: OpBranchConditional [[cmp]] [[true_bb:%\w+]] [[false_bb:%\w+]]
+// CHECK-NEXT: [[true_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %long_1
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[false_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %long_0
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[merge_bb]] = OpLabel
+// CHECK-NEXT: [[ld2:%\w+]] = OpLoad %long [[temp]]
+// CHECK-NEXT: [[res:%\w+]] = OpBitcast %ulong [[ld2]]
+// CHECK-NEXT: OpStore [[value]] [[res:%\w+]]
+  uint64_t value = x < 64 ? 1 : 0;
+}
+
+// Check that the literals get a 64-bit type, and the result of the select is
+// then cast to an signed 64-bit value. Note that the bitcast is redundant in
+// this case, but we add the bitcast before the type of the literal has been
+// determined, so we add it anyway.
+void bar(uint x) {
+// CHECK:      %bar = OpFunction
+// CHECK-NEXT: [[param:%\w+]] = OpFunctionParameter %_ptr_Function_uint
+// CHECK-NEXT: OpLabel
+// CHECK-NEXT: [[value:%\w+]] = OpVariable %_ptr_Function_long Function
+// CHECK-NEXT: [[temp:%\w+]] = OpVariable %_ptr_Function_long Function
+// CHECK-NEXT: [[ld:%\w+]] = OpLoad %uint [[param]]
+// CHECK-NEXT: [[cmp:%\w+]] = OpULessThan %bool [[ld]] %uint_64
+// CHECK-NEXT: OpSelectionMerge [[merge_bb:%\w+]] None
+// CHECK-NEXT: OpBranchConditional [[cmp]] [[true_bb:%\w+]] [[false_bb:%\w+]]
+// CHECK-NEXT: [[true_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %long_1
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[false_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %long_0
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[merge_bb]] = OpLabel
+// CHECK-NEXT: [[ld2:%\w+]] = OpLoad %long [[temp]]
+// CHECK-NEXT: [[res:%\w+]] = OpBitcast %long [[ld2]]
+// CHECK-NEXT: OpStore [[value]] [[res:%\w+]]
+  int64_t value = x < 64 ? 1 : 0;
+}
+
+void main() {
+  uint value;
+  foo(2);
+  bar(2);
+}

+ 59 - 0
tools/clang/test/CodeGenSPIRV/select.short.lit.hlsl2021.hlsl

@@ -0,0 +1,59 @@
+// RUN: %dxc -T ps_6_2 -HV 2021 -E main -enable-16bit-types
+
+// Check that the literals get a 16-bit type, and the result of the select is
+// then cast to an unsigned 16-bit value.
+void foo(uint x) {
+// CHECK:      %foo = OpFunction
+// CHECK-NEXT: [[param:%\w+]] = OpFunctionParameter %_ptr_Function_uint
+// CHECK-NEXT: OpLabel
+// CHECK-NEXT: [[value:%\w+]] = OpVariable %_ptr_Function_ushort Function
+// CHECK-NEXT: [[temp:%\w+]] = OpVariable %_ptr_Function_short Function
+// CHECK-NEXT: [[ld:%\w+]] = OpLoad %uint [[param]]
+// CHECK-NEXT: [[cmp:%\w+]] = OpULessThan %bool [[ld]] %uint_64
+// CHECK-NEXT: OpSelectionMerge [[merge_bb:%\w+]] None
+// CHECK-NEXT: OpBranchConditional [[cmp]] [[true_bb:%\w+]] [[false_bb:%\w+]]
+// CHECK-NEXT: [[true_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %short_1
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[false_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %short_0
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[merge_bb]] = OpLabel
+// CHECK-NEXT: [[ld2:%\w+]] = OpLoad %short [[temp]]
+// CHECK-NEXT: [[res:%\w+]] = OpBitcast %ushort [[ld2]]
+// CHECK-NEXT: OpStore [[value]] [[res:%\w+]]
+  uint16_t value = x < 64 ? 1 : 0;
+}
+
+// Check that the literals get a 16-bit type, and the result of the select is
+// then cast to an signed 16-bit value. Note that the bitcast is redundant in
+// this case, but we add the bitcast before the type of the literal has been
+// determined, so we add it anyway.
+void bar(uint x) {
+// CHECK:      %bar = OpFunction
+// CHECK-NEXT: [[param:%\w+]] = OpFunctionParameter %_ptr_Function_uint
+// CHECK-NEXT: OpLabel
+// CHECK-NEXT: [[value:%\w+]] = OpVariable %_ptr_Function_short Function
+// CHECK-NEXT: [[temp:%\w+]] = OpVariable %_ptr_Function_short Function
+// CHECK-NEXT: [[ld:%\w+]] = OpLoad %uint [[param]]
+// CHECK-NEXT: [[cmp:%\w+]] = OpULessThan %bool [[ld]] %uint_64
+// CHECK-NEXT: OpSelectionMerge [[merge_bb:%\w+]] None
+// CHECK-NEXT: OpBranchConditional [[cmp]] [[true_bb:%\w+]] [[false_bb:%\w+]]
+// CHECK-NEXT: [[true_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %short_1
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[false_bb]] = OpLabel
+// CHECK-NEXT: OpStore [[temp]] %short_0
+// CHECK-NEXT: OpBranch [[merge_bb]]
+// CHECK-NEXT: [[merge_bb]] = OpLabel
+// CHECK-NEXT: [[ld2:%\w+]] = OpLoad %short [[temp]]
+// CHECK-NEXT: [[res:%\w+]] = OpBitcast %short [[ld2]]
+// CHECK-NEXT: OpStore [[value]] [[res:%\w+]]
+  int16_t value = x < 64 ? 1 : 0;
+}
+
+void main() {
+  uint value;
+  foo(2);
+  bar(2);
+}

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -557,6 +557,12 @@ TEST_F(FileTest, CastLiteralTypeForTernary) {
 
 TEST_F(FileTest, SelectLongLit) { runFileTest("select.long.lit.hlsl"); }
 TEST_F(FileTest, SelectShortLit) { runFileTest("select.short.lit.hlsl"); }
+TEST_F(FileTest, SelectLongLit2021) {
+  runFileTest("select.long.lit.hlsl2021.hlsl");
+}
+TEST_F(FileTest, SelectShortLit2021) {
+  runFileTest("select.short.lit.hlsl2021.hlsl");
+}
 
 TEST_F(FileTest, CastLiteralTypeForTernary2021) {
   runFileTest("cast.literal-type.ternary.2021.hlsl");