瀏覽代碼

Prevent sinking coord calculation for sample (#3657)

merge #3655

The mark convergent pass is meant to prevent unwanted moving of
operations on derivative op input. It was previously only run on pixel
shaders. Because derivatives are supported in CS/MS/AS shaders as part
of shader model 6.6, it needs to be run on these stages for that target
too.

(cherry picked from commit 93a98982cf51eca3724840e196031437ae529bdc)
Greg Roth 4 年之前
父節點
當前提交
76e7647cc6
共有 2 個文件被更改,包括 50 次插入1 次删除
  1. 2 1
      lib/HLSL/DxilConvergent.cpp
  2. 48 0
      tools/clang/test/HLSLFileCheck/hlsl/objects/Texture/convergent_cs.hlsl

+ 2 - 1
lib/HLSL/DxilConvergent.cpp

@@ -47,7 +47,8 @@ public:
 
   bool runOnModule(Module &M) override {
     if (M.HasHLModule()) {
-      if (!M.GetHLModule().GetShaderModel()->IsPS())
+      const ShaderModel *SM = M.GetHLModule().GetShaderModel();
+      if (!SM->IsPS() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
         return false;
     }
     bool bUpdated = false;

+ 48 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/Texture/convergent_cs.hlsl

@@ -0,0 +1,48 @@
+// RUN: %dxc -E MainCS -T cs_6_6 %s | FileCheck %s
+// RUN: %dxc -E MainAS -T as_6_6 %s | FileCheck %s
+// RUN: %dxc -E MainMS -T ms_6_6 %s | FileCheck %s
+
+// Make sure add is not sunk into if.
+// Compute shader variant of convergent.hlsl
+
+// CHECK: add
+// CHECK: add
+// CHECK: icmp
+// CHECK-NEXT: br
+
+
+Texture2D<float4> tex;
+RWBuffer<float4> output;
+SamplerState s;
+
+void doit(uint ix, uint3 id){
+
+  float2 coord = id.xy + id.z;
+  float4 c = id.z;
+  if (id.z > 2) {
+    c += tex.Sample(s, coord);
+  }
+  output[ix] = c;
+
+}
+
+[numthreads(4,4,4)]
+void MainCS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
+  doit(ix, id);
+}
+
+struct Payload { int nothing; };
+
+[numthreads(4,4,4)]
+void MainAS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
+  doit(ix, id);
+  Payload pld = (Payload)0;
+  DispatchMesh(1,1,1,pld);
+}
+
+
+[numthreads(4,4,4)]
+[outputtopology("triangle")]
+void MainMS(uint ix : SV_GroupIndex, uint3 id : SV_GroupThreadID) {
+  doit(ix, id);
+}