Pārlūkot izejas kodu

[DXIL-2017-06] Check multi component uav on load not store operations. (#373)

Check multi component uav on load not store operations.
Young Kim 8 gadi atpakaļ
vecāks
revīzija
3eb1daa46e

+ 61 - 11
lib/HLSL/DxilModule.cpp

@@ -312,16 +312,25 @@ static bool IsResourceSingleComponent(llvm::Type *Ty) {
   return true;
 }
 
-bool DxilModule::ModuleHasMulticomponentUAVLoads() {
-  for (const auto &uav : GetUAVs()) {
-    const DxilResource *res = uav.get();
-    if (res->IsTypedBuffer() || res->IsAnyTexture()) {
-      if (!IsResourceSingleComponent(res->GetRetType())) {
-          return true;
-      }
+// Given a CreateHandle call, returns arbitrary ConstantInt rangeID
+// Note: HLSL is currently assuming that rangeID is a constant value, but this code is assuming
+// that it can be either constant, phi node, or select instruction
+static ConstantInt *GetArbitraryConstantRangeID(CallInst *handleCall) {
+  Value *rangeID =
+      handleCall->getArgOperand(DXIL::OperandIndex::kCreateHandleResIDOpIdx);
+  ConstantInt *ConstantRangeID = dyn_cast<ConstantInt>(rangeID);
+  while (ConstantRangeID == nullptr) {
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(rangeID)) {
+      ConstantRangeID = CI;
+    } else if (PHINode *PN = dyn_cast<PHINode>(rangeID)) {
+      rangeID = PN->getIncomingValue(0);
+    } else if (SelectInst *SI = dyn_cast<SelectInst>(rangeID)) {
+      rangeID = SI->getTrueValue();
+    } else {
+      return nullptr;
     }
   }
-  return false;
+  return ConstantRangeID;
 }
 
 void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
@@ -336,7 +345,15 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
   bool hasMSAD = false;
   bool hasInnerCoverage = false;
   bool hasViewID = false;
-  bool hasMulticomponentUAVLoads = ModuleHasMulticomponentUAVLoads();
+  bool hasMulticomponentUAVLoads = false;
+  bool hasMulticomponentUAVLoadsBackCompat = false;
+
+  // Try to maintain compatibility with a v1.0 validator if that's what we have.
+  {
+    unsigned valMajor, valMinor;
+    GetValidatorVersion(valMajor, valMinor);
+    hasMulticomponentUAVLoadsBackCompat = valMajor <= 1 && valMinor == 0;
+  }
 
   Type *int16Ty = Type::getInt16Ty(GetCtx());
   Type *int64Ty = Type::getInt64Ty(GetCtx());
@@ -401,6 +418,39 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
           case DXIL::OpCode::Msad:
             hasMSAD = true;
             break;
+          case DXIL::OpCode::BufferLoad:
+          case DXIL::OpCode::TextureLoad: {
+            if (hasMulticomponentUAVLoads) continue;
+            // This is the old-style computation (overestimating requirements).
+            Value *resHandle = CI->getArgOperand(DXIL::OperandIndex::kBufferStoreHandleOpIdx);
+            CallInst *handleCall = cast<CallInst>(resHandle);
+
+            if (ConstantInt *resClassArg =
+              dyn_cast<ConstantInt>(handleCall->getArgOperand(
+                DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
+              DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
+                resClassArg->getLimitedValue());
+              if (resClass == DXIL::ResourceClass::UAV) {
+                // Validator 1.0 assumes that all uav load is multi component load.
+                if (hasMulticomponentUAVLoadsBackCompat) {
+                  hasMulticomponentUAVLoads = true;
+                  continue;
+                }
+                else {
+                  ConstantInt *rangeID = GetArbitraryConstantRangeID(handleCall);
+                  if (rangeID) {
+                      DxilResource resource = GetUAV(rangeID->getLimitedValue());
+                      if (!IsResourceSingleComponent(resource.GetRetType())) {
+                          hasMulticomponentUAVLoads = true;
+                      }
+                  }
+                }
+              }
+            }
+            else {
+                DXASSERT(false, "Resource class must be constant.");
+            }
+          } break;
           case DXIL::OpCode::Fma:
             hasDoubleExtension |= isDouble;
             break;
@@ -793,12 +843,12 @@ static void CollectUsedResource(Value *resID,
   } else if (SelectInst *SI = dyn_cast<SelectInst>(resID)) {
     CollectUsedResource(SI->getTrueValue(), usedResID);
     CollectUsedResource(SI->getFalseValue(), usedResID);
-  } else {
-    PHINode *Phi = cast<PHINode>(resID);
+  } else if (PHINode *Phi = dyn_cast<PHINode>(resID)) {
     for (Use &U : Phi->incoming_values()) {
       CollectUsedResource(U.get(), usedResID);
     }
   }
+  // TODO: resID could be other types of instructions depending on the compiler optimization.
 }
 
 static void ConvertUsedResource(std::unordered_set<unsigned> &immResID,

+ 1 - 1
tools/clang/test/CodeGenHLSL/multiUAVLoad5.hlsl

@@ -1,4 +1,4 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 
 // CHECK-NOT: Typed UAV Load Additional Formats
 

+ 1 - 1
tools/clang/test/CodeGenHLSL/multiUAVLoad6.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 
-// CHECK: Typed UAV Load Additional Formats
+// CHECK-NOT: Typed UAV Load Additional Formats
 
 RWBuffer<int> buf[5][1];
 RWBuffer<int2> buf_2[3];

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -2924,6 +2924,7 @@ TEST_F(CompilerTest, CodeGenMultiUAVLoad1) {
 }
 
 TEST_F(CompilerTest, CodeGenMultiUAVLoad2) {
+  if (m_ver.SkipDxil_1_1_Test()) return;
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad2.hlsl");
 }
 
@@ -2932,14 +2933,17 @@ TEST_F(CompilerTest, CodeGenMultiUAVLoad3) {
 }
 
 TEST_F(CompilerTest, CodeGenMultiUAVLoad4) {
+  if (m_ver.SkipDxil_1_1_Test()) return;
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad4.hlsl");
 }
 
 TEST_F(CompilerTest, CodeGenMultiUAVLoad5) {
+  if (m_ver.SkipDxil_1_1_Test()) return;
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad5.hlsl");
 }
 
 TEST_F(CompilerTest, CodeGenMultiUAVLoad6) {
+  if (m_ver.SkipDxil_1_1_Test()) return;
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad6.hlsl");
 }
 
@@ -3408,6 +3412,7 @@ TEST_F(CompilerTest, CodeGenUint64_1) {
 }
 
 TEST_F(CompilerTest, CodeGenUint64_2) {
+  if (m_ver.SkipDxil_1_1_Test()) return;
   CodeGenTestCheck(L"..\\CodeGenHLSL\\uint64_2.hlsl");
 }