DxilLegalizeEvalOperations.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilLegalizeEvalOperations.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HlslIntrinsicOp.h"
  10. #include "dxc/DXIL/DxilModule.h"
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/HLSL/HLOperations.h"
  13. #include "llvm/Pass.h"
  14. #include "llvm/IR/Function.h"
  15. #include "llvm/IR/Instruction.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/Module.h"
  18. #include "llvm/Transforms/Utils/SSAUpdater.h"
  19. #include <unordered_set>
  20. #include <vector>
  21. using namespace llvm;
  22. using namespace hlsl;
  23. // Make sure src of EvalOperations are from function parameter.
  24. // This is needed in order to translate EvaluateAttribute operations that traces
  25. // back to LoadInput operations during translation stage. Promoting load/store
  26. // instructions beforehand will allow us to easily trace back to loadInput from
  27. // function call.
  28. namespace {
  29. class DxilLegalizeEvalOperations : public ModulePass {
  30. public:
  31. static char ID; // Pass identification, replacement for typeid
  32. explicit DxilLegalizeEvalOperations() : ModulePass(ID) {}
  33. const char *getPassName() const override {
  34. return "DXIL Legalize EvalOperations";
  35. }
  36. bool runOnModule(Module &M) override {
  37. for (Function &F : M.getFunctionList()) {
  38. hlsl::HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(&F);
  39. if (group != HLOpcodeGroup::NotHL) {
  40. std::vector<CallInst *> EvalFunctionCalls;
  41. // Find all EvaluateAttribute calls
  42. for (User *U : F.users()) {
  43. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  44. IntrinsicOp evalOp =
  45. static_cast<IntrinsicOp>(hlsl::GetHLOpcode(CI));
  46. if (evalOp == IntrinsicOp::IOP_EvaluateAttributeAtSample ||
  47. evalOp == IntrinsicOp::IOP_EvaluateAttributeCentroid ||
  48. evalOp == IntrinsicOp::IOP_EvaluateAttributeSnapped ||
  49. evalOp == IntrinsicOp::IOP_GetAttributeAtVertex) {
  50. EvalFunctionCalls.push_back(CI);
  51. }
  52. }
  53. }
  54. if (EvalFunctionCalls.empty()) {
  55. continue;
  56. }
  57. // Start from the call instruction, find all allocas that this call
  58. // uses.
  59. std::unordered_set<AllocaInst *> allocas;
  60. for (CallInst *CI : EvalFunctionCalls) {
  61. FindAllocasForEvalOperations(CI, allocas);
  62. }
  63. SSAUpdater SSA;
  64. SmallVector<Instruction *, 4> Insts;
  65. for (AllocaInst *AI : allocas) {
  66. for (User *user : AI->users()) {
  67. if (isa<LoadInst>(user) || isa<StoreInst>(user)) {
  68. Insts.emplace_back(cast<Instruction>(user));
  69. }
  70. }
  71. LoadAndStorePromoter(Insts, SSA).run(Insts);
  72. Insts.clear();
  73. }
  74. }
  75. }
  76. return true;
  77. }
  78. private:
  79. void FindAllocasForEvalOperations(Value *val,
  80. std::unordered_set<AllocaInst *> &allocas);
  81. };
  82. char DxilLegalizeEvalOperations::ID = 0;
  83. // Find allocas for EvaluateAttribute operations
  84. void DxilLegalizeEvalOperations::FindAllocasForEvalOperations(
  85. Value *val, std::unordered_set<AllocaInst *> &allocas) {
  86. Value *CurVal = val;
  87. while (!isa<AllocaInst>(CurVal)) {
  88. if (CallInst *CI = dyn_cast<CallInst>(CurVal)) {
  89. CurVal = CI->getOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  90. } else if (InsertElementInst *IE = dyn_cast<InsertElementInst>(CurVal)) {
  91. Value *arg0 =
  92. IE->getOperand(0); // Could be another insertelement or undef
  93. Value *arg1 = IE->getOperand(1);
  94. FindAllocasForEvalOperations(arg0, allocas);
  95. CurVal = arg1;
  96. } else if (ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(CurVal)) {
  97. Value *arg0 = SV->getOperand(0);
  98. Value *arg1 = SV->getOperand(1);
  99. FindAllocasForEvalOperations(
  100. arg0, allocas); // Shuffle vector could come from different allocas
  101. CurVal = arg1;
  102. } else if (ExtractElementInst *EE = dyn_cast<ExtractElementInst>(CurVal)) {
  103. CurVal = EE->getOperand(0);
  104. } else if (LoadInst *LI = dyn_cast<LoadInst>(CurVal)) {
  105. CurVal = LI->getOperand(0);
  106. } else {
  107. break;
  108. }
  109. }
  110. if (AllocaInst *AI = dyn_cast<AllocaInst>(CurVal)) {
  111. allocas.insert(AI);
  112. }
  113. }
  114. } // namespace
  115. ModulePass *llvm::createDxilLegalizeEvalOperationsPass() {
  116. return new DxilLegalizeEvalOperations();
  117. }
  118. INITIALIZE_PASS(DxilLegalizeEvalOperations,
  119. "hlsl-dxil-legalize-eval-operations",
  120. "DXIL legalize eval operations", false, false)