Selaa lähdekoodia

Prevent sinking coord calculation for sample (#3657)

merge #3655

The mark convergent pass is meant to prevent unwanted moving of
operations on derivative op input. It was previously only run on pixel
shaders. Because derivatives are supported in CS/MS/AS shaders as part
of shader model 6.6, it needs to be run on these stages for that target
too.

(cherry picked from commit 93a98982cf51eca3724840e196031437ae529bdc)
Greg Roth 4 vuotta sitten
vanhempi
commit
76e7647cc6

+ 2 - 1
lib/HLSL/DxilConvergent.cpp

@@ -47,7 +47,8 @@ public:
 
 
   bool runOnModule(Module &M) override {
   bool runOnModule(Module &M) override {
     if (M.HasHLModule()) {
     if (M.HasHLModule()) {
-      if (!M.GetHLModule().GetShaderModel()->IsPS())
+      const ShaderModel *SM = M.GetHLModule().GetShaderModel();
+      if (!SM->IsPS() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
         return false;
         return false;
     }
     }
     bool bUpdated = false;
     bool bUpdated = false;

+ 48 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/Texture/convergent_cs.hlsl

@@ -0,0 +1,48 @@
+// RUN: %dxc -E MainCS -T cs_6_6 %s | FileCheck %s
+// RUN: %dxc -E MainAS -T as_6_6 %s | FileCheck %s
+// RUN: %dxc -E MainMS -T ms_6_6 %s | FileCheck %s
+
+// Make sure add is not sunk into if.
+// Compute shader variant of convergent.hlsl
+
+// CHECK: add
+// CHECK: add
+// CHECK: icmp
+// CHECK-NEXT: br
+
+
+Texture2D<float4> tex;
+RWBuffer<float4> output;
+SamplerState s;
+
+void doit(uint ix, uint3 id){
+
+  float2 coord = id.xy + id.z;
+  float4 c = id.z;
+  if (id.z > 2) {
+    c += tex.Sample(s, coord);
+  }
+  output[ix] = c;
+
+}
+
+[numthreads(4,4,4)]
+void MainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
+  doit(ix, id);
+}
+
+struct Payload { int nothing; };
+
+[numthreads(4,4,4)]
+void MainAS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
+  doit(ix, id);
+  Payload pld = (Payload)0;
+  DispatchMesh(1,1,1,pld);
+}
+
+
+[numthreads(4,4,4)]
+[outputtopology("triangle")]
+void MainMS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
+  doit(ix, id);
+}