Răsfoiți Sursa

[spirv] fix Interlocked member call type cast bug (#2121)

Jaebaek Seo 6 ani în urmă
părinte
comite
8d86efebe4

+ 6 - 2
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2757,12 +2757,16 @@ SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods(
       spvBuilder.createStore(doExpr(expr->getArg(3)), originalVal);
   } else {
     auto *value = doExpr(expr->getArg(1));
-    auto *originalVal = spvBuilder.createAtomicOp(
+    SpirvInstruction *originalVal = spvBuilder.createAtomicOp(
         translateAtomicHlslOpcodeToSpirvOpcode(opcode),
         astContext.UnsignedIntTy, ptr, spv::Scope::Device,
         spv::MemorySemanticsMask::MaskNone, value, srcLoc);
-    if (expr->getNumArgs() > 2)
+    if (expr->getNumArgs() > 2) {
+      originalVal = castToType(originalVal, astContext.UnsignedIntTy,
+                               expr->getArg(2)->getType(),
+                               expr->getArg(2)->getLocStart());
       spvBuilder.createStore(doExpr(expr->getArg(2)), originalVal);
+    }
   }
 
   return nullptr;

+ 12 - 0
tools/clang/test/CodeGenSPIRV/cast.2float.interlocked.hlsl

@@ -0,0 +1,12 @@
+// Run: %dxc -T ps_6_0 -E main
+
+RWByteAddressBuffer foo;
+groupshared float bar;
+
+void main() {
+// CHECK:      [[foo:%\d+]] = OpAccessChain %_ptr_Uniform_uint %foo %uint_0 {{%\d+}}
+// CHECK-NEXT: [[foo:%\d+]] = OpAtomicIAdd %uint [[foo]] %uint_1 %uint_0 %uint_42
+// CHECK-NEXT: [[foo:%\d+]] = OpConvertUToF %float [[foo]]
+// CHECK-NEXT:                OpStore %bar [[foo]]
+  foo.InterlockedAdd(16, 42, bar);
+}