ModuleBuilder.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "spirv/1.0//spirv.hpp11"
  11. #include "clang/SPIRV/InstBuilder.h"
  12. #include "llvm/llvm_assert/assert.h"
  13. namespace clang {
  14. namespace spirv {
  15. ModuleBuilder::ModuleBuilder(SPIRVContext *C)
  16. : theContext(*C), theModule(), theFunction(nullptr), insertPoint(nullptr),
  17. instBuilder(nullptr) {
  18. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  19. this->constructSite = std::move(words);
  20. });
  21. }
  22. std::vector<uint32_t> ModuleBuilder::takeModule() {
  23. theModule.setBound(theContext.getNextId());
  24. std::vector<uint32_t> binary;
  25. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  26. binary.insert(binary.end(), words.begin(), words.end());
  27. });
  28. theModule.take(&ib);
  29. return std::move(binary);
  30. }
  31. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  32. llvm::StringRef funcName) {
  33. if (theFunction) {
  34. assert(false && "found nested function");
  35. return 0;
  36. }
  37. const uint32_t fId = theContext.takeNextId();
  38. theFunction = llvm::make_unique<Function>(
  39. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  40. theModule.addDebugName(fId, funcName);
  41. return fId;
  42. }
  43. uint32_t ModuleBuilder::addFnParameter(uint32_t type, llvm::StringRef name) {
  44. assert(theFunction && "found detached parameter");
  45. const uint32_t pointerType =
  46. getPointerType(type, spv::StorageClass::Function);
  47. const uint32_t paramId = theContext.takeNextId();
  48. theFunction->addParameter(pointerType, paramId);
  49. theModule.addDebugName(paramId, name);
  50. return paramId;
  51. }
  52. uint32_t ModuleBuilder::addFnVariable(uint32_t type, llvm::StringRef name,
  53. llvm::Optional<uint32_t> init) {
  54. assert(theFunction && "found detached local variable");
  55. const uint32_t varId = theContext.takeNextId();
  56. theFunction->addVariable(type, varId, init);
  57. theModule.addDebugName(varId, name);
  58. return varId;
  59. }
  60. bool ModuleBuilder::endFunction() {
  61. if (theFunction == nullptr) {
  62. assert(false && "no active function");
  63. return false;
  64. }
  65. // Move all basic blocks into the current function.
  66. // TODO: we should adjust the order the basic blocks according to
  67. // SPIR-V validation rules.
  68. for (auto &bb : basicBlocks) {
  69. theFunction->addBasicBlock(std::move(bb.second));
  70. }
  71. basicBlocks.clear();
  72. theModule.addFunction(std::move(theFunction));
  73. theFunction.reset(nullptr);
  74. insertPoint = nullptr;
  75. return true;
  76. }
  77. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  78. if (theFunction == nullptr) {
  79. assert(false && "found detached basic block");
  80. return 0;
  81. }
  82. const uint32_t labelId = theContext.takeNextId();
  83. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId);
  84. theModule.addDebugName(labelId, name);
  85. return labelId;
  86. }
  87. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  88. assert(insertPoint && "null insert point");
  89. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  90. }
  91. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  92. assert(insertPoint && "null insert point");
  93. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  94. }
  95. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  96. assert(insertPoint && "null insert point");
  97. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  98. }
  99. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  100. insertPoint = getBasicBlock(labelId);
  101. }
  102. uint32_t
  103. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  104. llvm::ArrayRef<uint32_t> constituents) {
  105. assert(insertPoint && "null insert point");
  106. const uint32_t resultId = theContext.takeNextId();
  107. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  108. insertPoint->appendInstruction(std::move(constructSite));
  109. return resultId;
  110. }
  111. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  112. assert(insertPoint && "null insert point");
  113. const uint32_t resultId = theContext.takeNextId();
  114. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  115. insertPoint->appendInstruction(std::move(constructSite));
  116. return resultId;
  117. }
  118. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  119. assert(insertPoint && "null insert point");
  120. instBuilder.opStore(address, value, llvm::None).x();
  121. insertPoint->appendInstruction(std::move(constructSite));
  122. }
  123. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  124. llvm::ArrayRef<uint32_t> indexes) {
  125. assert(insertPoint && "null insert point");
  126. const uint32_t id = theContext.takeNextId();
  127. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  128. insertPoint->appendInstruction(std::move(constructSite));
  129. return id;
  130. }
  131. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  132. uint32_t lhs, uint32_t rhs) {
  133. assert(insertPoint && "null insert point");
  134. const uint32_t id = theContext.takeNextId();
  135. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  136. insertPoint->appendInstruction(std::move(constructSite));
  137. return id;
  138. }
  139. void ModuleBuilder::createBranch(uint32_t targetLabel) {
  140. assert(insertPoint && "null insert point");
  141. instBuilder.opBranch(targetLabel).x();
  142. insertPoint->appendInstruction(std::move(constructSite));
  143. }
  144. void ModuleBuilder::createConditionalBranch(uint32_t condition,
  145. uint32_t trueLabel,
  146. uint32_t falseLabel,
  147. uint32_t mergeLabel,
  148. uint32_t continueLabel) {
  149. assert(insertPoint && "null insert point");
  150. if (mergeLabel) {
  151. if (continueLabel) {
  152. instBuilder
  153. .opLoopMerge(mergeLabel, continueLabel,
  154. spv::LoopControlMask::MaskNone)
  155. .x();
  156. insertPoint->appendInstruction(std::move(constructSite));
  157. } else {
  158. instBuilder
  159. .opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  160. .x();
  161. insertPoint->appendInstruction(std::move(constructSite));
  162. }
  163. }
  164. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  165. insertPoint->appendInstruction(std::move(constructSite));
  166. }
  167. void ModuleBuilder::createReturn() {
  168. assert(insertPoint && "null insert point");
  169. instBuilder.opReturn().x();
  170. insertPoint->appendInstruction(std::move(constructSite));
  171. }
  172. void ModuleBuilder::createReturnValue(uint32_t value) {
  173. assert(insertPoint && "null insert point");
  174. instBuilder.opReturnValue(value).x();
  175. insertPoint->appendInstruction(std::move(constructSite));
  176. }
  177. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  178. spv::ExecutionMode em,
  179. const std::vector<uint32_t> &params) {
  180. instBuilder.opExecutionMode(entryPointId, em);
  181. for (const auto &param : params) {
  182. instBuilder.literalInteger(param);
  183. }
  184. instBuilder.x();
  185. theModule.addExecutionMode(std::move(constructSite));
  186. }
  187. uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
  188. spv::StorageClass storageClass) {
  189. const uint32_t pointerType = getPointerType(type, storageClass);
  190. const uint32_t varId = theContext.takeNextId();
  191. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  192. theModule.addVariable(std::move(constructSite));
  193. return varId;
  194. }
  195. uint32_t ModuleBuilder::addStageBuiltinVariable(uint32_t type,
  196. spv::BuiltIn builtin) {
  197. spv::StorageClass sc = spv::StorageClass::Input;
  198. switch (builtin) {
  199. case spv::BuiltIn::Position:
  200. case spv::BuiltIn::PointSize:
  201. // TODO: add the rest output builtins
  202. sc = spv::StorageClass::Output;
  203. break;
  204. default:
  205. break;
  206. }
  207. const uint32_t pointerType = getPointerType(type, sc);
  208. const uint32_t varId = theContext.takeNextId();
  209. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  210. theModule.addVariable(std::move(constructSite));
  211. // Decorate with the specified Builtin
  212. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  213. theModule.addDecoration(*d, varId);
  214. return varId;
  215. }
  216. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  217. const Decoration *d =
  218. Decoration::getLocation(theContext, location, llvm::None);
  219. theModule.addDecoration(*d, targetId);
  220. }
  221. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  222. \
  223. uint32_t ModuleBuilder::get##ty##Type() { \
  224. const Type *type = Type::get##ty(theContext); \
  225. const uint32_t typeId = theContext.getResultIdForType(type); \
  226. theModule.addType(type, typeId); \
  227. return typeId; \
  228. \
  229. }
  230. IMPL_GET_PRIMITIVE_TYPE(Void)
  231. IMPL_GET_PRIMITIVE_TYPE(Bool)
  232. IMPL_GET_PRIMITIVE_TYPE(Int32)
  233. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  234. IMPL_GET_PRIMITIVE_TYPE(Float32)
  235. #undef IMPL_GET_PRIMITIVE_TYPE
  236. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  237. const Type *type = nullptr;
  238. switch (elemCount) {
  239. case 2:
  240. type = Type::getVec2(theContext, elemType);
  241. break;
  242. case 3:
  243. type = Type::getVec3(theContext, elemType);
  244. break;
  245. case 4:
  246. type = Type::getVec4(theContext, elemType);
  247. break;
  248. default:
  249. assert(false && "unhandled vector size");
  250. // Error found. Return 0 as the <result-id> directly.
  251. return 0;
  252. }
  253. const uint32_t typeId = theContext.getResultIdForType(type);
  254. theModule.addType(type, typeId);
  255. return typeId;
  256. }
  257. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  258. spv::StorageClass storageClass) {
  259. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  260. const uint32_t typeId = theContext.getResultIdForType(type);
  261. theModule.addType(type, typeId);
  262. return typeId;
  263. }
  264. uint32_t ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes) {
  265. const Type *type = Type::getStruct(theContext, fieldTypes);
  266. const uint32_t typeId = theContext.getResultIdForType(type);
  267. theModule.addType(type, typeId);
  268. return typeId;
  269. }
  270. uint32_t
  271. ModuleBuilder::getFunctionType(uint32_t returnType,
  272. const std::vector<uint32_t> &paramTypes) {
  273. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  274. const uint32_t typeId = theContext.getResultIdForType(type);
  275. theModule.addType(type, typeId);
  276. return typeId;
  277. }
  278. uint32_t ModuleBuilder::getConstantBool(bool value) {
  279. const uint32_t typeId = getBoolType();
  280. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  281. : Constant::getFalse(theContext, typeId);
  282. const uint32_t constId = theContext.getResultIdForConstant(constant);
  283. theModule.addConstant(constant, constId);
  284. return constId;
  285. }
  286. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  287. \
  288. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  289. const uint32_t typeId = get##builderTy##Type(); \
  290. const Constant *constant = \
  291. Constant::get##builderTy(theContext, typeId, value); \
  292. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  293. theModule.addConstant(constant, constId); \
  294. return constId; \
  295. \
  296. }
  297. IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
  298. IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
  299. IMPL_GET_PRIMITIVE_CONST(Float32, float)
  300. #undef IMPL_GET_PRIMITIVE_VALUE
  301. uint32_t
  302. ModuleBuilder::getConstantComposite(uint32_t typeId,
  303. llvm::ArrayRef<uint32_t> constituents) {
  304. const Constant *constant =
  305. Constant::getComposite(theContext, typeId, constituents);
  306. const uint32_t constId = theContext.getResultIdForConstant(constant);
  307. theModule.addConstant(constant, constId);
  308. return constId;
  309. }
  310. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  311. auto it = basicBlocks.find(labelId);
  312. if (it == basicBlocks.end()) {
  313. assert(false && "invalid <label-id>");
  314. return nullptr;
  315. }
  316. return it->second.get();
  317. }
  318. } // end namespace spirv
  319. } // end namespace clang