浏览代码

Merged PR 31: Support select chain in TranslateHLCreateHandle.

Support select chain in TranslateHLCreateHandle.
Xiang_Li (XBox) 7 年之前
父节点
当前提交
88de36b2f5

+ 7 - 0
include/dxc/HLSL/DxilUtil.h

@@ -10,6 +10,7 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 #pragma once
+#include <unordered_set>
 
 namespace llvm {
 class Type;
@@ -51,6 +52,12 @@ namespace dxilutil {
   // NewInst = phi A, B, C
   // Only support 1 operand now, other oerands should be Constant.
   llvm::Value * SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx);
+  // Collect all select operand used by Inst.
+  void CollectSelect(llvm::Instruction *Inst,
+                   std::unordered_set<llvm::Instruction *> &selectSet);
+  // If all operands are the same for a select inst, replace it with the operand.
+  bool MergeSelectOnSameValue(llvm::Instruction *SelInst, unsigned startOpIdx,
+                            unsigned numOperands);
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
     llvm::LLVMContext &Ctx, std::string &DiagStr);
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,

+ 3 - 48
lib/HLSL/DxilCondenseResources.cpp

@@ -1079,31 +1079,6 @@ bool DxilLowerCreateHandleForLib::PatchTBuffers(DxilModule &DM) {
 // phi1 = phi a1, b1, c1
 // NewInst = Add(phi0, phi1);
 namespace {
-void CollectSelect(llvm::Instruction *Inst,
-                   std::unordered_set<llvm::Instruction *> &selectSet) {
-  unsigned startOpIdx = 0;
-  // Skip Cond for Select.
-  if (isa<SelectInst>(Inst)) {
-    startOpIdx = 1;
-  } else if (!isa<PHINode>(Inst)) {
-    // Only check phi and select here.
-    return;
-  }
-  // Already add.
-  if (selectSet.count(Inst))
-    return;
-
-  selectSet.insert(Inst);
-
-  // Scan operand to add node which is phi/select.
-  unsigned numOperands = Inst->getNumOperands();
-  for (unsigned i = startOpIdx; i < numOperands; i++) {
-    Value *V = Inst->getOperand(i);
-    if (Instruction *I = dyn_cast<Instruction>(V)) {
-      CollectSelect(I, selectSet);
-    }
-  }
-}
 
 void CreateOperandSelect(Instruction *SelInst, Instruction *Prototype,
                          std::unordered_map<Instruction *, Instruction *>
@@ -1152,26 +1127,6 @@ void CreateOperandSelect(Instruction *SelInst, Instruction *Prototype,
   }
 }
 
-bool MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx,
-                            unsigned numOperands) {
-  Value *op0 = nullptr;
-  for (unsigned i = startOpIdx; i < numOperands; i++) {
-    Value *op = SelInst->getOperand(i);
-    if (i == startOpIdx) {
-      op0 = op;
-    } else {
-      if (op0 != op)
-        return false;
-    }
-  }
-  if (op0) {
-    SelInst->replaceAllUsesWith(op0);
-    SelInst->eraseFromParent();
-    return true;
-  }
-  return false;
-}
-
 void UpdateOperandSelect(Instruction *SelInst,
                          std::unordered_map<Instruction *, Instruction *>
                              &selInstToSelOperandInstMap,
@@ -1237,7 +1192,7 @@ void UpdateOperandSelect(Instruction *SelInst,
       opI->setOperand(j, selOp->getOperand(i));
     }
     // Remove select if all operand is the same.
-    if (!MergeSelectOnSameValue(opI, startOpIdx, numOperands) &&
+    if (!dxilutil::MergeSelectOnSameValue(opI, startOpIdx, numOperands) &&
         i != nonUniformOpIdx) {
       // Save nonUniform for later check.
       nonUniformOps.insert(opI);
@@ -1259,7 +1214,7 @@ void DxilLowerCreateHandleForLib::AddCreateHandleForPhiNodeAndSelect(
     for (User *HandleU : U->users()) {
       Instruction *I = cast<Instruction>(HandleU);
       if (!isa<CallInst>(I))
-        CollectSelect(I, resSelectSet);
+        dxilutil::CollectSelect(I, resSelectSet);
     }
   }
 
@@ -1319,7 +1274,7 @@ void DxilLowerCreateHandleForLib::AddCreateHandleForPhiNodeAndSelect(
       // Skip Cond for Select.
       if (SelectInst *Sel = dyn_cast<SelectInst>(I))
         startOpIdx = 1;
-      if (MergeSelectOnSameValue(I, startOpIdx, numOperands)) {
+      if (dxilutil::MergeSelectOnSameValue(I, startOpIdx, numOperands)) {
         nonUniformOps.erase(I);
         bUpdated = true;
       }

+ 139 - 14
lib/HLSL/DxilGenerationPass.cpp

@@ -344,9 +344,137 @@ private:
 }
 
 namespace {
+
+void CreateOperandSelect(Instruction *SelInst, Value *EmptyVal,
+                         std::unordered_map<Instruction *, Instruction *>
+                             &selInstToSelOperandInstMap) {
+  IRBuilder<> Builder(SelInst);
+
+  if (SelectInst *Sel = dyn_cast<SelectInst>(SelInst)) {
+    Instruction *newSel = cast<Instruction>(
+        Builder.CreateSelect(Sel->getCondition(), EmptyVal, EmptyVal));
+
+    selInstToSelOperandInstMap[SelInst] = newSel;
+  } else {
+    PHINode *Phi = cast<PHINode>(SelInst);
+    unsigned numIncoming = Phi->getNumIncomingValues();
+
+    // Don't replace constant int operand.
+    PHINode *newSel = Builder.CreatePHI(EmptyVal->getType(), numIncoming);
+    for (unsigned j = 0; j < numIncoming; j++) {
+      BasicBlock *BB = Phi->getIncomingBlock(j);
+      newSel->addIncoming(EmptyVal, BB);
+    }
+
+    selInstToSelOperandInstMap[SelInst] = newSel;
+  }
+}
+
+void UpdateOperandSelect(Instruction *SelInst, Instruction *Prototype,
+                         unsigned operandIdx,
+                         std::unordered_map<Instruction *, Instruction *>
+                             &selInstToSelOperandInstMap) {
+  unsigned numOperands = SelInst->getNumOperands();
+
+  unsigned startOpIdx = 0;
+  // Skip Cond for Select.
+  if (SelectInst *Sel = dyn_cast<SelectInst>(SelInst)) {
+    startOpIdx = 1;
+  }
+
+  Instruction *newSel = selInstToSelOperandInstMap[SelInst];
+  // Transform
+  // phi0 = phi a0, b0, c0
+  // phi1 = phi a1, b1, c1
+  // NewInst = Add(phi0, phi1);
+  //   into
+  // A = Add(a0, a1);
+  // B = Add(b0, b1);
+  // C = Add(c0, c1);
+  // NewSelInst = phi A, B, C
+  // Only support 1 operand now, other oerands should be Constant.
+
+  // Each operand of newInst is a clone of prototype inst.
+  // Now we set A operands based on operand 0 of phi0 and phi1.
+  for (unsigned i = startOpIdx; i < numOperands; i++) {
+    Instruction *selOp = cast<Instruction>(SelInst->getOperand(i));
+    auto it = selInstToSelOperandInstMap.find(selOp);
+    if (it != selInstToSelOperandInstMap.end()) {
+      // Operand is an select.
+      // Map to new created select inst.
+      Instruction *newSelOp = it->second;
+      newSel->setOperand(i, newSelOp);
+    } else {
+      // The operand is not select.
+      // just use it for prototype operand.
+      // Make sure function is the same.
+      Instruction *op = Prototype->clone();
+      op->setOperand(operandIdx, selOp);
+      if (PHINode *phi = dyn_cast<PHINode>(SelInst)) {
+        BasicBlock *BB = phi->getIncomingBlock(i);
+        IRBuilder<> TmpBuilder(BB->getTerminator());
+        TmpBuilder.Insert(op);
+      } else {
+        IRBuilder<> TmpBuilder(newSel);
+        TmpBuilder.Insert(op);
+      }
+      newSel->setOperand(i, op);
+    }
+  }
+}
+
 void TranslateHLCreateHandle(Function *F, hlsl::OP &hlslOP) {
   Value *opArg = hlslOP.GetU32Const(
       (unsigned)DXIL::OpCode::CreateHandleFromResourceStructForLib);
+
+  // Remove PhiNode createHandle first.
+  std::vector<Instruction *> resSelects;
+  std::unordered_set<llvm::Instruction *> selectSet;
+  for (auto U = F->user_begin(); U != F->user_end();) {
+    Value *user = *(U++);
+    if (!isa<Instruction>(user))
+      continue;
+    // must be call inst
+    CallInst *CI = cast<CallInst>(user);
+    Value *res = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
+    if (isa<SelectInst>(res) || isa<PHINode>(res))
+      dxilutil::CollectSelect(cast<Instruction>(res), selectSet);
+  }
+
+  if (!selectSet.empty()) {
+    FunctionType *FT = F->getFunctionType();
+    Type *ResTy = FT->getParamType(HLOperandIndex::kUnaryOpSrc0Idx);
+
+    Value *UndefHandle = UndefValue::get(F->getReturnType());
+    std::unordered_map<Instruction *, Instruction *> handleMap;
+    for (Instruction *SelInst : selectSet) {
+      CreateOperandSelect(SelInst, UndefHandle, handleMap);
+    }
+
+    Value *UndefRes = UndefValue::get(ResTy);
+    std::unique_ptr<CallInst> PrototypeCall(
+        CallInst::Create(F, {opArg, UndefRes}));
+
+    for (Instruction *SelInst : selectSet) {
+      UpdateOperandSelect(SelInst, PrototypeCall.get(),
+                          HLOperandIndex::kUnaryOpSrc0Idx, handleMap);
+    }
+
+    // Replace createHandle on select with select on createHandle.
+    for (Instruction *SelInst : selectSet) {
+      Value *NewSel = handleMap[SelInst];
+      for (auto U = SelInst->user_begin(); U != SelInst->user_end();) {
+        Value *user = *(U++);
+        if (CallInst *CI = dyn_cast<CallInst>(user)) {
+          if (CI->getCalledFunction() == F) {
+            CI->replaceAllUsesWith(NewSel);
+            CI->eraseFromParent();
+          }
+        }
+      }
+    }
+  }
+
   for (auto U = F->user_begin(); U != F->user_end();) {
     Value *user = *(U++);
     if (!isa<Instruction>(user))
@@ -356,21 +484,18 @@ void TranslateHLCreateHandle(Function *F, hlsl::OP &hlslOP) {
     Value *res = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
     Value *newHandle = nullptr;
     IRBuilder<> Builder(CI);
-    if (LoadInst *LI = dyn_cast<LoadInst>(res)) {
-      Function *createHandle =
-          hlslOP.GetOpFunc(DXIL::OpCode::CreateHandleFromResourceStructForLib,
-                           LI->getType());
-      newHandle = Builder.CreateCall(createHandle, {opArg, LI});
-    } else {
-      Function *createHandle =
-          hlslOP.GetOpFunc(DXIL::OpCode::CreateHandleFromResourceStructForLib,
-                           res->getType());
-      CallInst *newHandleCI = Builder.CreateCall(createHandle, {opArg, res});
-      // Change select/phi on operands into select/phi on operation.
-      newHandle =
-          dxilutil::SelectOnOperation(newHandleCI, HLOperandIndex::kUnaryOpSrc0Idx);
-    }
+    // Must be load.
+    LoadInst *LI = cast<LoadInst>(res);
+    Function *createHandle = hlslOP.GetOpFunc(
+        DXIL::OpCode::CreateHandleFromResourceStructForLib, LI->getType());
+    newHandle = Builder.CreateCall(createHandle, {opArg, LI});
+
     CI->replaceAllUsesWith(newHandle);
+    if (res->user_empty()) {
+      if (Instruction *I = dyn_cast<Instruction>(res))
+        I->eraseFromParent();
+    }
+
     CI->eraseFromParent();
   }
 }

+ 46 - 0
lib/HLSL/DxilUtil.cpp

@@ -176,6 +176,52 @@ void EmitResMappingError(Instruction *Res) {
   }
 }
 
+void CollectSelect(llvm::Instruction *Inst,
+                   std::unordered_set<llvm::Instruction *> &selectSet) {
+  unsigned startOpIdx = 0;
+  // Skip Cond for Select.
+  if (isa<SelectInst>(Inst)) {
+    startOpIdx = 1;
+  } else if (!isa<PHINode>(Inst)) {
+    // Only check phi and select here.
+    return;
+  }
+  // Already add.
+  if (selectSet.count(Inst))
+    return;
+
+  selectSet.insert(Inst);
+
+  // Scan operand to add node which is phi/select.
+  unsigned numOperands = Inst->getNumOperands();
+  for (unsigned i = startOpIdx; i < numOperands; i++) {
+    Value *V = Inst->getOperand(i);
+    if (Instruction *I = dyn_cast<Instruction>(V)) {
+      CollectSelect(I, selectSet);
+    }
+  }
+}
+
+bool MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx,
+                            unsigned numOperands) {
+  Value *op0 = nullptr;
+  for (unsigned i = startOpIdx; i < numOperands; i++) {
+    Value *op = SelInst->getOperand(i);
+    if (i == startOpIdx) {
+      op0 = op;
+    } else {
+      if (op0 != op)
+        return false;
+    }
+  }
+  if (op0) {
+    SelInst->replaceAllUsesWith(op0);
+    SelInst->eraseFromParent();
+    return true;
+  }
+  return false;
+}
+
 Value *SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx) {
   Instruction *prototype = Inst;
   for (unsigned i = 0; i < prototype->getNumOperands(); i++) {

+ 27 - 0
tools/clang/test/CodeGenHLSL/quick-test/res_select2.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -T lib_6_1 %s | FileCheck %s
+
+// Make sure no phi of resource.
+// CHECK-NOT: phi %class.RWBuffer
+// CHECK: phi %dx.types.Handle
+
+RWBuffer<float4> a;
+RWBuffer<float4> b;
+RWBuffer<float4> c;
+
+float4 test(int i, int j, int m) {
+  RWBuffer<float4> buf = c;
+  while (i > 9) {
+     while (j < 4) {
+        if (i < m)
+          buf = b;
+        buf[j] = i;
+        j++;
+     }
+     if (m > j)
+       buf = a;
+     buf[m] = i;
+     i--;
+  }
+  buf[i] = j;
+  return j;
+}