PreciseVisitor.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. //===--- PreciseVisitor.cpp ------- Precise Visitor --------------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "PreciseVisitor.h"
  10. #include "clang/SPIRV/AstTypeProbe.h"
  11. #include "clang/SPIRV/SpirvFunction.h"
  12. #include "clang/SPIRV/SpirvType.h"
  13. #include <stack>
  14. namespace {
  15. /// \brief Returns true if the given OpAccessChain instruction is accessing a
  16. /// precise variable, or accessing a precise member of a structure. Returns
  17. /// false otherwise.
  18. bool isAccessingPrecise(clang::spirv::SpirvAccessChain *inst) {
  19. using namespace clang::spirv;
  20. // If the access chain base is another access chain and so on, first flatten
  21. // them (from the bottom to the top). For example:
  22. // %x = OpAccessChain <type> %obj %int_1 %int_2
  23. // %y = OpAccessChain <type> %x %int_3 %int_4
  24. // %z = OpAccessChain <type> %y %int_5 %int_6
  25. // Should be flattened to:
  26. // %z = OpAccessChain <type> %obj %int_1 %int_2 %int_3 %int_4 %int_5 %int_6
  27. std::stack<SpirvInstruction *> indexes;
  28. SpirvInstruction *base = inst;
  29. while (auto *accessChain = llvm::dyn_cast<SpirvAccessChain>(base)) {
  30. for (auto iter = accessChain->getIndexes().rbegin();
  31. iter != accessChain->getIndexes().rend(); ++iter) {
  32. indexes.push(*iter);
  33. }
  34. base = accessChain->getBase();
  35. // If we reach a 'precise' base at any level, return true.
  36. if (base->isPrecise())
  37. return true;
  38. }
  39. // Start from the lowest level base (%obj in the above example), and step
  40. // forward using the 'indexes'. If a 'precise' structure field is discovered
  41. // at any point, return true.
  42. const SpirvType *baseType = base->getResultType();
  43. while (baseType && !indexes.empty()) {
  44. if (auto *vecType = llvm::dyn_cast<VectorType>(baseType)) {
  45. indexes.pop();
  46. baseType = vecType->getElementType();
  47. } else if (auto *matType = llvm::dyn_cast<MatrixType>(baseType)) {
  48. indexes.pop();
  49. baseType = matType->getVecType();
  50. } else if (auto *arrType = llvm::dyn_cast<ArrayType>(baseType)) {
  51. indexes.pop();
  52. baseType = arrType->getElementType();
  53. } else if (auto *raType = llvm::dyn_cast<RuntimeArrayType>(baseType)) {
  54. indexes.pop();
  55. baseType = raType->getElementType();
  56. } else if (auto *structType = llvm::dyn_cast<StructType>(baseType)) {
  57. SpirvInstruction *index = indexes.top();
  58. if (auto *constInt = llvm::dyn_cast<SpirvConstantInteger>(index)) {
  59. uint32_t indexValue =
  60. static_cast<uint32_t>(constInt->getValue().getZExtValue());
  61. auto fields = structType->getFields();
  62. assert(indexValue < fields.size());
  63. auto &fieldInfo = fields[indexValue];
  64. if (fieldInfo.isPrecise) {
  65. return true;
  66. } else {
  67. baseType = fieldInfo.type;
  68. indexes.pop();
  69. }
  70. } else {
  71. // Trying to index into a structure using a variable? This shouldn't be
  72. // happening.
  73. assert(false && "indexing into a struct with variable value");
  74. return false;
  75. }
  76. } else if (auto *ptrType = llvm::dyn_cast<SpirvPointerType>(baseType)) {
  77. // Note: no need to pop the stack here.
  78. baseType = ptrType->getPointeeType();
  79. } else {
  80. return false;
  81. }
  82. }
  83. return false;
  84. }
  85. } // anonymous namespace
  86. namespace clang {
  87. namespace spirv {
  88. bool PreciseVisitor::visit(SpirvFunction *fn, Phase phase) {
  89. // Before going through the function instructions
  90. if (phase == Visitor::Phase::Init) {
  91. curFnRetValPrecise = fn->isPrecise();
  92. }
  93. return true;
  94. }
  95. bool PreciseVisitor::visit(SpirvReturn *inst) {
  96. if (inst->hasReturnValue()) {
  97. inst->getReturnValue()->setPrecise(curFnRetValPrecise);
  98. }
  99. return true;
  100. }
  101. bool PreciseVisitor::visit(SpirvVariable *var) {
  102. if (var->hasInitializer())
  103. var->getInitializer()->setPrecise(var->isPrecise());
  104. return true;
  105. }
  106. bool PreciseVisitor::visit(SpirvSelect *inst) {
  107. inst->getTrueObject()->setPrecise(inst->isPrecise());
  108. inst->getFalseObject()->setPrecise(inst->isPrecise());
  109. return true;
  110. }
  111. bool PreciseVisitor::visit(SpirvVectorShuffle *inst) {
  112. // If the result of a vector shuffle is 'precise', the vectors from which the
  113. // elements are chosen should also be 'precise'.
  114. if (inst->isPrecise()) {
  115. auto *vec1 = inst->getVec1();
  116. auto *vec2 = inst->getVec2();
  117. const auto vec1Type = vec1->getAstResultType();
  118. const auto vec2Type = vec2->getAstResultType();
  119. uint32_t vec1Size;
  120. uint32_t vec2Size;
  121. (void)isVectorType(vec1Type, nullptr, &vec1Size);
  122. (void)isVectorType(vec2Type, nullptr, &vec2Size);
  123. bool vec1ElemUsed = false;
  124. bool vec2ElemUsed = false;
  125. for (auto component : inst->getComponents()) {
  126. if (component < vec1Size)
  127. vec1ElemUsed = true;
  128. else
  129. vec2ElemUsed = true;
  130. }
  131. if (vec1ElemUsed)
  132. vec1->setPrecise();
  133. if (vec2ElemUsed)
  134. vec2->setPrecise();
  135. }
  136. return true;
  137. }
  138. bool PreciseVisitor::visit(SpirvBitFieldExtract *inst) {
  139. inst->getBase()->setPrecise(inst->isPrecise());
  140. return true;
  141. }
  142. bool PreciseVisitor::visit(SpirvBitFieldInsert *inst) {
  143. inst->getBase()->setPrecise(inst->isPrecise());
  144. inst->getInsert()->setPrecise(inst->isPrecise());
  145. return true;
  146. }
  147. bool PreciseVisitor::visit(SpirvAtomic *inst) {
  148. if (inst->isPrecise() && inst->hasValue())
  149. inst->getValue()->setPrecise();
  150. return true;
  151. }
  152. bool PreciseVisitor::visit(SpirvCompositeConstruct *inst) {
  153. if (inst->isPrecise())
  154. for (auto *consituent : inst->getConstituents())
  155. consituent->setPrecise();
  156. return true;
  157. }
  158. bool PreciseVisitor::visit(SpirvCompositeExtract *inst) {
  159. inst->getComposite()->setPrecise(inst->isPrecise());
  160. return true;
  161. }
  162. bool PreciseVisitor::visit(SpirvCompositeInsert *inst) {
  163. inst->getComposite()->setPrecise(inst->isPrecise());
  164. inst->getObject()->setPrecise(inst->isPrecise());
  165. return true;
  166. }
  167. bool PreciseVisitor::visit(SpirvLoad *inst) {
  168. // If the instruction result is precise, the pointer we're loading from should
  169. // also be marked as precise.
  170. if (inst->isPrecise())
  171. inst->getPointer()->setPrecise();
  172. return true;
  173. }
  174. bool PreciseVisitor::visit(SpirvStore *inst) {
  175. // If the 'pointer' to which we are storing is marked as 'precise', the object
  176. // we are storing should also be marked as 'precise'.
  177. // Note that the 'pointer' may either be an 'OpVariable' or it might be the
  178. // result of one or more access chains (in which case we should figure out if
  179. // the 'base' of the access chain is 'precise').
  180. auto *ptr = inst->getPointer();
  181. auto *obj = inst->getObject();
  182. // The simple case (target is a precise variable).
  183. if (ptr->isPrecise()) {
  184. obj->setPrecise();
  185. return true;
  186. }
  187. if (auto *accessChain = llvm::dyn_cast<SpirvAccessChain>(ptr)) {
  188. if (isAccessingPrecise(accessChain)) {
  189. obj->setPrecise();
  190. return true;
  191. }
  192. }
  193. return true;
  194. }
  195. bool PreciseVisitor::visit(SpirvBinaryOp *inst) {
  196. bool isPrecise = inst->isPrecise();
  197. inst->getOperand1()->setPrecise(isPrecise);
  198. inst->getOperand2()->setPrecise(isPrecise);
  199. return true;
  200. }
  201. bool PreciseVisitor::visit(SpirvUnaryOp *inst) {
  202. inst->getOperand()->setPrecise(inst->isPrecise());
  203. return true;
  204. }
  205. bool PreciseVisitor::visit(SpirvNonUniformBinaryOp *inst) {
  206. inst->getArg1()->setPrecise(inst->isPrecise());
  207. inst->getArg2()->setPrecise(inst->isPrecise());
  208. return true;
  209. }
  210. bool PreciseVisitor::visit(SpirvNonUniformUnaryOp *inst) {
  211. inst->getArg()->setPrecise(inst->isPrecise());
  212. return true;
  213. }
  214. bool PreciseVisitor::visit(SpirvExtInst *inst) {
  215. if (inst->isPrecise())
  216. for (auto *operand : inst->getOperands())
  217. operand->setPrecise();
  218. return true;
  219. }
  220. } // end namespace spirv
  221. } // end namespace clang