Selaa lähdekoodia

Check multi component uav on load not store operations. (#372)

Check multi component uav on load not store operations.
Young Kim 8 vuotta sitten
vanhempi
commit
7f82a9e98c

+ 34 - 31
lib/HLSL/DxilModule.cpp

@@ -312,16 +312,25 @@ static bool IsResourceSingleComponent(llvm::Type *Ty) {
   return true;
   return true;
 }
 }
 
 
-bool DxilModule::ModuleHasMulticomponentUAVLoads() {
-  for (const auto &uav : GetUAVs()) {
-    const DxilResource *res = uav.get();
-    if (res->IsTypedBuffer() || res->IsAnyTexture()) {
-      if (!IsResourceSingleComponent(res->GetRetType())) {
-          return true;
-      }
+// Given a CreateHandle call, returns arbitrary ConstantInt rangeID
+// Note: HLSL is currently assuming that rangeID is a constant value, but this code is assuming
+// that it can be either constant, phi node, or select instruction
+static ConstantInt *GetArbitraryConstantRangeID(CallInst *handleCall) {
+  Value *rangeID =
+      handleCall->getArgOperand(DXIL::OperandIndex::kCreateHandleResIDOpIdx);
+  ConstantInt *ConstantRangeID = dyn_cast<ConstantInt>(rangeID);
+  while (ConstantRangeID == nullptr) {
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(rangeID)) {
+      ConstantRangeID = CI;
+    } else if (PHINode *PN = dyn_cast<PHINode>(rangeID)) {
+      rangeID = PN->getIncomingValue(0);
+    } else if (SelectInst *SI = dyn_cast<SelectInst>(rangeID)) {
+      rangeID = SI->getTrueValue();
+    } else {
+      return nullptr;
     }
     }
   }
   }
-  return false;
+  return ConstantRangeID;
 }
 }
 
 
 void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
 void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
@@ -345,8 +354,6 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
     GetValidatorVersion(valMajor, valMinor);
     GetValidatorVersion(valMajor, valMinor);
     hasMulticomponentUAVLoadsBackCompat = valMajor <= 1 && valMinor == 0;
     hasMulticomponentUAVLoadsBackCompat = valMajor <= 1 && valMinor == 0;
   }
   }
-  if (!hasMulticomponentUAVLoadsBackCompat)
-    hasMulticomponentUAVLoads = ModuleHasMulticomponentUAVLoads();
 
 
   Type *int16Ty = Type::getInt16Ty(GetCtx());
   Type *int16Ty = Type::getInt16Ty(GetCtx());
   Type *int64Ty = Type::getInt64Ty(GetCtx());
   Type *int64Ty = Type::getInt64Ty(GetCtx());
@@ -414,7 +421,6 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
           case DXIL::OpCode::BufferLoad:
           case DXIL::OpCode::BufferLoad:
           case DXIL::OpCode::TextureLoad: {
           case DXIL::OpCode::TextureLoad: {
             if (hasMulticomponentUAVLoads) continue;
             if (hasMulticomponentUAVLoads) continue;
-            if (!hasMulticomponentUAVLoadsBackCompat) continue;
             // This is the old-style computation (overestimating requirements).
             // This is the old-style computation (overestimating requirements).
             Value *resHandle = CI->getArgOperand(DXIL::OperandIndex::kBufferStoreHandleOpIdx);
             Value *resHandle = CI->getArgOperand(DXIL::OperandIndex::kBufferStoreHandleOpIdx);
             CallInst *handleCall = cast<CallInst>(resHandle);
             CallInst *handleCall = cast<CallInst>(resHandle);
@@ -425,28 +431,25 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
               DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
               DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
                 resClassArg->getLimitedValue());
                 resClassArg->getLimitedValue());
               if (resClass == DXIL::ResourceClass::UAV) {
               if (resClass == DXIL::ResourceClass::UAV) {
-                // For DXIL, all uav load is multi component load.
-                hasMulticomponentUAVLoads = true;
-              }
-            }
-            else if (PHINode *resClassPhi = dyn_cast<
-              PHINode>(handleCall->getArgOperand(
-                DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
-              unsigned numOperands = resClassPhi->getNumOperands();
-              for (unsigned i = 0; i < numOperands; i++) {
-                if (ConstantInt *resClassArg = dyn_cast<ConstantInt>(
-                  resClassPhi->getIncomingValue(i))) {
-                  DXIL::ResourceClass resClass =
-                    static_cast<DXIL::ResourceClass>(
-                      resClassArg->getLimitedValue());
-                  if (resClass == DXIL::ResourceClass::UAV) {
-                    // For DXIL, all uav load is multi component load.
-                    hasMulticomponentUAVLoads = true;
-                    break;
+                // Validator 1.0 assumes that all uav load is multi component load.
+                if (hasMulticomponentUAVLoadsBackCompat) {
+                  hasMulticomponentUAVLoads = true;
+                  continue;
+                }
+                else {
+                  ConstantInt *rangeID = GetArbitraryConstantRangeID(handleCall);
+                  if (rangeID) {
+                      DxilResource resource = GetUAV(rangeID->getLimitedValue());
+                      if (!IsResourceSingleComponent(resource.GetRetType())) {
+                          hasMulticomponentUAVLoads = true;
+                      }
                   }
                   }
                 }
                 }
               }
               }
             }
             }
