فهرست منبع

[spirv] Fix translation for WaveReadLaneAt. (#3056)

* [spirv] Fix translation for WaveReadLaneAt

WaveReadLaneAt should be translated to OpGroupNonUniformShuffle, not OpGroupNonUniformBroadcast.

* Add comments.
Ehsan 5 سال پیش
والد
کامیت
28af56515a

+ 4 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -450,6 +450,10 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   case spv::Op::OpGroupNonUniformBroadcastFirst:
   case spv::Op::OpGroupNonUniformBroadcastFirst:
     addCapability(spv::Capability::GroupNonUniformBallot);
     addCapability(spv::Capability::GroupNonUniformBallot);
     break;
     break;
+  case spv::Op::OpGroupNonUniformShuffle:
+  case spv::Op::OpGroupNonUniformShuffleXor:
+    addCapability(spv::Capability::GroupNonUniformShuffle);
+    break;
   case spv::Op::OpGroupNonUniformIAdd:
   case spv::Op::OpGroupNonUniformIAdd:
   case spv::Op::OpGroupNonUniformFAdd:
   case spv::Op::OpGroupNonUniformFAdd:
   case spv::Op::OpGroupNonUniformIMul:
   case spv::Op::OpGroupNonUniformIMul:

+ 4 - 1
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -7768,8 +7768,11 @@ SpirvInstruction *SpirvEmitter::processWaveBroadcast(const CallExpr *callExpr) {
   auto *value = doExpr(callExpr->getArg(0));
   auto *value = doExpr(callExpr->getArg(0));
   const QualType retType = callExpr->getCallReturnType(astContext);
   const QualType retType = callExpr->getCallReturnType(astContext);
   if (numArgs == 2)
   if (numArgs == 2)
+    // WaveReadLaneAt is in fact not a broadcast operation (even though its name
+    // might incorrectly suggest so). The proper mapping to SPIR-V for
+    // it is OpGroupNonUniformShuffle, *not* OpGroupNonUniformBroadcast.
     return spvBuilder.createGroupNonUniformBinaryOp(
     return spvBuilder.createGroupNonUniformBinaryOp(
-        spv::Op::OpGroupNonUniformBroadcast, retType, spv::Scope::Subgroup,
+        spv::Op::OpGroupNonUniformShuffle, retType, spv::Scope::Subgroup,
         value, doExpr(callExpr->getArg(1)), srcLoc);
         value, doExpr(callExpr->getArg(1)), srcLoc);
   else
   else
     return spvBuilder.createGroupNonUniformUnaryOp(
     return spvBuilder.createGroupNonUniformUnaryOp(

+ 4 - 4
tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-at.hlsl

@@ -10,7 +10,7 @@ struct S {
 
 
 RWStructuredBuffer<S> values;
 RWStructuredBuffer<S> values;
 
 
-// CHECK: OpCapability GroupNonUniformBallot
+// CHECK: OpCapability GroupNonUniformShuffle
 
 
 [numthreads(32, 1, 1)]
 [numthreads(32, 1, 1)]
 void main(uint3 id: SV_DispatchThreadID) {
 void main(uint3 id: SV_DispatchThreadID) {
@@ -21,12 +21,12 @@ void main(uint3 id: SV_DispatchThreadID) {
        int val3 = values[x].val3;
        int val3 = values[x].val3;
 
 
 // CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
 // CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
-// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %v4float %uint_3 [[val1]] %uint_15
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformShuffle %v4float %uint_3 [[val1]] %uint_15
     values[x].val1 = WaveReadLaneAt(val1, 15);
     values[x].val1 = WaveReadLaneAt(val1, 15);
 // CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
 // CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
-// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %v3uint %uint_3 [[val2]] %uint_42
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformShuffle %v3uint %uint_3 [[val2]] %uint_42
     values[x].val2 = WaveReadLaneAt(val2, 42);
     values[x].val2 = WaveReadLaneAt(val2, 42);
 // CHECK:      [[val3:%\d+]] = OpLoad %int %val3
 // CHECK:      [[val3:%\d+]] = OpLoad %int %val3
-// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %int %uint_3 [[val3]] %uint_15
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformShuffle %int %uint_3 [[val3]] %uint_15
     values[x].val3 = WaveReadLaneAt(val3, 15);
     values[x].val3 = WaveReadLaneAt(val3, 15);
 }
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-at.vulkan1.2.hlsl

@@ -10,7 +10,7 @@ struct S {
 
 
 RWStructuredBuffer<S> values;
 RWStructuredBuffer<S> values;
 
 
-// CHECK: OpCapability GroupNonUniformBallot
+// CHECK: OpCapability GroupNonUniformShuffle
 
 
 [numthreads(32, 1, 1)]
 [numthreads(32, 1, 1)]
 void main(uint3 id: SV_DispatchThreadID) {
 void main(uint3 id: SV_DispatchThreadID) {
@@ -21,12 +21,12 @@ void main(uint3 id: SV_DispatchThreadID) {
        int val3 = values[x].val3;
        int val3 = values[x].val3;
 
 
 // CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
 // CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
-// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %v4float %uint_3 [[val1]] %uint_15
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformShuffle %v4float %uint_3 [[val1]] %uint_15
     values[x].val1 = WaveReadLaneAt(val1, 15);
     values[x].val1 = WaveReadLaneAt(val1, 15);
 // CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
 // CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
-// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %v3uint %uint_3 [[val2]] %uint_42
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformShuffle %v3uint %uint_3 [[val2]] %uint_42
     values[x].val2 = WaveReadLaneAt(val2, 42);
     values[x].val2 = WaveReadLaneAt(val2, 42);
 // CHECK:      [[val3:%\d+]] = OpLoad %int %val3
 // CHECK:      [[val3:%\d+]] = OpLoad %int %val3
-// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %int %uint_3 [[val3]] %uint_15
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformShuffle %int %uint_3 [[val3]] %uint_15
     values[x].val3 = WaveReadLaneAt(val3, 15);
     values[x].val3 = WaveReadLaneAt(val3, 15);
 }
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/spirv.debug.opline.intrinsic.vulkan1.1.hlsl

@@ -49,7 +49,7 @@
   WavePrefixCountBits(i == 1);
   WavePrefixCountBits(i == 1);
 
 
 // CHECK:      OpLine [[file]] 53 3
 // CHECK:      OpLine [[file]] 53 3
-// CHECK-NEXT: OpGroupNonUniformBroadcast %int %uint_3
+// CHECK-NEXT: OpGroupNonUniformShuffle %int %uint_3
   WaveReadLaneAt(i, 15);
   WaveReadLaneAt(i, 15);
 
 
 // CHECK:      OpLine [[file]] 57 3
 // CHECK:      OpLine [[file]] 57 3