浏览代码

Change SROA_HLSL to not iterate through use if modifying them.

Tex Riddell 6 年之前
父节点
当前提交
69c685f3ea

+ 78 - 52
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -2707,8 +2707,6 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
     assert(NewGEP->getType() == GEP->getType() && "type mismatch");
     
     GEP->replaceAllUsesWith(NewGEP);
-    if (isa<Instruction>(GEP))
-      DeadInsts.push_back(GEP);
   } else {
     // End at array of basic type.
     Type *Ty = GEP->getType()->getPointerElementType();
@@ -2725,22 +2723,16 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
         NewGEPs.emplace_back(NewGEP);
       }
       const bool bAllowReplace = isa<AllocaInst>(OldVal);
-      if (SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL,
-                                   bAllowReplace)) {
-        if (GEP->user_empty() && isa<Instruction>(GEP))
-          DeadInsts.push_back(GEP);
-        return;
-      }
-      SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL);
-      helper.RewriteForScalarRepl(GEP, Builder);
-      for (Value *NewGEP : NewGEPs) {
-        if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
-          // Delete unused newGEP.
-          cast<Instruction>(NewGEP)->eraseFromParent();
+      if (!SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL, bAllowReplace)) {
+        SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL);
+        helper.RewriteForScalarRepl(GEP, Builder);
+        for (Value *NewGEP : NewGEPs) {
+          if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
+            // Delete unused newGEP.
+            cast<Instruction>(NewGEP)->eraseFromParent();
+          }
         }
       }
-      if (GEP->user_empty() && isa<Instruction>(GEP))
-        DeadInsts.push_back(GEP);
     } else {
       Value *vecIdx = NewArgs.back();
       if (ConstantInt *immVecIdx = dyn_cast<ConstantInt>(vecIdx)) {
@@ -2758,14 +2750,22 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
         assert(NewGEP->getType() == GEP->getType() && "type mismatch");
 
         GEP->replaceAllUsesWith(NewGEP);
-        if (isa<Instruction>(GEP))
-          DeadInsts.push_back(GEP);
       } else {
         // dynamic vector indexing.
         assert(0 && "should not reach here");
       }
     }
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  DXASSERT(GEP->user_empty(), "All uses of the GEP should have been eliminated");
+  if (isa<Instruction>(GEP)) {
+    GEP->setOperand(GEP->getPointerOperandIndex(), UndefValue::get(GEP->getPointerOperand()->getType()));
+    DeadInsts.push_back(GEP);
+  }
+  else {
+    cast<Constant>(GEP)->destroyConstant();
+  }
 }
 
 /// isVectorOrStructArray - Check if T is array of vector or struct.
@@ -2828,7 +2828,6 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
       Insert = Builder.CreateInsertElement(Insert, Load, i, "insert");
     }
     LI->replaceAllUsesWith(Insert);
-    DeadInsts.push_back(LI);
   } else if (isCompatibleAggregate(LIType, ValTy)) {
     if (isVectorOrStructArray(LIType)) {
       // Replace:
@@ -2846,7 +2845,6 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
       Value *newLd =
           LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
       LI->replaceAllUsesWith(newLd);
-      DeadInsts.push_back(LI);
     } else {
       // Replace:
       //   %res = load { i32, i32 }* %alloc
@@ -2880,11 +2878,14 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
       if (LIType->isStructTy()) {
         SimplifyStructValUsage(Insert, LdElts, DeadInsts);
       }
-      DeadInsts.push_back(LI);
     }
   } else {
     llvm_unreachable("other type don't need rewrite");
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  LI->setOperand(LI->getPointerOperandIndex(), UndefValue::get(LI->getPointerOperand()->getType()));
+  DeadInsts.push_back(LI);
 }
 
 /// RewriteForStore - Replace OldVal with flattened NewElts in StoreInst.
@@ -2906,7 +2907,6 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       Value *Extract = Builder.CreateExtractElement(Val, i, Val->getName());
       Builder.CreateStore(Extract, NewElts[i]);
     }
-    DeadInsts.push_back(SI);
   } else if (isCompatibleAggregate(SIType, ValTy)) {
     if (isVectorOrStructArray(SIType)) {
       // Replace:
@@ -2936,7 +2936,6 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       SmallVector<Value *, 8> idxList;
       idxList.emplace_back(zero);
       StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
-      DeadInsts.push_back(SI);
     } else {
       // Replace:
       //   store { i32, i32 } %val, { i32, i32 }* %alloc
@@ -2959,11 +2958,14 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
               Extract->getType(), {NewElts[i], Extract}, *M);
         }
       }
-      DeadInsts.push_back(SI);
     }
   } else {
     llvm_unreachable("other type don't need rewrite");
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  SI->setOperand(SI->getPointerOperandIndex(), UndefValue::get(SI->getPointerOperand()->getType()));
+  DeadInsts.push_back(SI);
 }
 /// RewriteMemIntrin - MI is a memcpy/memset/memmove from or to AI.
 /// Rewrite it to copy or set the elements of the scalarized memory.