+            else {
+                DXASSERT(false, "Resource class must be constant.");
+            }
           } break;
           } break;
           case DXIL::OpCode::Fma:
           case DXIL::OpCode::Fma:
             hasDoubleExtension |= isDouble;
             hasDoubleExtension |= isDouble;
@@ -941,12 +944,12 @@ static void CollectUsedResource(Value *resID,
   } else if (SelectInst *SI = dyn_cast<SelectInst>(resID)) {
   } else if (SelectInst *SI = dyn_cast<SelectInst>(resID)) {
     CollectUsedResource(SI->getTrueValue(), usedResID);
     CollectUsedResource(SI->getTrueValue(), usedResID);
     CollectUsedResource(SI->getFalseValue(), usedResID);
     CollectUsedResource(SI->getFalseValue(), usedResID);
-  } else {
-    PHINode *Phi = cast<PHINode>(resID);
+  } else if (PHINode *Phi = dyn_cast<PHINode>(resID)) {
     for (Use &U : Phi->incoming_values()) {
     for (Use &U : Phi->incoming_values()) {
       CollectUsedResource(U.get(), usedResID);
       CollectUsedResource(U.get(), usedResID);
     }
     }
   }
   }
+  // TODO: resID could be other types of instructions depending on the compiler optimization.
 }
 }
 
 
 static void ConvertUsedResource(std::unordered_set<unsigned> &immResID,
 static void ConvertUsedResource(std::unordered_set<unsigned> &immResID,

+ 1 - 1
tools/clang/test/CodeGenHLSL/multiUAVLoad5.hlsl

@@ -1,4 +1,4 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 
 
 // CHECK-NOT: Typed UAV Load Additional Formats
 // CHECK-NOT: Typed UAV Load Additional Formats
 
 

+ 1 - 1
tools/clang/test/CodeGenHLSL/multiUAVLoad6.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 // RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 
 
-// CHECK: Typed UAV Load Additional Formats
+// CHECK-NOT: Typed UAV Load Additional Formats
 
 
 RWBuffer<int> buf[5][1];
 RWBuffer<int> buf[5][1];
 RWBuffer<int2> buf_2[3];
 RWBuffer<int2> buf_2[3];

+ 1 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -3006,6 +3006,7 @@ TEST_F(CompilerTest, CodeGenMultiUAVLoad4) {
 }
 }
 
 
 TEST_F(CompilerTest, CodeGenMultiUAVLoad5) {
 TEST_F(CompilerTest, CodeGenMultiUAVLoad5) {
+  if (m_ver.SkipDxil_1_1_Test()) return;
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad5.hlsl");
   CodeGenTestCheck(L"..\\CodeGenHLSL\\multiUAVLoad5.hlsl");
 }
 }