CGLoopInfo.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. //===---- CGLoopInfo.cpp - LLVM CodeGen for loop metadata -*- C++ -*-------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "CGLoopInfo.h"
  10. #include "clang/AST/Attr.h"
  11. #include "clang/Sema/LoopHint.h"
  12. #include "llvm/IR/BasicBlock.h"
  13. #include "llvm/IR/Constants.h"
  14. #include "llvm/IR/InstrTypes.h"
  15. #include "llvm/IR/Instructions.h"
  16. #include "llvm/IR/Metadata.h"
  17. using namespace clang::CodeGen;
  18. using namespace llvm;
  19. static MDNode *createMetadata(LLVMContext &Ctx, const LoopAttributes &Attrs) {
  20. if (!Attrs.IsParallel && Attrs.VectorizerWidth == 0 &&
  21. Attrs.VectorizerUnroll == 0 &&
  22. Attrs.HlslLoop == false && // HLSL Change
  23. Attrs.HlslUnrollCount == 0 && // HLSL Change
  24. Attrs.VectorizerEnable == LoopAttributes::VecUnspecified)
  25. return nullptr;
  26. SmallVector<Metadata *, 4> Args;
  27. // Reserve operand 0 for loop id self reference.
  28. auto TempNode = MDNode::getTemporary(Ctx, None);
  29. Args.push_back(TempNode.get());
  30. // Setting vectorizer.width
  31. if (Attrs.VectorizerWidth > 0) {
  32. Metadata *Vals[] = {MDString::get(Ctx, "llvm.loop.vectorize.width"),
  33. ConstantAsMetadata::get(ConstantInt::get(
  34. Type::getInt32Ty(Ctx), Attrs.VectorizerWidth))};
  35. Args.push_back(MDNode::get(Ctx, Vals));
  36. }
  37. // Setting vectorizer.unroll
  38. if (Attrs.VectorizerUnroll > 0) {
  39. Metadata *Vals[] = {MDString::get(Ctx, "llvm.loop.interleave.count"),
  40. ConstantAsMetadata::get(ConstantInt::get(
  41. Type::getInt32Ty(Ctx), Attrs.VectorizerUnroll))};
  42. Args.push_back(MDNode::get(Ctx, Vals));
  43. }
  44. // Setting vectorizer.enable
  45. if (Attrs.VectorizerEnable != LoopAttributes::VecUnspecified) {
  46. Metadata *Vals[] = {
  47. MDString::get(Ctx, "llvm.loop.vectorize.enable"),
  48. ConstantAsMetadata::get(ConstantInt::get(
  49. Type::getInt1Ty(Ctx),
  50. (Attrs.VectorizerEnable == LoopAttributes::VecEnable)))};
  51. Args.push_back(MDNode::get(Ctx, Vals));
  52. }
  53. // HLSL Change Begins.
  54. if (Attrs.HlslLoop) {
  55. // Disable unroll.
  56. SmallVector<Metadata *, 1> DisableOperands;
  57. DisableOperands.push_back(MDString::get(Ctx, "llvm.loop.unroll.disable"));
  58. MDNode *DisableNode = MDNode::get(Ctx, DisableOperands);
  59. Args.push_back(DisableNode);
  60. }
  61. else if (Attrs.HlslUnrollCount) {
  62. if (Attrs.HlslUnrollCount == 1) {
  63. // Full unroll.
  64. SmallVector<Metadata *, 1> FullOperands;
  65. FullOperands.push_back(MDString::get(Ctx, "llvm.loop.unroll.full"));
  66. MDNode *FullNode = MDNode::get(Ctx, FullOperands);
  67. Args.push_back(FullNode);
  68. } else {
  69. Metadata *Vals[] = {MDString::get(Ctx, "llvm.loop.unroll.count"),
  70. ConstantAsMetadata::get(ConstantInt::get(
  71. Type::getInt32Ty(Ctx), Attrs.HlslUnrollCount))};
  72. Args.push_back(MDNode::get(Ctx, Vals));
  73. }
  74. }
  75. // HLSL Change Ends.
  76. // Set the first operand to itself.
  77. MDNode *LoopID = MDNode::get(Ctx, Args);
  78. LoopID->replaceOperandWith(0, LoopID);
  79. return LoopID;
  80. }
  81. LoopAttributes::LoopAttributes(bool IsParallel)
  82. : IsParallel(IsParallel), VectorizerEnable(LoopAttributes::VecUnspecified),
  83. VectorizerWidth(0), VectorizerUnroll(0),
  84. HlslLoop(false), HlslUnrollCount(0) {} // HLSL Change
  85. void LoopAttributes::clear() {
  86. IsParallel = false;
  87. VectorizerWidth = 0;
  88. VectorizerUnroll = 0;
  89. VectorizerEnable = LoopAttributes::VecUnspecified;
  90. HlslLoop = false; // HLSL Change
  91. HlslUnrollCount = 0; // HLSL Change
  92. }
  93. LoopInfo::LoopInfo(BasicBlock *Header, const LoopAttributes &Attrs)
  94. : LoopID(nullptr), Header(Header), Attrs(Attrs) {
  95. LoopID = createMetadata(Header->getContext(), Attrs);
  96. }
  97. void LoopInfoStack::push(BasicBlock *Header,
  98. ArrayRef<const clang::Attr *> Attrs) {
  99. for (const auto *Attr : Attrs) {
  100. const LoopHintAttr *LH = dyn_cast<LoopHintAttr>(Attr);
  101. // HLSL Change Begins
  102. if (dyn_cast<HLSLLoopAttr>(Attr)) {
  103. setHlslLoop(true);
  104. } else if (const HLSLUnrollAttr *UnrollAttr =
  105. dyn_cast<HLSLUnrollAttr>(Attr)) {
  106. unsigned count = UnrollAttr->getCount();
  107. setHlslUnrollCount(count);
  108. }
  109. // HLSL Change Ends
  110. // Skip non loop hint attributes
  111. if (!LH)
  112. continue;
  113. LoopHintAttr::OptionType Option = LH->getOption();
  114. LoopHintAttr::LoopHintState State = LH->getState();
  115. switch (Option) {
  116. case LoopHintAttr::Vectorize:
  117. case LoopHintAttr::Interleave:
  118. if (State == LoopHintAttr::AssumeSafety) {
  119. // Apply "llvm.mem.parallel_loop_access" metadata to load/stores.
  120. setParallel(true);
  121. }
  122. break;
  123. case LoopHintAttr::VectorizeWidth:
  124. case LoopHintAttr::InterleaveCount:
  125. case LoopHintAttr::Unroll:
  126. case LoopHintAttr::UnrollCount:
  127. // Nothing to do here for these loop hints.
  128. break;
  129. }
  130. }
  131. Active.push_back(LoopInfo(Header, StagedAttrs));
  132. // Clear the attributes so nested loops do not inherit them.
  133. StagedAttrs.clear();
  134. }
  135. void LoopInfoStack::pop() {
  136. assert(!Active.empty() && "No active loops to pop");
  137. Active.pop_back();
  138. }
  139. void LoopInfoStack::InsertHelper(Instruction *I) const {
  140. if (!hasInfo())
  141. return;
  142. const LoopInfo &L = getInfo();
  143. if (!L.getLoopID())
  144. return;
  145. if (TerminatorInst *TI = dyn_cast<TerminatorInst>(I)) {
  146. for (unsigned i = 0, ie = TI->getNumSuccessors(); i < ie; ++i)
  147. if (TI->getSuccessor(i) == L.getHeader()) {
  148. TI->setMetadata("llvm.loop", L.getLoopID());
  149. break;
  150. }
  151. return;
  152. }
  153. if (L.getAttributes().IsParallel && I->mayReadOrWriteMemory())
  154. I->setMetadata("llvm.mem.parallel_loop_access", L.getLoopID());
  155. }