@@ -3006,6 +3008,10 @@ void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Value *OldV) {
            I != E; ++I)
         if (*I == MI)
           return;
+
+      // Remove the uses so that the caller can keep iterating over its other users
+      MI->setOperand(0, UndefValue::get(MI->getOperand(0)->getType()));
+      MI->setOperand(1, UndefValue::get(MI->getOperand(1)->getType()));
       DeadInsts.push_back(MI);
       return;
     }
@@ -3136,6 +3142,11 @@ void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Value *OldV) {
                               MI->isVolatile());
     }
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  MI->setOperand(0, UndefValue::get(MI->getOperand(0)->getType()));
+  if (isa<MemTransferInst>(MI))
+    MI->setOperand(1, UndefValue::get(MI->getOperand(1)->getType()));
   DeadInsts.push_back(MI);
 }
 
@@ -3317,6 +3328,13 @@ void SROA_Helper::RewriteForAddrSpaceCast(Value *CE,
   }
   SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL);
   helper.RewriteForScalarRepl(CE, Builder);
+
+  // Remove the use so that the caller can keep iterating over its other users
+  DXASSERT(CE->user_empty(), "All uses of the addrspacecast should have been eliminated");
+  if (Instruction *I = dyn_cast<Instruction>(CE))
+    I->eraseFromParent();
+  else
+    cast<Constant>(CE)->destroyConstant();
 }
 
 /// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
@@ -3335,10 +3353,6 @@ void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
       return;
     }
   }
-  // Skip unused CE. 
-  if (CE->use_empty())
-    return;
-
   for (Value::use_iterator UI = CE->use_begin(), E = CE->use_end(); UI != E;) {
     Use &TheUse = *UI++;
     if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) {
@@ -3352,37 +3366,49 @@ void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
       RewriteForConstExpr(cast<ConstantExpr>(TheUse.getUser()), Builder);
     }
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  DXASSERT(CE->user_empty(), "All uses of the constantexpr should have been eliminated");
+  CE->destroyConstant();
 }
 /// RewriteForScalarRepl - OldVal is being split into NewElts, so rewrite
 /// users of V, which references it, to use the separate elements.
 void SROA_Helper::RewriteForScalarRepl(Value *V, IRBuilder<> &Builder) {
-
-  for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;) {
-    Use &TheUse = *UI++;
-
+  // Don't iterate upon the uses explicitly because we'll be removing them,
+  // and potentially adding new ones (if expanding memcpys) during the iteration.
+  Use* PrevUse = nullptr;
+  while (!V->use_empty()) {
+    Use &TheUse = *V->use_begin();
+
+    DXASSERT_LOCALVAR(PrevUse, &TheUse != PrevUse,
+      "Infinite loop while SROA'ing value, use isn't getting eliminated.");
+    PrevUse = &TheUse;
+
+    // Each of these must either call ->eraseFromParent()
+    // or null out the use of V so that we make progress.
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(TheUse.getUser())) {
       RewriteForConstExpr(CE, Builder);
-      continue;
     }
-    Instruction *User = cast<Instruction>(TheUse.getUser());
-
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
-      IRBuilder<> Builder(GEP);
-      RewriteForGEP(cast<GEPOperator>(GEP), Builder);
-    } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
-      RewriteForLoad(ldInst);
-    else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
-      RewriteForStore(stInst);
-    else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
-      RewriteMemIntrin(MI, cast<Instruction>(V));
-    else if (CallInst *CI = dyn_cast<CallInst>(User)) 
-      RewriteCall(CI);
-    else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
-      RewriteBitCast(BCI);
-    else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(User)) {
-      RewriteForAddrSpaceCast(CI, Builder);
-    } else {
-      assert(0 && "not support.");
+    else {
+      Instruction *User = cast<Instruction>(TheUse.getUser());
+      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
+        IRBuilder<> Builder(GEP);
+        RewriteForGEP(cast<GEPOperator>(GEP), Builder);
+      } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
+        RewriteForLoad(ldInst);
+      else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
+        RewriteForStore(stInst);
+      else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
+        RewriteMemIntrin(MI, V);
+      else if (CallInst *CI = dyn_cast<CallInst>(User)) 
+        RewriteCall(CI);
+      else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
+        RewriteBitCast(BCI);
+      else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(User)) {
+        RewriteForAddrSpaceCast(CI, Builder);
+      } else {
+        assert(0 && "not support.");
+      }
     }
   }
 }

+ 14 - 0
tools/clang/test/CodeGenHLSL/passes/sroa_hlsl/groupshared_array_struct_matrix_regression.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Regression test for GitHub #1631, where SROA would generate more uses
+// of a value while processing it (due to expanding a memcpy) and fail
+// to process the new uses. This caused global structs of matrices to reach HLMatrixLower,
+// which couldn't handle them and would unexpectedly leave matrix intrinsics untouched.
+// Compilation would then fail with "error: Fail to lower matrix load/store."
+
+// CHECK: ret void
+
+struct S { int1x1 x, y; };
+groupshared S gs[1];
+void f(S s[1]) {}
+void main() { f(gs); }