ModuleBuilder.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "spirv/1.0//spirv.hpp11"
  11. #include "clang/SPIRV/InstBuilder.h"
  12. #include "llvm/llvm_assert/assert.h"
  13. namespace clang {
  14. namespace spirv {
  15. ModuleBuilder::ModuleBuilder(SPIRVContext *C)
  16. : theContext(*C), theModule(), theFunction(nullptr), insertPoint(nullptr),
  17. instBuilder(nullptr), glslExtSetId(0) {
  18. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  19. this->constructSite = std::move(words);
  20. });
  21. }
  22. std::vector<uint32_t> ModuleBuilder::takeModule() {
  23. theModule.setBound(theContext.getNextId());
  24. std::vector<uint32_t> binary;
  25. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  26. binary.insert(binary.end(), words.begin(), words.end());
  27. });
  28. theModule.take(&ib);
  29. return binary;
  30. }
  31. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  32. llvm::StringRef funcName, uint32_t fId) {
  33. if (theFunction) {
  34. assert(false && "found nested function");
  35. return 0;
  36. }
  37. // If the caller doesn't supply a function <result-id>, we need to get one.
  38. if (!fId)
  39. fId = theContext.takeNextId();
  40. theFunction = llvm::make_unique<Function>(
  41. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  42. theModule.addDebugName(fId, funcName);
  43. return fId;
  44. }
  45. uint32_t ModuleBuilder::addFnParam(uint32_t ptrType, llvm::StringRef name) {
  46. assert(theFunction && "found detached parameter");
  47. const uint32_t paramId = theContext.takeNextId();
  48. theFunction->addParameter(ptrType, paramId);
  49. theModule.addDebugName(paramId, name);
  50. return paramId;
  51. }
  52. uint32_t ModuleBuilder::addFnVar(uint32_t varType, llvm::StringRef name,
  53. llvm::Optional<uint32_t> init) {
  54. assert(theFunction && "found detached local variable");
  55. const uint32_t ptrType = getPointerType(varType, spv::StorageClass::Function);
  56. const uint32_t varId = theContext.takeNextId();
  57. theFunction->addVariable(ptrType, varId, init);
  58. theModule.addDebugName(varId, name);
  59. return varId;
  60. }
  61. bool ModuleBuilder::endFunction() {
  62. if (theFunction == nullptr) {
  63. assert(false && "no active function");
  64. return false;
  65. }
  66. // Move all basic blocks into the current function.
  67. // TODO: we should adjust the order the basic blocks according to
  68. // SPIR-V validation rules.
  69. for (auto &bb : basicBlocks) {
  70. theFunction->addBasicBlock(std::move(bb.second));
  71. }
  72. basicBlocks.clear();
  73. theModule.addFunction(std::move(theFunction));
  74. theFunction.reset(nullptr);
  75. insertPoint = nullptr;
  76. return true;
  77. }
  78. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  79. if (theFunction == nullptr) {
  80. assert(false && "found detached basic block");
  81. return 0;
  82. }
  83. const uint32_t labelId = theContext.takeNextId();
  84. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId, name);
  85. return labelId;
  86. }
  87. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  88. assert(insertPoint && "null insert point");
  89. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  90. }
  91. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  92. assert(insertPoint && "null insert point");
  93. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  94. }
  95. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  96. assert(insertPoint && "null insert point");
  97. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  98. }
  99. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  100. insertPoint = getBasicBlock(labelId);
  101. }
  102. uint32_t
  103. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  104. llvm::ArrayRef<uint32_t> constituents) {
  105. assert(insertPoint && "null insert point");
  106. const uint32_t resultId = theContext.takeNextId();
  107. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  108. insertPoint->appendInstruction(std::move(constructSite));
  109. return resultId;
  110. }
  111. uint32_t
  112. ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
  113. llvm::ArrayRef<uint32_t> indexes) {
  114. assert(insertPoint && "null insert point");
  115. const uint32_t resultId = theContext.takeNextId();
  116. instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
  117. insertPoint->appendInstruction(std::move(constructSite));
  118. return resultId;
  119. }
  120. uint32_t
  121. ModuleBuilder::createVectorShuffle(uint32_t resultType, uint32_t vector1,
  122. uint32_t vector2,
  123. llvm::ArrayRef<uint32_t> selectors) {
  124. assert(insertPoint && "null insert point");
  125. const uint32_t resultId = theContext.takeNextId();
  126. instBuilder.opVectorShuffle(resultType, resultId, vector1, vector2, selectors)
  127. .x();
  128. insertPoint->appendInstruction(std::move(constructSite));
  129. return resultId;
  130. }
  131. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  132. assert(insertPoint && "null insert point");
  133. const uint32_t resultId = theContext.takeNextId();
  134. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  135. insertPoint->appendInstruction(std::move(constructSite));
  136. return resultId;
  137. }
  138. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  139. assert(insertPoint && "null insert point");
  140. instBuilder.opStore(address, value, llvm::None).x();
  141. insertPoint->appendInstruction(std::move(constructSite));
  142. }
  143. uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
  144. uint32_t functionId,
  145. llvm::ArrayRef<uint32_t> params) {
  146. assert(insertPoint && "null insert point");
  147. const uint32_t id = theContext.takeNextId();
  148. instBuilder.opFunctionCall(returnType, id, functionId, params).x();
  149. insertPoint->appendInstruction(std::move(constructSite));
  150. return id;
  151. }
  152. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  153. llvm::ArrayRef<uint32_t> indexes) {
  154. assert(insertPoint && "null insert point");
  155. const uint32_t id = theContext.takeNextId();
  156. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  157. insertPoint->appendInstruction(std::move(constructSite));
  158. return id;
  159. }
  160. uint32_t ModuleBuilder::createUnaryOp(spv::Op op, uint32_t resultType,
  161. uint32_t operand) {
  162. assert(insertPoint && "null insert point");
  163. const uint32_t id = theContext.takeNextId();
  164. instBuilder.unaryOp(op, resultType, id, operand).x();
  165. insertPoint->appendInstruction(std::move(constructSite));
  166. return id;
  167. }
  168. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  169. uint32_t lhs, uint32_t rhs) {
  170. assert(insertPoint && "null insert point");
  171. const uint32_t id = theContext.takeNextId();
  172. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  173. insertPoint->appendInstruction(std::move(constructSite));
  174. return id;
  175. }
  176. uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
  177. uint32_t orignalValuePtr,
  178. uint32_t scopeId,
  179. uint32_t memorySemanticsId,
  180. uint32_t valueToOp) {
  181. assert(insertPoint && "null insert point");
  182. const uint32_t id = theContext.takeNextId();
  183. switch (opcode) {
  184. case spv::Op::OpAtomicIAdd:
  185. instBuilder.opAtomicIAdd(resultType, id, orignalValuePtr, scopeId,
  186. memorySemanticsId, valueToOp);
  187. break;
  188. case spv::Op::OpAtomicISub:
  189. instBuilder.opAtomicISub(resultType, id, orignalValuePtr, scopeId,
  190. memorySemanticsId, valueToOp);
  191. break;
  192. case spv::Op::OpAtomicAnd:
  193. instBuilder.opAtomicAnd(resultType, id, orignalValuePtr, scopeId,
  194. memorySemanticsId, valueToOp);
  195. break;
  196. case spv::Op::OpAtomicOr:
  197. instBuilder.opAtomicOr(resultType, id, orignalValuePtr, scopeId,
  198. memorySemanticsId, valueToOp);
  199. break;
  200. case spv::Op::OpAtomicXor:
  201. instBuilder.opAtomicXor(resultType, id, orignalValuePtr, scopeId,
  202. memorySemanticsId, valueToOp);
  203. break;
  204. case spv::Op::OpAtomicUMax:
  205. instBuilder.opAtomicUMax(resultType, id, orignalValuePtr, scopeId,
  206. memorySemanticsId, valueToOp);
  207. break;
  208. case spv::Op::OpAtomicUMin:
  209. instBuilder.opAtomicUMin(resultType, id, orignalValuePtr, scopeId,
  210. memorySemanticsId, valueToOp);
  211. break;
  212. case spv::Op::OpAtomicSMax:
  213. instBuilder.opAtomicSMax(resultType, id, orignalValuePtr, scopeId,
  214. memorySemanticsId, valueToOp);
  215. break;
  216. case spv::Op::OpAtomicSMin:
  217. instBuilder.opAtomicSMin(resultType, id, orignalValuePtr, scopeId,
  218. memorySemanticsId, valueToOp);
  219. break;
  220. case spv::Op::OpAtomicExchange:
  221. instBuilder.opAtomicExchange(resultType, id, orignalValuePtr, scopeId,
  222. memorySemanticsId, valueToOp);
  223. break;
  224. default:
  225. assert(false && "unimplemented atomic opcode");
  226. }
  227. instBuilder.x();
  228. insertPoint->appendInstruction(std::move(constructSite));
  229. return id;
  230. }
  231. uint32_t ModuleBuilder::createAtomicCompareExchange(
  232. uint32_t resultType, uint32_t orignalValuePtr, uint32_t scopeId,
  233. uint32_t equalMemorySemanticsId, uint32_t unequalMemorySemanticsId,
  234. uint32_t valueToOp, uint32_t comparator) {
  235. assert(insertPoint && "null insert point");
  236. const uint32_t id = theContext.takeNextId();
  237. instBuilder.opAtomicCompareExchange(
  238. resultType, id, orignalValuePtr, scopeId, equalMemorySemanticsId,
  239. unequalMemorySemanticsId, valueToOp, comparator);
  240. instBuilder.x();
  241. insertPoint->appendInstruction(std::move(constructSite));
  242. return id;
  243. }
  244. spv::ImageOperandsMask ModuleBuilder::composeImageOperandsMask(
  245. uint32_t bias, uint32_t lod, const std::pair<uint32_t, uint32_t> &grad,
  246. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  247. uint32_t sample, llvm::SmallVectorImpl<uint32_t> *orderedParams) {
  248. using spv::ImageOperandsMask;
  249. // SPIR-V Image Operands from least significant bit to most significant bit
  250. // Bias, Lod, Grad, ConstOffset, Offset, ConstOffsets, Sample, MinLod
  251. auto mask = ImageOperandsMask::MaskNone;
  252. orderedParams->clear();
  253. if (bias) {
  254. mask = mask | ImageOperandsMask::Bias;
  255. orderedParams->push_back(bias);
  256. }
  257. if (lod) {
  258. mask = mask | ImageOperandsMask::Lod;
  259. orderedParams->push_back(lod);
  260. }
  261. if (grad.first && grad.second) {
  262. mask = mask | ImageOperandsMask::Grad;
  263. orderedParams->push_back(grad.first);
  264. orderedParams->push_back(grad.second);
  265. }
  266. if (constOffset) {
  267. mask = mask | ImageOperandsMask::ConstOffset;
  268. orderedParams->push_back(constOffset);
  269. }
  270. if (varOffset) {
  271. mask = mask | ImageOperandsMask::Offset;
  272. requireCapability(spv::Capability::ImageGatherExtended);
  273. orderedParams->push_back(varOffset);
  274. }
  275. if (constOffsets) {
  276. mask = mask | ImageOperandsMask::ConstOffsets;
  277. orderedParams->push_back(constOffsets);
  278. }
  279. if (sample) {
  280. mask = mask | ImageOperandsMask::Sample;
  281. orderedParams->push_back(sample);
  282. }
  283. return mask;
  284. }
  285. uint32_t ModuleBuilder::createImageTexelPointer(uint32_t resultType,
  286. uint32_t imageId,
  287. uint32_t coordinate,
  288. uint32_t sample) {
  289. assert(insertPoint && "null insert point");
  290. const uint32_t id = theContext.takeNextId();
  291. instBuilder.opImageTexelPointer(resultType, id, imageId, coordinate, sample)
  292. .x();
  293. insertPoint->appendInstruction(std::move(constructSite));
  294. return id;
  295. }
  296. uint32_t ModuleBuilder::createImageSample(
  297. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  298. uint32_t coordinate, uint32_t compareVal, uint32_t bias, uint32_t lod,
  299. std::pair<uint32_t, uint32_t> grad, uint32_t constOffset,
  300. uint32_t varOffset, uint32_t constOffsets, uint32_t sample) {
  301. assert(insertPoint && "null insert point");
  302. // An OpSampledImage is required to do the image sampling.
  303. const uint32_t sampledImgId = theContext.takeNextId();
  304. const uint32_t sampledImgTy = getSampledImageType(imageType);
  305. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  306. insertPoint->appendInstruction(std::move(constructSite));
  307. const uint32_t texelId = theContext.takeNextId();
  308. llvm::SmallVector<uint32_t, 4> params;
  309. const auto mask = composeImageOperandsMask(
  310. bias, lod, grad, constOffset, varOffset, constOffsets, sample, &params);
  311. // If depth-comparison is needed when sampling, we use the OpImageSampleDref*
  312. // instructions.
  313. if (compareVal) {
  314. // The Lod and Grad image operands requires explicit-lod instructions.
  315. // Otherwise we use implicit-lod instructions.
  316. if (lod || (grad.first && grad.second)) {
  317. instBuilder.opImageSampleDrefExplicitLod(texelType, texelId, sampledImgId,
  318. coordinate, compareVal, mask);
  319. } else {
  320. instBuilder.opImageSampleDrefImplicitLod(
  321. texelType, texelId, sampledImgId, coordinate, compareVal,
  322. llvm::Optional<spv::ImageOperandsMask>(mask));
  323. }
  324. } else {
  325. // The Lod and Grad image operands requires explicit-lod instructions.
  326. // Otherwise we use implicit-lod instructions.
  327. if (lod || (grad.first && grad.second)) {
  328. instBuilder.opImageSampleExplicitLod(texelType, texelId, sampledImgId,
  329. coordinate, mask);
  330. } else {
  331. instBuilder.opImageSampleImplicitLod(
  332. texelType, texelId, sampledImgId, coordinate,
  333. llvm::Optional<spv::ImageOperandsMask>(mask));
  334. }
  335. }
  336. for (const auto param : params)
  337. instBuilder.idRef(param);
  338. instBuilder.x();
  339. insertPoint->appendInstruction(std::move(constructSite));
  340. return texelId;
  341. }
  342. void ModuleBuilder::createImageWrite(uint32_t imageId, uint32_t coordId,
  343. uint32_t texelId) {
  344. assert(insertPoint && "null insert point");
  345. instBuilder.opImageWrite(imageId, coordId, texelId, llvm::None).x();
  346. insertPoint->appendInstruction(std::move(constructSite));
  347. }
  348. uint32_t ModuleBuilder::createImageFetchOrRead(
  349. bool doImageFetch, uint32_t texelType, uint32_t image, uint32_t coordinate,
  350. uint32_t lod, uint32_t constOffset, uint32_t varOffset,
  351. uint32_t constOffsets, uint32_t sample) {
  352. assert(insertPoint && "null insert point");
  353. llvm::SmallVector<uint32_t, 2> params;
  354. const auto mask =
  355. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  356. /*bias*/ 0, lod, std::make_pair(0, 0), constOffset, varOffset,
  357. constOffsets, sample, &params));
  358. const uint32_t texelId = theContext.takeNextId();
  359. if (doImageFetch)
  360. instBuilder.opImageFetch(texelType, texelId, image, coordinate, mask);
  361. else
  362. instBuilder.opImageRead(texelType, texelId, image, coordinate, mask);
  363. for (const auto param : params)
  364. instBuilder.idRef(param);
  365. instBuilder.x();
  366. insertPoint->appendInstruction(std::move(constructSite));
  367. return texelId;
  368. }
  369. uint32_t ModuleBuilder::createImageGather(
  370. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  371. uint32_t coordinate, uint32_t component, uint32_t compareVal,
  372. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  373. uint32_t sample) {
  374. assert(insertPoint && "null insert point");
  375. // An OpSampledImage is required to do the image sampling.
  376. const uint32_t sampledImgId = theContext.takeNextId();
  377. const uint32_t sampledImgTy = getSampledImageType(imageType);
  378. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  379. insertPoint->appendInstruction(std::move(constructSite));
  380. llvm::SmallVector<uint32_t, 2> params;
  381. const auto mask =
  382. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  383. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  384. constOffsets, sample, &params));
  385. const uint32_t texelId = theContext.takeNextId();
  386. if (compareVal) {
  387. // Note: OpImageDrefGather does not take the component parameter.
  388. instBuilder.opImageDrefGather(texelType, texelId, sampledImgId, coordinate,
  389. compareVal, mask);
  390. } else {
  391. instBuilder.opImageGather(texelType, texelId, sampledImgId, coordinate,
  392. component, mask);
  393. }
  394. for (const auto param : params)
  395. instBuilder.idRef(param);
  396. instBuilder.x();
  397. insertPoint->appendInstruction(std::move(constructSite));
  398. return texelId;
  399. }
  400. uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
  401. uint32_t trueValue, uint32_t falseValue) {
  402. assert(insertPoint && "null insert point");
  403. const uint32_t id = theContext.takeNextId();
  404. instBuilder.opSelect(resultType, id, condition, trueValue, falseValue).x();
  405. insertPoint->appendInstruction(std::move(constructSite));
  406. return id;
  407. }
  408. void ModuleBuilder::createSwitch(
  409. uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
  410. llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
  411. assert(insertPoint && "null insert point");
  412. // Create the OpSelectioMerege.
  413. instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  414. .x();
  415. insertPoint->appendInstruction(std::move(constructSite));
  416. // Create the OpSwitch.
  417. instBuilder.opSwitch(selector, defaultLabel, target).x();
  418. insertPoint->appendInstruction(std::move(constructSite));
  419. }
  420. void ModuleBuilder::createKill() {
  421. assert(insertPoint && "null insert point");
  422. assert(!isCurrentBasicBlockTerminated());
  423. instBuilder.opKill().x();
  424. insertPoint->appendInstruction(std::move(constructSite));
  425. }
  426. void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
  427. uint32_t continueBB,
  428. spv::LoopControlMask loopControl) {
  429. assert(insertPoint && "null insert point");
  430. if (mergeBB && continueBB) {
  431. instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
  432. insertPoint->appendInstruction(std::move(constructSite));
  433. }
  434. instBuilder.opBranch(targetLabel).x();
  435. insertPoint->appendInstruction(std::move(constructSite));
  436. }
  437. void ModuleBuilder::createConditionalBranch(
  438. uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
  439. uint32_t mergeLabel, uint32_t continueLabel,
  440. spv::SelectionControlMask selectionControl,
  441. spv::LoopControlMask loopControl) {
  442. assert(insertPoint && "null insert point");
  443. if (mergeLabel) {
  444. if (continueLabel) {
  445. instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
  446. insertPoint->appendInstruction(std::move(constructSite));
  447. } else {
  448. instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
  449. insertPoint->appendInstruction(std::move(constructSite));
  450. }
  451. }
  452. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  453. insertPoint->appendInstruction(std::move(constructSite));
  454. }
  455. void ModuleBuilder::createReturn() {
  456. assert(insertPoint && "null insert point");
  457. instBuilder.opReturn().x();
  458. insertPoint->appendInstruction(std::move(constructSite));
  459. }
  460. void ModuleBuilder::createReturnValue(uint32_t value) {
  461. assert(insertPoint && "null insert point");
  462. instBuilder.opReturnValue(value).x();
  463. insertPoint->appendInstruction(std::move(constructSite));
  464. }
  465. uint32_t ModuleBuilder::createExtInst(uint32_t resultType, uint32_t setId,
  466. uint32_t instId,
  467. llvm::ArrayRef<uint32_t> operands) {
  468. assert(insertPoint && "null insert point");
  469. uint32_t resultId = theContext.takeNextId();
  470. instBuilder.opExtInst(resultType, resultId, setId, instId, operands).x();
  471. insertPoint->appendInstruction(std::move(constructSite));
  472. return resultId;
  473. }
  474. void ModuleBuilder::createControlBarrier(uint32_t execution, uint32_t memory,
  475. uint32_t semantics) {
  476. assert(insertPoint && "null insert point");
  477. instBuilder.opControlBarrier(execution, memory, semantics).x();
  478. insertPoint->appendInstruction(std::move(constructSite));
  479. }
  480. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  481. spv::ExecutionMode em,
  482. llvm::ArrayRef<uint32_t> params) {
  483. instBuilder.opExecutionMode(entryPointId, em);
  484. for (const auto &param : params) {
  485. instBuilder.literalInteger(param);
  486. }
  487. instBuilder.x();
  488. theModule.addExecutionMode(std::move(constructSite));
  489. }
  490. uint32_t ModuleBuilder::getGLSLExtInstSet() {
  491. if (glslExtSetId == 0) {
  492. glslExtSetId = theContext.takeNextId();
  493. theModule.addExtInstSet(glslExtSetId, "GLSL.std.450");
  494. }
  495. return glslExtSetId;
  496. }
  497. uint32_t ModuleBuilder::addStageIOVar(uint32_t type,
  498. spv::StorageClass storageClass,
  499. std::string name) {
  500. const uint32_t pointerType = getPointerType(type, storageClass);
  501. const uint32_t varId = theContext.takeNextId();
  502. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  503. theModule.addVariable(std::move(constructSite));
  504. theModule.addDebugName(varId, name);
  505. return varId;
  506. }
  507. uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
  508. spv::BuiltIn builtin) {
  509. const uint32_t pointerType = getPointerType(type, sc);
  510. const uint32_t varId = theContext.takeNextId();
  511. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  512. theModule.addVariable(std::move(constructSite));
  513. // Decorate with the specified Builtin
  514. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  515. theModule.addDecoration(d, varId);
  516. return varId;
  517. }
  518. uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
  519. llvm::StringRef name,
  520. llvm::Optional<uint32_t> init) {
  521. assert(sc != spv::StorageClass::Function);
  522. // TODO: basically duplicated code of addFileVar()
  523. const uint32_t pointerType = getPointerType(type, sc);
  524. const uint32_t varId = theContext.takeNextId();
  525. instBuilder.opVariable(pointerType, varId, sc, init).x();
  526. theModule.addVariable(std::move(constructSite));
  527. theModule.addDebugName(varId, name);
  528. return varId;
  529. }
  530. void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
  531. uint32_t bindingNumber) {
  532. const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
  533. theModule.addDecoration(d, targetId);
  534. d = Decoration::getBinding(theContext, bindingNumber);
  535. theModule.addDecoration(d, targetId);
  536. }
  537. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  538. const Decoration *d =
  539. Decoration::getLocation(theContext, location, llvm::None);
  540. theModule.addDecoration(d, targetId);
  541. }
  542. void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
  543. const Decoration *d = nullptr;
  544. switch (decoration) {
  545. case spv::Decoration::Centroid:
  546. d = Decoration::getCentroid(theContext);
  547. break;
  548. case spv::Decoration::Flat:
  549. d = Decoration::getFlat(theContext);
  550. break;
  551. case spv::Decoration::NoPerspective:
  552. d = Decoration::getNoPerspective(theContext);
  553. break;
  554. case spv::Decoration::Sample:
  555. d = Decoration::getSample(theContext);
  556. break;
  557. case spv::Decoration::Block:
  558. d = Decoration::getBlock(theContext);
  559. break;
  560. case spv::Decoration::RelaxedPrecision:
  561. d = Decoration::getRelaxedPrecision(theContext);
  562. break;
  563. }
  564. assert(d && "unimplemented decoration");
  565. theModule.addDecoration(d, targetId);
  566. }
  567. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  568. \
  569. uint32_t ModuleBuilder::get##ty##Type() { \
  570. const Type *type = Type::get##ty(theContext); \
  571. const uint32_t typeId = theContext.getResultIdForType(type); \
  572. theModule.addType(type, typeId); \
  573. return typeId; \
  574. \
  575. }
  576. IMPL_GET_PRIMITIVE_TYPE(Void)
  577. IMPL_GET_PRIMITIVE_TYPE(Bool)
  578. IMPL_GET_PRIMITIVE_TYPE(Int32)
  579. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  580. IMPL_GET_PRIMITIVE_TYPE(Float32)
  581. #undef IMPL_GET_PRIMITIVE_TYPE
  582. #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty) \
  583. \
  584. uint32_t ModuleBuilder::get##ty##Type() { \
  585. requireCapability(spv::Capability::ty); \
  586. const Type *type = Type::get##ty(theContext); \
  587. const uint32_t typeId = theContext.getResultIdForType(type); \
  588. theModule.addType(type, typeId); \
  589. return typeId; \
  590. \
  591. }
  592. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64)
  593. #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
  594. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  595. const Type *type = nullptr;
  596. switch (elemCount) {
  597. case 2:
  598. type = Type::getVec2(theContext, elemType);
  599. break;
  600. case 3:
  601. type = Type::getVec3(theContext, elemType);
  602. break;
  603. case 4:
  604. type = Type::getVec4(theContext, elemType);
  605. break;
  606. default:
  607. assert(false && "unhandled vector size");
  608. // Error found. Return 0 as the <result-id> directly.
  609. return 0;
  610. }
  611. const uint32_t typeId = theContext.getResultIdForType(type);
  612. theModule.addType(type, typeId);
  613. return typeId;
  614. }
  615. uint32_t ModuleBuilder::getMatType(uint32_t colType, uint32_t colCount) {
  616. const Type *type = Type::getMatrix(theContext, colType, colCount);
  617. const uint32_t typeId = theContext.getResultIdForType(type);
  618. theModule.addType(type, typeId);
  619. return typeId;
  620. }
  621. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  622. spv::StorageClass storageClass) {
  623. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  624. const uint32_t typeId = theContext.getResultIdForType(type);
  625. theModule.addType(type, typeId);
  626. return typeId;
  627. }
  628. uint32_t
  629. ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
  630. llvm::StringRef structName,
  631. llvm::ArrayRef<llvm::StringRef> fieldNames,
  632. Type::DecorationSet decorations) {
  633. const Type *type = Type::getStruct(theContext, fieldTypes, decorations);
  634. bool isRegistered = false;
  635. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  636. theModule.addType(type, typeId);
  637. if (!isRegistered) {
  638. theModule.addDebugName(typeId, structName);
  639. if (!fieldNames.empty()) {
  640. assert(fieldNames.size() == fieldTypes.size());
  641. for (uint32_t i = 0; i < fieldNames.size(); ++i)
  642. theModule.addDebugName(typeId, fieldNames[i],
  643. llvm::Optional<uint32_t>(i));
  644. }
  645. }
  646. return typeId;
  647. }
  648. uint32_t ModuleBuilder::getArrayType(uint32_t elemType, uint32_t count,
  649. Type::DecorationSet decorations) {
  650. const Type *type = Type::getArray(theContext, elemType, count, decorations);
  651. const uint32_t typeId = theContext.getResultIdForType(type);
  652. theModule.addType(type, typeId);
  653. return typeId;
  654. }
  655. uint32_t ModuleBuilder::getRuntimeArrayType(uint32_t elemType,
  656. Type::DecorationSet decorations) {
  657. const Type *type = Type::getRuntimeArray(theContext, elemType, decorations);
  658. const uint32_t typeId = theContext.getResultIdForType(type);
  659. theModule.addType(type, typeId);
  660. return typeId;
  661. }
  662. uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
  663. llvm::ArrayRef<uint32_t> paramTypes) {
  664. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  665. const uint32_t typeId = theContext.getResultIdForType(type);
  666. theModule.addType(type, typeId);
  667. return typeId;
  668. }
  669. uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
  670. uint32_t depth, bool isArray, uint32_t ms,
  671. uint32_t sampled,
  672. spv::ImageFormat format) {
  673. const Type *type = Type::getImage(theContext, sampledType, dim, depth,
  674. isArray, ms, sampled, format);
  675. bool isRegistered = false;
  676. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  677. theModule.addType(type, typeId);
  678. switch (format) {
  679. case spv::ImageFormat::Rg32f:
  680. case spv::ImageFormat::Rg16f:
  681. case spv::ImageFormat::R11fG11fB10f:
  682. case spv::ImageFormat::R16f:
  683. case spv::ImageFormat::Rgba16:
  684. case spv::ImageFormat::Rgb10A2:
  685. case spv::ImageFormat::Rg16:
  686. case spv::ImageFormat::Rg8:
  687. case spv::ImageFormat::R16:
  688. case spv::ImageFormat::R8:
  689. case spv::ImageFormat::Rgba16Snorm:
  690. case spv::ImageFormat::Rg16Snorm:
  691. case spv::ImageFormat::Rg8Snorm:
  692. case spv::ImageFormat::R16Snorm:
  693. case spv::ImageFormat::R8Snorm:
  694. case spv::ImageFormat::Rg32i:
  695. case spv::ImageFormat::Rg16i:
  696. case spv::ImageFormat::Rg8i:
  697. case spv::ImageFormat::R16i:
  698. case spv::ImageFormat::R8i:
  699. case spv::ImageFormat::Rgb10a2ui:
  700. case spv::ImageFormat::Rg32ui:
  701. case spv::ImageFormat::Rg16ui:
  702. case spv::ImageFormat::Rg8ui:
  703. case spv::ImageFormat::R16ui:
  704. case spv::ImageFormat::R8ui:
  705. requireCapability(spv::Capability::StorageImageExtendedFormats);
  706. }
  707. if (dim == spv::Dim::Dim1D) {
  708. if (sampled == 2u) {
  709. requireCapability(spv::Capability::Image1D);
  710. } else {
  711. requireCapability(spv::Capability::Sampled1D);
  712. }
  713. }
  714. if (dim == spv::Dim::Buffer) {
  715. requireCapability(spv::Capability::SampledBuffer);
  716. }
  717. if (isArray && ms) {
  718. requireCapability(spv::Capability::ImageMSArray);
  719. }
  720. // Skip constructing the debug name if we have already done it before.
  721. if (!isRegistered) {
  722. const char *dimStr = "";
  723. switch (dim) {
  724. case spv::Dim::Dim1D:
  725. dimStr = "1d.";
  726. break;
  727. case spv::Dim::Dim2D:
  728. dimStr = "2d.";
  729. break;
  730. case spv::Dim::Dim3D:
  731. dimStr = "3d.";
  732. break;
  733. case spv::Dim::Cube:
  734. dimStr = "cube.";
  735. break;
  736. case spv::Dim::Rect:
  737. dimStr = "rect.";
  738. break;
  739. case spv::Dim::Buffer:
  740. dimStr = "buffer.";
  741. break;
  742. case spv::Dim::SubpassData:
  743. dimStr = "subpass.";
  744. break;
  745. default:
  746. break;
  747. }
  748. std::string name =
  749. std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
  750. theModule.addDebugName(typeId, name);
  751. }
  752. return typeId;
  753. }
  754. uint32_t ModuleBuilder::getSamplerType() {
  755. const Type *type = Type::getSampler(theContext);
  756. const uint32_t typeId = theContext.getResultIdForType(type);
  757. theModule.addType(type, typeId);
  758. theModule.addDebugName(typeId, "type.sampler");
  759. return typeId;
  760. }
  761. uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
  762. const Type *type = Type::getSampledImage(theContext, imageType);
  763. const uint32_t typeId = theContext.getResultIdForType(type);
  764. theModule.addType(type, typeId);
  765. theModule.addDebugName(typeId, "type.sampled.image");
  766. return typeId;
  767. }
  768. uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
  769. // Create a uint RuntimeArray with Array Stride of 4.
  770. const uint32_t uintType = getUint32Type();
  771. const auto *arrStride4 = Decoration::getArrayStride(theContext, 4u);
  772. const Type *raType =
  773. Type::getRuntimeArray(theContext, uintType, {arrStride4});
  774. const uint32_t raTypeId = theContext.getResultIdForType(raType);
  775. theModule.addType(raType, raTypeId);
  776. // Create a struct containing the runtime array as its only member.
  777. // The struct must also be decorated as BufferBlock. The offset decoration
  778. // should also be applied to the first (only) member. NonWritable decoration
  779. // should also be applied to the first member if isRW is true.
  780. llvm::SmallVector<const Decoration *, 3> typeDecs;
  781. typeDecs.push_back(Decoration::getBufferBlock(theContext));
  782. typeDecs.push_back(Decoration::getOffset(theContext, 0, 0));
  783. if (!isRW)
  784. typeDecs.push_back(Decoration::getNonWritable(theContext, 0));
  785. const Type *type = Type::getStruct(theContext, {raTypeId}, typeDecs);
  786. const uint32_t typeId = theContext.getResultIdForType(type);
  787. theModule.addType(type, typeId);
  788. theModule.addDebugName(
  789. typeId, isRW ? "type.RWByteAddressBuffer" : "type.ByteAddressBuffer");
  790. return typeId;
  791. }
  792. uint32_t ModuleBuilder::getConstantBool(bool value) {
  793. const uint32_t typeId = getBoolType();
  794. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  795. : Constant::getFalse(theContext, typeId);
  796. const uint32_t constId = theContext.getResultIdForConstant(constant);
  797. theModule.addConstant(constant, constId);
  798. return constId;
  799. }
  800. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  801. \
  802. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  803. const uint32_t typeId = get##builderTy##Type(); \
  804. const Constant *constant = \
  805. Constant::get##builderTy(theContext, typeId, value); \
  806. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  807. theModule.addConstant(constant, constId); \
  808. return constId; \
  809. \
  810. }
  811. IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
  812. IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
  813. IMPL_GET_PRIMITIVE_CONST(Float32, float)
  814. IMPL_GET_PRIMITIVE_CONST(Float64, double)
  815. #undef IMPL_GET_PRIMITIVE_VALUE
  816. uint32_t
  817. ModuleBuilder::getConstantComposite(uint32_t typeId,
  818. llvm::ArrayRef<uint32_t> constituents) {
  819. const Constant *constant =
  820. Constant::getComposite(theContext, typeId, constituents);
  821. const uint32_t constId = theContext.getResultIdForConstant(constant);
  822. theModule.addConstant(constant, constId);
  823. return constId;
  824. }
  825. uint32_t ModuleBuilder::getConstantNull(uint32_t typeId) {
  826. const Constant *constant = Constant::getNull(theContext, typeId);
  827. const uint32_t constId = theContext.getResultIdForConstant(constant);
  828. theModule.addConstant(constant, constId);
  829. return constId;
  830. }
  831. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  832. auto it = basicBlocks.find(labelId);
  833. if (it == basicBlocks.end()) {
  834. assert(false && "invalid <label-id>");
  835. return nullptr;
  836. }
  837. return it->second.get();
  838. }
  839. } // end namespace spirv
  840. } // end namespace clang