Browse Source

Fix Multiple Component UAV Load Flag Check (#356)

This change is to fix the bug on fixing multi component UAV Load Flag Check by iterating over UAV resources in a DxilModule and check if a given resource is 1) either texture or typed buffer and 2) multi component resource. This change also changes shader flag collection stage after all necessary passes are complete from higher level DXIL to DXIL.
Young Kim 8 năm trước cách đây
mục cha
commit
2c568d00db

+ 4 - 0
include/dxc/HLSL/DxilModule.h

@@ -291,6 +291,10 @@ public:
   ShaderFlags m_ShaderFlags;
   void CollectShaderFlags(ShaderFlags &Flags);
 
+  // Check if DxilModule contains multi component UAV Loads.
+  // This funciton must be called after unused resources are removed from DxilModule
+  bool ModuleHasMulticomponentUAVLoads();
+
   // Compute shader.
   unsigned m_NumThreads[3];
 

+ 4 - 2
lib/HLSL/DxilGenerationPass.cpp

@@ -171,12 +171,12 @@ void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, bool HasDebugInfo) {
   //bool m_bEnableDoublePrecision;
   //bool m_bEnableDoubleExtensions;
   //bool m_bEnableMinPrecision;
-  M.CollectShaderFlags();
+  //M.CollectShaderFlags();
 
   //bool m_bForceEarlyDepthStencil;
   //bool m_bEnableRawAndStructuredBuffers;
   //bool m_bEnableMSAD;
-  M.m_ShaderFlags.SetAllResourcesBound(H.GetHLOptions().bAllResourcesBound);
+  //M.m_ShaderFlags.SetAllResourcesBound(H.GetHLOptions().bAllResourcesBound);
 
   // Compute shader.
   if (FnProps != nullptr && FnProps->shaderKind == DXIL::ShaderKind::Compute) {
@@ -235,6 +235,8 @@ void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, bool HasDebugInfo) {
   // Keep llvm used.
   M.EmitLLVMUsed();
 
+  M.m_ShaderFlags.SetAllResourcesBound(H.GetHLOptions().bAllResourcesBound);
+
   // Update Validator Version
   M.UpgradeToMinValidatorVersion();
 }

+ 36 - 35
lib/HLSL/DxilModule.cpp

@@ -291,6 +291,40 @@ unsigned DxilModule::GetGlobalFlags() const {
   return Flags;
 }
 
+static bool IsResourceSingleComponent(llvm::Type *Ty) {
+  if (llvm::ArrayType *arrType = llvm::dyn_cast<llvm::ArrayType>(Ty)) {
+    if (arrType->getArrayNumElements() > 1) {
+      return false;
+    }
+    return IsResourceSingleComponent(arrType->getArrayElementType());
+  } else if (llvm::StructType *structType =
+                 llvm::dyn_cast<llvm::StructType>(Ty)) {
+    if (structType->getStructNumElements() > 1) {
+      return false;
+    }
+    return IsResourceSingleComponent(structType->getStructElementType(0));
+  } else if (llvm::VectorType *vectorType =
+                 llvm::dyn_cast<llvm::VectorType>(Ty)) {
+    if (vectorType->getNumElements() > 1) {
+      return false;
+    }
+    return IsResourceSingleComponent(vectorType->getVectorElementType());
+  }
+  return true;
+}
+
+bool DxilModule::ModuleHasMulticomponentUAVLoads() {
+  for (const auto &uav : GetUAVs()) {
+    const DxilResource *res = uav.get();
+    if (res->IsTypedBuffer() || res->IsAnyTexture()) {
+      if (!IsResourceSingleComponent(res->GetRetType())) {
+          return true;
+      }
+    }
+  }
+  return false;
+}
+
 void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
   bool hasDouble = false;
   // ddiv dfma drcp d2i d2u i2d u2d.
@@ -301,9 +335,10 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
   bool hasWaveOps = false;
   bool hasCheckAccessFully = false;
   bool hasMSAD = false;
-  bool hasMulticomponentUAVLoads = false;
   bool hasInnerCoverage = false;
   bool hasViewID = false;
+  bool hasMulticomponentUAVLoads = ModuleHasMulticomponentUAVLoads();
+
   Type *int16Ty = Type::getInt16Ty(GetCtx());
   Type *int64Ty = Type::getInt64Ty(GetCtx());
 
@@ -367,40 +402,6 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
           case DXIL::OpCode::Msad:
             hasMSAD = true;
             break;
-          case DXIL::OpCode::BufferLoad:
-          case DXIL::OpCode::TextureLoad: {
-            Value *resHandle = CI->getArgOperand(DXIL::OperandIndex::kBufferStoreHandleOpIdx);
-            CallInst *handleCall = cast<CallInst>(resHandle);
-
-            if (ConstantInt *resClassArg =
-                    dyn_cast<ConstantInt>(handleCall->getArgOperand(
-                        DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
-              DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
-                  resClassArg->getLimitedValue());
-              if (resClass == DXIL::ResourceClass::UAV) {
-                // For DXIL, all uav load is multi component load.
-                hasMulticomponentUAVLoads = true;
-              }
-            } else if (PHINode *resClassPhi = dyn_cast<
-                           PHINode>(handleCall->getArgOperand(
-                           DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
-              unsigned numOperands = resClassPhi->getNumOperands();
-              for (unsigned i = 0; i < numOperands; i++) {
-                if (ConstantInt *resClassArg = dyn_cast<ConstantInt>(
-                        resClassPhi->getIncomingValue(i))) {
-                  DXIL::ResourceClass resClass =
-                      static_cast<DXIL::ResourceClass>(
-                          resClassArg->getLimitedValue());
-                  if (resClass == DXIL::ResourceClass::UAV) {
-                    // For DXIL, all uav load is multi component load.
-                    hasMulticomponentUAVLoads = true;
-                    break;
-                  }
-                }
-              }
-            }
-
-          } break;
           case DXIL::OpCode::Fma:
             hasDoubleExtension |= isDouble;
             break;

+ 11 - 0
tools/clang/test/CodeGenHLSL/multiUAVLoad1.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK: Typed UAV Load Additional Formats
+
+RWBuffer<float4> g_buf : register(u0);
+[numthreads(8,8,1)]
+void main(uint GI : SV_GroupIndex) {
+    uint addr = GI * 4;
+    float4 val = g_buf.Load(addr);
+    g_buf[addr] = val + 1;
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/multiUAVLoad2.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK-NOT: Typed UAV Load Additional Formats
+
+RWBuffer<float1> g_buf : register(u0);
+[numthreads(8,8,1)]
+void main(uint GI : SV_GroupIndex) {
+    uint addr = GI * 4;
+    float1 val = g_buf.Load(addr);
+    g_buf[addr] = val + 1;
+}

+ 15 - 0
tools/clang/test/CodeGenHLSL/multiUAVLoad3.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: Typed UAV Load Additional Formats
+
+struct PSInput {
+    float4 position : SV_POSITION;
+    float2 uv : TEXCOORD;
+};
+
+RWTexture2D<float2> g_tex : register(u0);
+
+float4 main(PSInput input) : SV_TARGET {
+    float2 val = g_tex.Load(input.uv);
+    return val.xyxx;
+}

+ 15 - 0
tools/clang/test/CodeGenHLSL/multiUAVLoad4.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK-NOT: Typed UAV Load Additional Formats
+
+struct PSInput {
+    float4 position : SV_POSITION;
+    float2 uv : TEXCOORD;
+};
+
+RWTexture2D<float> g_tex : register(u0);
+
+float4 main(PSInput input) : SV_TARGET {
+    float val = g_tex.Load(input.uv);
+    return val;
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/multiUAVLoad5.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK-NOT: Typed UAV Load Additional Formats
+
+RWByteAddressBuffer g_bab : register(u0);
+[numthreads(8,8,1)]
+void main(uint GI : SV_GroupIndex) {
+    uint addr = GI * 4;
+    uint val = g_bab.Load(addr);
+    g_bab.Store(addr, val + 1);
+}

+ 16 - 0
tools/clang/test/CodeGenHLSL/multiUAVLoad6.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK: Typed UAV Load Additional Formats
+
+RWBuffer<int> buf[5][1];
+RWBuffer<int2> buf_2[3];
+int ref;
+
+[numthreads(64,1,1)]
+void main(uint tid : SV_DispatchThreadID)
+{
+  if (ref > 0)
+    buf[1][0][1] = 3;
+  else
+    buf_2[2][3] = 3;
+}

+ 1 - 2
tools/clang/test/CodeGenHLSL/uint64_2.hlsl

@@ -1,10 +1,9 @@
 // RUN: %dxc -E main -T cs_6_0 -not_use_legacy_cbuf_load  %s | FileCheck %s
 
-// CHECK: Typed UAV Load Additional Formats
 // CHECK: 64-Bit integer
 // CHECK: dx.op.bufferStore.i32
 // CHECK: dx.op.bufferStore.i32
-// CHECK: !{i32 0, i64 1056768
+// CHECK: !{i32 0, i64 1048576
 
 // Note: a change in the internal layout will produce
 // a difference in the serialized flags, eg:

+ 30 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -545,6 +545,12 @@ public:
   TEST_METHOD(CodeGenMinprec6)
   TEST_METHOD(CodeGenMinprec7)
   TEST_METHOD(CodeGenMinprecCast)
+  TEST_METHOD(CodeGenMultiUAVLoad1)
+  TEST_METHOD(CodeGenMultiUAVLoad2)
+  TEST_METHOD(CodeGenMultiUAVLoad3)
+  TEST_METHOD(CodeGenMultiUAVLoad4)
+  TEST_METHOD(CodeGenMultiUAVLoad5)
+  TEST_METHOD(CodeGenMultiUAVLoad6)
   TEST_METHOD(CodeGenMultiStream)
   TEST_METHOD(CodeGenMultiStream2)
   TEST_METHOD(CodeGenNeg1)
@@ -2944,6 +2950,30 @@ TEST_F(CompilerTest, CodeGenMinprecCast) {
   CodeGenTest(L"..\\CodeGenHLSL\\minprec_cast.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenMultiUAVLoad1) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad1.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMultiUAVLoad2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad2.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMultiUAVLoad3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad3.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMultiUAVLoad4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad4.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMultiUAVLoad5) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad5.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMultiUAVLoad6) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad6.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenMultiStream) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiStreamGS.hlsl");
 }