SpirvModule.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. //===--- SpirvModule.cpp - SPIR-V Module Implementation ----------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/SpirvModule.h"
  10. #include "clang/SPIRV/SpirvFunction.h"
  11. #include "clang/SPIRV/SpirvVisitor.h"
  12. namespace clang {
  13. namespace spirv {
  14. SpirvModule::SpirvModule()
  15. : capabilities({}), extensions({}), extInstSets({}), memoryModel(nullptr),
  16. entryPoints({}), executionModes({}), moduleProcesses({}), decorations({}),
  17. constants({}), variables({}), functions({}), debugInstructions({}) {}
  18. SpirvModule::~SpirvModule() {
  19. for (auto *cap : capabilities)
  20. cap->releaseMemory();
  21. for (auto *ext : extensions)
  22. ext->releaseMemory();
  23. for (auto *set : extInstSets)
  24. set->releaseMemory();
  25. if (memoryModel)
  26. memoryModel->releaseMemory();
  27. for (auto *entry : entryPoints)
  28. entry->releaseMemory();
  29. for (auto *exec : executionModes)
  30. exec->releaseMemory();
  31. for (auto *str : constStrings)
  32. str->releaseMemory();
  33. for (auto *d : sources)
  34. d->releaseMemory();
  35. for (auto *mp : moduleProcesses)
  36. mp->releaseMemory();
  37. for (auto *decoration : decorations)
  38. decoration->releaseMemory();
  39. for (auto *constant : constants)
  40. constant->releaseMemory();
  41. for (auto *var : variables)
  42. var->releaseMemory();
  43. for (auto *di : debugInstructions)
  44. di->releaseMemory();
  45. for (auto *f : allFunctions)
  46. f->~SpirvFunction();
  47. }
  48. bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
  49. // Note: It is debatable whether reverse order of visiting the module should
  50. // reverse everything in this method. For the time being, we just reverse the
  51. // order of the function visitors, and keeping everything else the same.
  52. // For example, it is not clear what the value would be of vising the last
  53. // function first. We can update this methodology if needed.
  54. if (!visitor->visit(this, Visitor::Phase::Init))
  55. return false;
  56. if (reverseOrder) {
  57. // Reverse order of a SPIR-V module.
  58. // Our transformations do not cross function bounaries, therefore the order
  59. // of visiting functions is not important.
  60. for (auto iter = functions.rbegin(); iter != functions.rend(); ++iter) {
  61. auto *fn = *iter;
  62. if (!fn->invokeVisitor(visitor, reverseOrder))
  63. return false;
  64. }
  65. for (auto iter = debugInstructions.rbegin();
  66. iter != debugInstructions.rend(); ++iter) {
  67. auto *debugInstruction = *iter;
  68. if (!debugInstruction->invokeVisitor(visitor))
  69. return false;
  70. }
  71. for (auto iter = variables.rbegin(); iter != variables.rend(); ++iter) {
  72. auto *var = *iter;
  73. if (!var->invokeVisitor(visitor))
  74. return false;
  75. }
  76. for (auto iter = constants.rbegin(); iter != constants.rend(); ++iter) {
  77. auto *constant = *iter;
  78. if (!constant->invokeVisitor(visitor))
  79. return false;
  80. }
  81. // Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
  82. // manual indexing.
  83. for (auto decorIndex = decorations.size(); decorIndex > 0; --decorIndex) {
  84. auto *decoration = decorations[decorIndex - 1];
  85. if (!decoration->invokeVisitor(visitor))
  86. return false;
  87. }
  88. for (auto iter = moduleProcesses.rbegin(); iter != moduleProcesses.rend();
  89. ++iter) {
  90. auto *moduleProcess = *iter;
  91. if (!moduleProcess->invokeVisitor(visitor))
  92. return false;
  93. }
  94. if (!sources.empty())
  95. for (auto iter = sources.rbegin(); iter != sources.rend(); ++iter) {
  96. auto *source = *iter;
  97. if (!source->invokeVisitor(visitor))
  98. return false;
  99. }
  100. for (auto iter = constStrings.rbegin(); iter != constStrings.rend();
  101. ++iter) {
  102. if (!(*iter)->invokeVisitor(visitor))
  103. return false;
  104. }
  105. for (auto iter = executionModes.rbegin(); iter != executionModes.rend();
  106. ++iter) {
  107. auto *execMode = *iter;
  108. if (!execMode->invokeVisitor(visitor))
  109. return false;
  110. }
  111. for (auto iter = entryPoints.rbegin(); iter != entryPoints.rend(); ++iter) {
  112. auto *entryPoint = *iter;
  113. if (!entryPoint->invokeVisitor(visitor))
  114. return false;
  115. }
  116. if (!memoryModel->invokeVisitor(visitor))
  117. return false;
  118. for (auto iter = extInstSets.rbegin(); iter != extInstSets.rend(); ++iter) {
  119. auto *extInstSet = *iter;
  120. if (!extInstSet->invokeVisitor(visitor))
  121. return false;
  122. }
  123. // Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
  124. // manual indexing.
  125. for (auto extIndex = extensions.size(); extIndex > 0; --extIndex) {
  126. auto *extension = extensions[extIndex - 1];
  127. if (!extension->invokeVisitor(visitor))
  128. return false;
  129. }
  130. // Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
  131. // manual indexing.
  132. for (auto capIndex = capabilities.size(); capIndex > 0; --capIndex) {
  133. auto *capability = capabilities[capIndex - 1];
  134. if (!capability->invokeVisitor(visitor))
  135. return false;
  136. }
  137. }
  138. // Traverse the regular order of a SPIR-V module.
  139. else {
  140. for (auto *cap : capabilities)
  141. if (!cap->invokeVisitor(visitor))
  142. return false;
  143. for (auto ext : extensions)
  144. if (!ext->invokeVisitor(visitor))
  145. return false;
  146. for (auto extInstSet : extInstSets)
  147. if (!extInstSet->invokeVisitor(visitor))
  148. return false;
  149. if (!memoryModel->invokeVisitor(visitor))
  150. return false;
  151. for (auto entryPoint : entryPoints)
  152. if (!entryPoint->invokeVisitor(visitor))
  153. return false;
  154. for (auto execMode : executionModes)
  155. if (!execMode->invokeVisitor(visitor))
  156. return false;
  157. for (auto *str : constStrings)
  158. if (!str->invokeVisitor(visitor))
  159. return false;
  160. if (!sources.empty())
  161. for (auto *source : sources)
  162. if (!source->invokeVisitor(visitor))
  163. return false;
  164. for (auto moduleProcess : moduleProcesses)
  165. if (!moduleProcess->invokeVisitor(visitor))
  166. return false;
  167. for (auto decoration : decorations)
  168. if (!decoration->invokeVisitor(visitor))
  169. return false;
  170. for (auto constant : constants)
  171. if (!constant->invokeVisitor(visitor))
  172. return false;
  173. for (auto var : variables)
  174. if (!var->invokeVisitor(visitor))
  175. return false;
  176. for (auto *debugInstruction : debugInstructions)
  177. if (!debugInstruction->invokeVisitor(visitor))
  178. return false;
  179. for (auto fn : functions)
  180. if (!fn->invokeVisitor(visitor, reverseOrder))
  181. return false;
  182. }
  183. if (!visitor->visit(this, Visitor::Phase::Done))
  184. return false;
  185. return true;
  186. }
  187. void SpirvModule::addFunctionToListOfSortedModuleFunctions(SpirvFunction *fn) {
  188. assert(fn && "cannot add null function to the module");
  189. functions.push_back(fn);
  190. }
  191. void SpirvModule::addFunction(SpirvFunction *fn) {
  192. assert(fn && "cannot add null function to the module");
  193. allFunctions.insert(fn);
  194. }
  195. bool SpirvModule::addCapability(SpirvCapability *cap) {
  196. assert(cap && "cannot add null capability to the module");
  197. return capabilities.insert(cap);
  198. }
  199. void SpirvModule::setMemoryModel(SpirvMemoryModel *model) {
  200. assert(model && "cannot set a null memory model");
  201. memoryModel = model;
  202. }
  203. void SpirvModule::addEntryPoint(SpirvEntryPoint *ep) {
  204. assert(ep && "cannot add null as an entry point");
  205. entryPoints.push_back(ep);
  206. }
  207. void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
  208. assert(em && "cannot add null execution mode");
  209. executionModes.push_back(em);
  210. }
  211. bool SpirvModule::addExtension(SpirvExtension *ext) {
  212. assert(ext && "cannot add null extension");
  213. return extensions.insert(ext);
  214. }
  215. void SpirvModule::addExtInstSet(SpirvExtInstImport *set) {
  216. assert(set && "cannot add null extended instruction set");
  217. extInstSets.push_back(set);
  218. }
  219. SpirvExtInstImport *SpirvModule::getExtInstSet(llvm::StringRef name) {
  220. // We expect very few (usually 1) extended instruction sets to exist in the
  221. // module, so this is not expensive.
  222. auto found = std::find_if(extInstSets.begin(), extInstSets.end(),
  223. [name](const SpirvExtInstImport *set) {
  224. return set->getExtendedInstSetName() == name;
  225. });
  226. if (found != extInstSets.end())
  227. return *found;
  228. return nullptr;
  229. }
  230. void SpirvModule::addVariable(SpirvVariable *var) {
  231. assert(var && "cannot add null variable to the module");
  232. variables.push_back(var);
  233. }
  234. void SpirvModule::addDecoration(SpirvDecoration *decor) {
  235. assert(decor && "cannot add null decoration to the module");
  236. decorations.insert(decor);
  237. }
  238. void SpirvModule::addConstant(SpirvConstant *constant) {
  239. assert(constant);
  240. constants.push_back(constant);
  241. }
  242. void SpirvModule::addString(SpirvString *str) {
  243. assert(str);
  244. constStrings.push_back(str);
  245. }
  246. void SpirvModule::addSource(SpirvSource *src) {
  247. assert(src);
  248. sources.push_back(src);
  249. }
  250. void SpirvModule::addDebugInfo(SpirvDebugInstruction *info) {
  251. assert(info);
  252. debugInstructions.push_back(info);
  253. }
  254. void SpirvModule::addModuleProcessed(SpirvModuleProcessed *p) {
  255. assert(p);
  256. moduleProcesses.push_back(p);
  257. }
  258. } // end namespace spirv
  259. } // end namespace clang