Explorar el Código

const folding on dxil.convergent.marker. (#2523)

Xiang Li hace 5 años
padre
commit
0bd0afe693

+ 2 - 0
include/dxc/HLSL/DxilConvergent.h

@@ -6,9 +6,11 @@
 // License. See LICENSE.TXT for details.                                     //
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
+#pragma once
 
 namespace llvm {
   class Value;
+  class Function;
 }
 
 namespace hlsl {

+ 15 - 0
include/dxc/HLSL/DxilConvergentName.h

@@ -0,0 +1,15 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// DxilConvergentName.h                                                      //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+//  Expose helper function name to avoid link issue with spirv.              //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+#pragma once
+
+namespace hlsl {
+  static char *kConvergentFunctionPrefix = "dxil.convergent.marker.";
+}

+ 22 - 2
lib/Analysis/DxilConstantFolding.cpp

@@ -36,10 +36,23 @@
 #include <functional>
 
 #include "dxc/DXIL/DXIL.h"
-
+#include "dxc/HLSL/DxilConvergentName.h"
 using namespace llvm;
 using namespace hlsl;
 
+namespace {
+
+bool IsConvergentMarker(const Function *F) {
+  return F->getName().startswith(kConvergentFunctionPrefix);
+}
+
+bool IsConvergentMarker(const char *Name) {
+  StringRef RName = Name;
+  return RName.startswith(kConvergentFunctionPrefix);
+}
+
+} // namespace
+
 // Check if the given function is a dxil intrinsic and if so extract the
 // opcode for the instrinsic being called.
 static bool GetDxilOpcode(StringRef Name, ArrayRef<Constant *> Operands, OP::OpCode &out) {
@@ -535,6 +548,12 @@ Constant *hlsl::ConstantFoldScalarCall(StringRef Name, Type *Ty, ArrayRef<Consta
     else if (Ty->isIntegerTy()) {
       return ConstantFoldIntIntrinsic(opcode, Ty, IntrinsicOperands);
     }
+  } else if (IsConvergentMarker(Name.data())) {
+    assert(RawOperands.size() == 1);
+    if (ConstantInt *C = dyn_cast<ConstantInt>(RawOperands[0]))
+      return C;
+    if (ConstantFP *C = dyn_cast<ConstantFP>(RawOperands[0]))
+      return C;
   }
 
   return hlsl::ConstantFoldScalarCallExt(Name, Ty, RawOperands);
@@ -550,7 +569,8 @@ bool hlsl::CanConstantFoldCallTo(const Function *F) {
     assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
     return false;
   }
-
+  if (IsConvergentMarker(F))
+    return true;
   // Lookup opcode class in dxil module. Set default value to invalid class.
   OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
   const bool found = F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);

+ 1 - 4
lib/HLSL/DxilConvergent.cpp

@@ -24,14 +24,11 @@
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/DxilConvergent.h"
 #include "dxc/HlslIntrinsicOp.h"
+#include "dxc/HLSL/DxilConvergentName.h"
 
 using namespace llvm;
 using namespace hlsl;
 
-namespace {
-const StringRef kConvergentFunctionPrefix = "dxil.convergent.marker.";
-}
-
 bool hlsl::IsConvergentMarker(Value *V) {
   CallInst *CI = dyn_cast<CallInst>(V);
   if (!CI)

+ 41 - 0
tools/clang/test/HLSLFileCheck/hlsl/functions/misc/convergent_const_folding.hlsl

@@ -0,0 +1,41 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure dxil.convergent.marker will be const folding.
+// CHECK:@dx.op.storeOutput
+
+SamplerState ss1;
+SamplerState ss0;
+
+struct Option
+{
+ bool cond;
+};
+
+
+Texture2D<float4> Tex;
+
+float4 ps(
+ float2 uv,
+ bool cond
+)
+{
+
+ Option op;
+
+ op.cond = cond;
+
+ if (op.cond)
+ {  
+  float c = op.cond ? 0.0f : 1;
+  uv = Tex.Sample(ss0, c).xy;
+ }
+
+ SamplerState texSampler = (op.cond?ss0:ss1);
+ return Tex.Sample(texSampler, uv);
+}
+
+float4 main(float2 uv :UV) :SV_Target
+{
+  return ps(uv, 0);
+
+}