ModuleBuilder.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "spirv/1.0//spirv.hpp11"
  11. #include "clang/SPIRV/InstBuilder.h"
  12. #include "llvm/llvm_assert/assert.h"
  13. namespace clang {
  14. namespace spirv {
  15. ModuleBuilder::ModuleBuilder(SPIRVContext *C)
  16. : theContext(*C), theModule(), theFunction(nullptr), insertPoint(nullptr),
  17. instBuilder(nullptr), glslExtSetId(0) {
  18. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  19. this->constructSite = std::move(words);
  20. });
  21. }
  22. std::vector<uint32_t> ModuleBuilder::takeModule() {
  23. theModule.setBound(theContext.getNextId());
  24. std::vector<uint32_t> binary;
  25. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  26. binary.insert(binary.end(), words.begin(), words.end());
  27. });
  28. theModule.take(&ib);
  29. return binary;
  30. }
  31. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  32. llvm::StringRef funcName, uint32_t fId) {
  33. if (theFunction) {
  34. assert(false && "found nested function");
  35. return 0;
  36. }
  37. // If the caller doesn't supply a function <result-id>, we need to get one.
  38. if (!fId)
  39. fId = theContext.takeNextId();
  40. theFunction = llvm::make_unique<Function>(
  41. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  42. theModule.addDebugName(fId, funcName);
  43. return fId;
  44. }
  45. uint32_t ModuleBuilder::addFnParam(uint32_t ptrType, llvm::StringRef name) {
  46. assert(theFunction && "found detached parameter");
  47. const uint32_t paramId = theContext.takeNextId();
  48. theFunction->addParameter(ptrType, paramId);
  49. theModule.addDebugName(paramId, name);
  50. return paramId;
  51. }
  52. uint32_t ModuleBuilder::addFnVar(uint32_t varType, llvm::StringRef name,
  53. llvm::Optional<uint32_t> init) {
  54. assert(theFunction && "found detached local variable");
  55. const uint32_t ptrType = getPointerType(varType, spv::StorageClass::Function);
  56. const uint32_t varId = theContext.takeNextId();
  57. theFunction->addVariable(ptrType, varId, init);
  58. theModule.addDebugName(varId, name);
  59. return varId;
  60. }
  61. bool ModuleBuilder::endFunction() {
  62. if (theFunction == nullptr) {
  63. assert(false && "no active function");
  64. return false;
  65. }
  66. // Move all basic blocks into the current function.
  67. // TODO: we should adjust the order the basic blocks according to
  68. // SPIR-V validation rules.
  69. for (auto &bb : basicBlocks) {
  70. theFunction->addBasicBlock(std::move(bb.second));
  71. }
  72. basicBlocks.clear();
  73. theModule.addFunction(std::move(theFunction));
  74. theFunction.reset(nullptr);
  75. insertPoint = nullptr;
  76. return true;
  77. }
  78. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name,
  79. bool isReachable) {
  80. if (theFunction == nullptr) {
  81. assert(false && "found detached basic block");
  82. return 0;
  83. }
  84. const uint32_t labelId = theContext.takeNextId();
  85. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId, isReachable);
  86. // OpName instructions should not be added for unreachable basic blocks
  87. // because such blocks are *not* discovered by BlockReadableOrderVisitor and
  88. // therefore they are not emitted.
  89. // The newly created basic block is unreachable if specified by the caller,
  90. // or, if this block is being created by a block that is already unreachable.
  91. if (isReachable && (!insertPoint || insertPoint->isReachable()))
  92. theModule.addDebugName(labelId, name);
  93. return labelId;
  94. }
  95. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  96. assert(insertPoint && "null insert point");
  97. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  98. }
  99. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  100. assert(insertPoint && "null insert point");
  101. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  102. }
  103. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  104. assert(insertPoint && "null insert point");
  105. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  106. }
  107. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  108. insertPoint = getBasicBlock(labelId);
  109. }
  110. uint32_t
  111. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  112. llvm::ArrayRef<uint32_t> constituents) {
  113. assert(insertPoint && "null insert point");
  114. const uint32_t resultId = theContext.takeNextId();
  115. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  116. insertPoint->appendInstruction(std::move(constructSite));
  117. return resultId;
  118. }
  119. uint32_t
  120. ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
  121. llvm::ArrayRef<uint32_t> indexes) {
  122. assert(insertPoint && "null insert point");
  123. const uint32_t resultId = theContext.takeNextId();
  124. instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
  125. insertPoint->appendInstruction(std::move(constructSite));
  126. return resultId;
  127. }
  128. uint32_t
  129. ModuleBuilder::createVectorShuffle(uint32_t resultType, uint32_t vector1,
  130. uint32_t vector2,
  131. llvm::ArrayRef<uint32_t> selectors) {
  132. assert(insertPoint && "null insert point");
  133. const uint32_t resultId = theContext.takeNextId();
  134. instBuilder.opVectorShuffle(resultType, resultId, vector1, vector2, selectors)
  135. .x();
  136. insertPoint->appendInstruction(std::move(constructSite));
  137. return resultId;
  138. }
  139. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  140. assert(insertPoint && "null insert point");
  141. const uint32_t resultId = theContext.takeNextId();
  142. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  143. insertPoint->appendInstruction(std::move(constructSite));
  144. return resultId;
  145. }
  146. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  147. assert(insertPoint && "null insert point");
  148. instBuilder.opStore(address, value, llvm::None).x();
  149. insertPoint->appendInstruction(std::move(constructSite));
  150. }
  151. uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
  152. uint32_t functionId,
  153. llvm::ArrayRef<uint32_t> params) {
  154. assert(insertPoint && "null insert point");
  155. const uint32_t id = theContext.takeNextId();
  156. instBuilder.opFunctionCall(returnType, id, functionId, params).x();
  157. insertPoint->appendInstruction(std::move(constructSite));
  158. return id;
  159. }
  160. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  161. llvm::ArrayRef<uint32_t> indexes) {
  162. assert(insertPoint && "null insert point");
  163. const uint32_t id = theContext.takeNextId();
  164. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  165. insertPoint->appendInstruction(std::move(constructSite));
  166. return id;
  167. }
  168. uint32_t ModuleBuilder::createUnaryOp(spv::Op op, uint32_t resultType,
  169. uint32_t operand) {
  170. assert(insertPoint && "null insert point");
  171. const uint32_t id = theContext.takeNextId();
  172. instBuilder.unaryOp(op, resultType, id, operand).x();
  173. insertPoint->appendInstruction(std::move(constructSite));
  174. return id;
  175. }
  176. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  177. uint32_t lhs, uint32_t rhs) {
  178. assert(insertPoint && "null insert point");
  179. const uint32_t id = theContext.takeNextId();
  180. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  181. insertPoint->appendInstruction(std::move(constructSite));
  182. return id;
  183. }
  184. uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
  185. uint32_t trueValue, uint32_t falseValue) {
  186. assert(insertPoint && "null insert point");
  187. const uint32_t id = theContext.takeNextId();
  188. instBuilder.opSelect(resultType, id, condition, trueValue, falseValue).x();
  189. insertPoint->appendInstruction(std::move(constructSite));
  190. return id;
  191. }
  192. void ModuleBuilder::createSwitch(
  193. uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
  194. llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
  195. assert(insertPoint && "null insert point");
  196. // Create the OpSelectioMerege.
  197. instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  198. .x();
  199. insertPoint->appendInstruction(std::move(constructSite));
  200. // Create the OpSwitch.
  201. instBuilder.opSwitch(selector, defaultLabel, target).x();
  202. insertPoint->appendInstruction(std::move(constructSite));
  203. }
  204. void ModuleBuilder::createKill() {
  205. assert(insertPoint && "null insert point");
  206. assert(!isCurrentBasicBlockTerminated());
  207. instBuilder.opKill().x();
  208. insertPoint->appendInstruction(std::move(constructSite));
  209. }
  210. void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
  211. uint32_t continueBB,
  212. spv::LoopControlMask loopControl) {
  213. assert(insertPoint && "null insert point");
  214. if (mergeBB && continueBB) {
  215. instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
  216. insertPoint->appendInstruction(std::move(constructSite));
  217. }
  218. instBuilder.opBranch(targetLabel).x();
  219. insertPoint->appendInstruction(std::move(constructSite));
  220. }
  221. void ModuleBuilder::createConditionalBranch(
  222. uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
  223. uint32_t mergeLabel, uint32_t continueLabel,
  224. spv::SelectionControlMask selectionControl,
  225. spv::LoopControlMask loopControl) {
  226. assert(insertPoint && "null insert point");
  227. if (mergeLabel) {
  228. if (continueLabel) {
  229. instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
  230. insertPoint->appendInstruction(std::move(constructSite));
  231. } else {
  232. instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
  233. insertPoint->appendInstruction(std::move(constructSite));
  234. }
  235. }
  236. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  237. insertPoint->appendInstruction(std::move(constructSite));
  238. }
  239. void ModuleBuilder::createReturn() {
  240. assert(insertPoint && "null insert point");
  241. instBuilder.opReturn().x();
  242. insertPoint->appendInstruction(std::move(constructSite));
  243. }
  244. void ModuleBuilder::createReturnValue(uint32_t value) {
  245. assert(insertPoint && "null insert point");
  246. instBuilder.opReturnValue(value).x();
  247. insertPoint->appendInstruction(std::move(constructSite));
  248. }
  249. uint32_t ModuleBuilder::createExtInst(uint32_t resultType, uint32_t setId,
  250. uint32_t instId,
  251. llvm::ArrayRef<uint32_t> operands) {
  252. assert(insertPoint && "null insert point");
  253. uint32_t resultId = theContext.takeNextId();
  254. instBuilder.opExtInst(resultType, resultId, setId, instId, operands).x();
  255. insertPoint->appendInstruction(std::move(constructSite));
  256. return resultId;
  257. }
  258. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  259. spv::ExecutionMode em,
  260. const std::vector<uint32_t> &params) {
  261. instBuilder.opExecutionMode(entryPointId, em);
  262. for (const auto &param : params) {
  263. instBuilder.literalInteger(param);
  264. }
  265. instBuilder.x();
  266. theModule.addExecutionMode(std::move(constructSite));
  267. }
  268. uint32_t ModuleBuilder::getGLSLExtInstSet() {
  269. if (glslExtSetId == 0) {
  270. glslExtSetId = theContext.takeNextId();
  271. theModule.addExtInstSet(glslExtSetId, "GLSL.std.450");
  272. }
  273. return glslExtSetId;
  274. }
  275. uint32_t ModuleBuilder::addStageIOVar(uint32_t type,
  276. spv::StorageClass storageClass,
  277. std::string name) {
  278. const uint32_t pointerType = getPointerType(type, storageClass);
  279. const uint32_t varId = theContext.takeNextId();
  280. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  281. theModule.addVariable(std::move(constructSite));
  282. theModule.addDebugName(varId, name);
  283. return varId;
  284. }
  285. uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
  286. spv::BuiltIn builtin) {
  287. const uint32_t pointerType = getPointerType(type, sc);
  288. const uint32_t varId = theContext.takeNextId();
  289. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  290. theModule.addVariable(std::move(constructSite));
  291. // Decorate with the specified Builtin
  292. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  293. theModule.addDecoration(*d, varId);
  294. return varId;
  295. }
  296. uint32_t ModuleBuilder::addFileVar(uint32_t type, llvm::StringRef name,
  297. llvm::Optional<uint32_t> init) {
  298. const uint32_t pointerType = getPointerType(type, spv::StorageClass::Private);
  299. const uint32_t varId = theContext.takeNextId();
  300. instBuilder.opVariable(pointerType, varId, spv::StorageClass::Private, init)
  301. .x();
  302. theModule.addVariable(std::move(constructSite));
  303. theModule.addDebugName(varId, name);
  304. return varId;
  305. }
  306. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  307. const Decoration *d =
  308. Decoration::getLocation(theContext, location, llvm::None);
  309. theModule.addDecoration(*d, targetId);
  310. }
  311. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  312. \
  313. uint32_t ModuleBuilder::get##ty##Type() { \
  314. const Type *type = Type::get##ty(theContext); \
  315. const uint32_t typeId = theContext.getResultIdForType(type); \
  316. theModule.addType(type, typeId); \
  317. return typeId; \
  318. \
  319. }
  320. IMPL_GET_PRIMITIVE_TYPE(Void)
  321. IMPL_GET_PRIMITIVE_TYPE(Bool)
  322. IMPL_GET_PRIMITIVE_TYPE(Int32)
  323. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  324. IMPL_GET_PRIMITIVE_TYPE(Float32)
  325. #undef IMPL_GET_PRIMITIVE_TYPE
  326. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  327. const Type *type = nullptr;
  328. switch (elemCount) {
  329. case 2:
  330. type = Type::getVec2(theContext, elemType);
  331. break;
  332. case 3:
  333. type = Type::getVec3(theContext, elemType);
  334. break;
  335. case 4:
  336. type = Type::getVec4(theContext, elemType);
  337. break;
  338. default:
  339. assert(false && "unhandled vector size");
  340. // Error found. Return 0 as the <result-id> directly.
  341. return 0;
  342. }
  343. const uint32_t typeId = theContext.getResultIdForType(type);
  344. theModule.addType(type, typeId);
  345. return typeId;
  346. }
  347. uint32_t ModuleBuilder::getMatType(uint32_t colType, uint32_t colCount) {
  348. const Type *type = Type::getMatrix(theContext, colType, colCount);
  349. const uint32_t typeId = theContext.getResultIdForType(type);
  350. theModule.addType(type, typeId);
  351. return typeId;
  352. }
  353. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  354. spv::StorageClass storageClass) {
  355. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  356. const uint32_t typeId = theContext.getResultIdForType(type);
  357. theModule.addType(type, typeId);
  358. return typeId;
  359. }
  360. uint32_t
  361. ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
  362. llvm::StringRef structName,
  363. llvm::ArrayRef<llvm::StringRef> fieldNames) {
  364. const Type *type = Type::getStruct(theContext, fieldTypes);
  365. bool isRegistered = false;
  366. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  367. theModule.addType(type, typeId);
  368. // TODO: Probably we should check duplication and do nothing if trying to add
  369. // the same debug name for the same entity in addDebugName().
  370. if (!isRegistered) {
  371. theModule.addDebugName(typeId, structName);
  372. if (!fieldNames.empty()) {
  373. assert(fieldNames.size() == fieldTypes.size());
  374. for (uint32_t i = 0; i < fieldNames.size(); ++i)
  375. theModule.addDebugName(typeId, fieldNames[i],
  376. llvm::Optional<uint32_t>(i));
  377. }
  378. }
  379. return typeId;
  380. }
  381. uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
  382. llvm::ArrayRef<uint32_t> paramTypes) {
  383. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  384. const uint32_t typeId = theContext.getResultIdForType(type);
  385. theModule.addType(type, typeId);
  386. return typeId;
  387. }
  388. uint32_t ModuleBuilder::getConstantBool(bool value) {
  389. const uint32_t typeId = getBoolType();
  390. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  391. : Constant::getFalse(theContext, typeId);
  392. const uint32_t constId = theContext.getResultIdForConstant(constant);
  393. theModule.addConstant(constant, constId);
  394. return constId;
  395. }
  396. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  397. \
  398. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  399. const uint32_t typeId = get##builderTy##Type(); \
  400. const Constant *constant = \
  401. Constant::get##builderTy(theContext, typeId, value); \
  402. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  403. theModule.addConstant(constant, constId); \
  404. return constId; \
  405. \
  406. }
  407. IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
  408. IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
  409. IMPL_GET_PRIMITIVE_CONST(Float32, float)
  410. #undef IMPL_GET_PRIMITIVE_VALUE
  411. uint32_t
  412. ModuleBuilder::getConstantComposite(uint32_t typeId,
  413. llvm::ArrayRef<uint32_t> constituents) {
  414. const Constant *constant =
  415. Constant::getComposite(theContext, typeId, constituents);
  416. const uint32_t constId = theContext.getResultIdForConstant(constant);
  417. theModule.addConstant(constant, constId);
  418. return constId;
  419. }
  420. uint32_t ModuleBuilder::getConstantNull(uint32_t typeId) {
  421. const Constant *constant = Constant::getNull(theContext, typeId);
  422. const uint32_t constId = theContext.getResultIdForConstant(constant);
  423. theModule.addConstant(constant, constId);
  424. return constId;
  425. }
  426. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  427. auto it = basicBlocks.find(labelId);
  428. if (it == basicBlocks.end()) {
  429. assert(false && "invalid <label-id>");
  430. return nullptr;
  431. }
  432. return it->second.get();
  433. }
  434. } // end namespace spirv
  435. } // end namespace clang