ModuleBuilder.cpp 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "TypeTranslator.h"
  11. #include "spirv/unified1//spirv.hpp11"
  12. #include "clang/SPIRV/BitwiseCast.h"
  13. #include "clang/SPIRV/InstBuilder.h"
  14. #include "llvm/llvm_assert/assert.h"
  15. namespace clang {
  16. namespace spirv {
  17. ModuleBuilder::ModuleBuilder(SPIRVContext *C)
  18. : theContext(*C), theModule(), theFunction(nullptr), insertPoint(nullptr),
  19. instBuilder(nullptr), glslExtSetId(0) {
  20. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  21. this->constructSite = std::move(words);
  22. });
  23. }
  24. std::vector<uint32_t> ModuleBuilder::takeModule() {
  25. theModule.setBound(theContext.getNextId());
  26. std::vector<uint32_t> binary;
  27. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  28. binary.insert(binary.end(), words.begin(), words.end());
  29. });
  30. theModule.take(&ib);
  31. return binary;
  32. }
  33. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  34. llvm::StringRef funcName, uint32_t fId) {
  35. if (theFunction) {
  36. assert(false && "found nested function");
  37. return 0;
  38. }
  39. // If the caller doesn't supply a function <result-id>, we need to get one.
  40. if (!fId)
  41. fId = theContext.takeNextId();
  42. theFunction = llvm::make_unique<Function>(
  43. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  44. theModule.addDebugName(fId, funcName);
  45. return fId;
  46. }
  47. uint32_t ModuleBuilder::addFnParam(uint32_t ptrType, llvm::StringRef name) {
  48. assert(theFunction && "found detached parameter");
  49. const uint32_t paramId = theContext.takeNextId();
  50. theFunction->addParameter(ptrType, paramId);
  51. theModule.addDebugName(paramId, name);
  52. return paramId;
  53. }
  54. uint32_t ModuleBuilder::addFnVar(uint32_t varType, llvm::StringRef name,
  55. llvm::Optional<uint32_t> init) {
  56. assert(theFunction && "found detached local variable");
  57. const uint32_t ptrType = getPointerType(varType, spv::StorageClass::Function);
  58. const uint32_t varId = theContext.takeNextId();
  59. theFunction->addVariable(ptrType, varId, init);
  60. theModule.addDebugName(varId, name);
  61. return varId;
  62. }
  63. bool ModuleBuilder::endFunction() {
  64. if (theFunction == nullptr) {
  65. assert(false && "no active function");
  66. return false;
  67. }
  68. // Move all basic blocks into the current function.
  69. // TODO: we should adjust the order the basic blocks according to
  70. // SPIR-V validation rules.
  71. for (auto &bb : basicBlocks) {
  72. theFunction->addBasicBlock(std::move(bb.second));
  73. }
  74. basicBlocks.clear();
  75. theModule.addFunction(std::move(theFunction));
  76. theFunction.reset(nullptr);
  77. insertPoint = nullptr;
  78. return true;
  79. }
  80. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  81. if (theFunction == nullptr) {
  82. assert(false && "found detached basic block");
  83. return 0;
  84. }
  85. const uint32_t labelId = theContext.takeNextId();
  86. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId, name);
  87. return labelId;
  88. }
  89. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  90. assert(insertPoint && "null insert point");
  91. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  92. }
  93. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  94. assert(insertPoint && "null insert point");
  95. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  96. }
  97. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  98. assert(insertPoint && "null insert point");
  99. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  100. }
  101. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  102. insertPoint = getBasicBlock(labelId);
  103. }
  104. uint32_t
  105. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  106. llvm::ArrayRef<uint32_t> constituents) {
  107. assert(insertPoint && "null insert point");
  108. const uint32_t resultId = theContext.takeNextId();
  109. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  110. insertPoint->appendInstruction(std::move(constructSite));
  111. return resultId;
  112. }
  113. uint32_t
  114. ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
  115. llvm::ArrayRef<uint32_t> indexes) {
  116. assert(insertPoint && "null insert point");
  117. const uint32_t resultId = theContext.takeNextId();
  118. instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
  119. insertPoint->appendInstruction(std::move(constructSite));
  120. return resultId;
  121. }
  122. uint32_t ModuleBuilder::createCompositeInsert(uint32_t resultType,
  123. uint32_t composite,
  124. llvm::ArrayRef<uint32_t> indices,
  125. uint32_t object) {
  126. assert(insertPoint && "null insert point");
  127. const uint32_t resultId = theContext.takeNextId();
  128. instBuilder
  129. .opCompositeInsert(resultType, resultId, object, composite, indices)
  130. .x();
  131. insertPoint->appendInstruction(std::move(constructSite));
  132. return resultId;
  133. }
  134. uint32_t
  135. ModuleBuilder::createVectorShuffle(uint32_t resultType, uint32_t vector1,
  136. uint32_t vector2,
  137. llvm::ArrayRef<uint32_t> selectors) {
  138. assert(insertPoint && "null insert point");
  139. const uint32_t resultId = theContext.takeNextId();
  140. instBuilder.opVectorShuffle(resultType, resultId, vector1, vector2, selectors)
  141. .x();
  142. insertPoint->appendInstruction(std::move(constructSite));
  143. return resultId;
  144. }
  145. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  146. assert(insertPoint && "null insert point");
  147. const uint32_t resultId = theContext.takeNextId();
  148. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  149. insertPoint->appendInstruction(std::move(constructSite));
  150. return resultId;
  151. }
  152. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  153. assert(insertPoint && "null insert point");
  154. instBuilder.opStore(address, value, llvm::None).x();
  155. insertPoint->appendInstruction(std::move(constructSite));
  156. }
  157. uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
  158. uint32_t functionId,
  159. llvm::ArrayRef<uint32_t> params) {
  160. assert(insertPoint && "null insert point");
  161. const uint32_t id = theContext.takeNextId();
  162. instBuilder.opFunctionCall(returnType, id, functionId, params).x();
  163. insertPoint->appendInstruction(std::move(constructSite));
  164. return id;
  165. }
  166. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  167. llvm::ArrayRef<uint32_t> indexes) {
  168. assert(insertPoint && "null insert point");
  169. const uint32_t id = theContext.takeNextId();
  170. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  171. insertPoint->appendInstruction(std::move(constructSite));
  172. return id;
  173. }
  174. uint32_t ModuleBuilder::createUnaryOp(spv::Op op, uint32_t resultType,
  175. uint32_t operand) {
  176. assert(insertPoint && "null insert point");
  177. const uint32_t id = theContext.takeNextId();
  178. instBuilder.unaryOp(op, resultType, id, operand).x();
  179. insertPoint->appendInstruction(std::move(constructSite));
  180. switch (op) {
  181. case spv::Op::OpImageQuerySize:
  182. case spv::Op::OpImageQueryLevels:
  183. case spv::Op::OpImageQuerySamples:
  184. requireCapability(spv::Capability::ImageQuery);
  185. break;
  186. }
  187. return id;
  188. }
  189. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  190. uint32_t lhs, uint32_t rhs) {
  191. assert(insertPoint && "null insert point");
  192. const uint32_t id = theContext.takeNextId();
  193. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  194. insertPoint->appendInstruction(std::move(constructSite));
  195. switch (op) {
  196. case spv::Op::OpImageQueryLod:
  197. case spv::Op::OpImageQuerySizeLod:
  198. requireCapability(spv::Capability::ImageQuery);
  199. break;
  200. }
  201. return id;
  202. }
  203. uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
  204. uint32_t resultType,
  205. uint32_t lhs, uint32_t rhs) {
  206. const uint32_t id = theContext.takeNextId();
  207. instBuilder.specConstantBinaryOp(op, resultType, id, lhs, rhs).x();
  208. theModule.addVariable(std::move(constructSite));
  209. return id;
  210. }
  211. uint32_t ModuleBuilder::createGroupNonUniformOp(spv::Op op, uint32_t resultType,
  212. uint32_t execScope) {
  213. assert(insertPoint && "null insert point");
  214. const uint32_t id = theContext.takeNextId();
  215. instBuilder.groupNonUniformOp(op, resultType, id, execScope).x();
  216. insertPoint->appendInstruction(std::move(constructSite));
  217. return id;
  218. }
  219. uint32_t ModuleBuilder::createGroupNonUniformUnaryOp(
  220. spv::Op op, uint32_t resultType, uint32_t execScope, uint32_t operand,
  221. llvm::Optional<spv::GroupOperation> groupOp) {
  222. assert(insertPoint && "null insert point");
  223. const uint32_t id = theContext.takeNextId();
  224. instBuilder
  225. .groupNonUniformUnaryOp(op, resultType, id, execScope, groupOp, operand)
  226. .x();
  227. insertPoint->appendInstruction(std::move(constructSite));
  228. return id;
  229. }
  230. uint32_t ModuleBuilder::createGroupNonUniformBinaryOp(spv::Op op,
  231. uint32_t resultType,
  232. uint32_t execScope,
  233. uint32_t operand1,
  234. uint32_t operand2) {
  235. assert(insertPoint && "null insert point");
  236. const uint32_t id = theContext.takeNextId();
  237. instBuilder
  238. .groupNonUniformBinaryOp(op, resultType, id, execScope, operand1,
  239. operand2)
  240. .x();
  241. insertPoint->appendInstruction(std::move(constructSite));
  242. return id;
  243. }
  244. uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
  245. uint32_t orignalValuePtr,
  246. uint32_t scopeId,
  247. uint32_t memorySemanticsId,
  248. uint32_t valueToOp) {
  249. assert(insertPoint && "null insert point");
  250. const uint32_t id = theContext.takeNextId();
  251. switch (opcode) {
  252. case spv::Op::OpAtomicIAdd:
  253. instBuilder.opAtomicIAdd(resultType, id, orignalValuePtr, scopeId,
  254. memorySemanticsId, valueToOp);
  255. break;
  256. case spv::Op::OpAtomicISub:
  257. instBuilder.opAtomicISub(resultType, id, orignalValuePtr, scopeId,
  258. memorySemanticsId, valueToOp);
  259. break;
  260. case spv::Op::OpAtomicAnd:
  261. instBuilder.opAtomicAnd(resultType, id, orignalValuePtr, scopeId,
  262. memorySemanticsId, valueToOp);
  263. break;
  264. case spv::Op::OpAtomicOr:
  265. instBuilder.opAtomicOr(resultType, id, orignalValuePtr, scopeId,
  266. memorySemanticsId, valueToOp);
  267. break;
  268. case spv::Op::OpAtomicXor:
  269. instBuilder.opAtomicXor(resultType, id, orignalValuePtr, scopeId,
  270. memorySemanticsId, valueToOp);
  271. break;
  272. case spv::Op::OpAtomicUMax:
  273. instBuilder.opAtomicUMax(resultType, id, orignalValuePtr, scopeId,
  274. memorySemanticsId, valueToOp);
  275. break;
  276. case spv::Op::OpAtomicUMin:
  277. instBuilder.opAtomicUMin(resultType, id, orignalValuePtr, scopeId,
  278. memorySemanticsId, valueToOp);
  279. break;
  280. case spv::Op::OpAtomicSMax:
  281. instBuilder.opAtomicSMax(resultType, id, orignalValuePtr, scopeId,
  282. memorySemanticsId, valueToOp);
  283. break;
  284. case spv::Op::OpAtomicSMin:
  285. instBuilder.opAtomicSMin(resultType, id, orignalValuePtr, scopeId,
  286. memorySemanticsId, valueToOp);
  287. break;
  288. case spv::Op::OpAtomicExchange:
  289. instBuilder.opAtomicExchange(resultType, id, orignalValuePtr, scopeId,
  290. memorySemanticsId, valueToOp);
  291. break;
  292. default:
  293. assert(false && "unimplemented atomic opcode");
  294. }
  295. instBuilder.x();
  296. insertPoint->appendInstruction(std::move(constructSite));
  297. return id;
  298. }
  299. uint32_t ModuleBuilder::createAtomicCompareExchange(
  300. uint32_t resultType, uint32_t orignalValuePtr, uint32_t scopeId,
  301. uint32_t equalMemorySemanticsId, uint32_t unequalMemorySemanticsId,
  302. uint32_t valueToOp, uint32_t comparator) {
  303. assert(insertPoint && "null insert point");
  304. const uint32_t id = theContext.takeNextId();
  305. instBuilder.opAtomicCompareExchange(
  306. resultType, id, orignalValuePtr, scopeId, equalMemorySemanticsId,
  307. unequalMemorySemanticsId, valueToOp, comparator);
  308. instBuilder.x();
  309. insertPoint->appendInstruction(std::move(constructSite));
  310. return id;
  311. }
  312. spv::ImageOperandsMask ModuleBuilder::composeImageOperandsMask(
  313. uint32_t bias, uint32_t lod, const std::pair<uint32_t, uint32_t> &grad,
  314. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  315. uint32_t sample, uint32_t minLod,
  316. llvm::SmallVectorImpl<uint32_t> *orderedParams) {
  317. using spv::ImageOperandsMask;
  318. // SPIR-V Image Operands from least significant bit to most significant bit
  319. // Bias, Lod, Grad, ConstOffset, Offset, ConstOffsets, Sample, MinLod
  320. auto mask = ImageOperandsMask::MaskNone;
  321. orderedParams->clear();
  322. if (bias) {
  323. mask = mask | ImageOperandsMask::Bias;
  324. orderedParams->push_back(bias);
  325. }
  326. if (lod) {
  327. mask = mask | ImageOperandsMask::Lod;
  328. orderedParams->push_back(lod);
  329. }
  330. if (grad.first && grad.second) {
  331. mask = mask | ImageOperandsMask::Grad;
  332. orderedParams->push_back(grad.first);
  333. orderedParams->push_back(grad.second);
  334. }
  335. if (constOffset) {
  336. mask = mask | ImageOperandsMask::ConstOffset;
  337. orderedParams->push_back(constOffset);
  338. }
  339. if (varOffset) {
  340. mask = mask | ImageOperandsMask::Offset;
  341. requireCapability(spv::Capability::ImageGatherExtended);
  342. orderedParams->push_back(varOffset);
  343. }
  344. if (constOffsets) {
  345. mask = mask | ImageOperandsMask::ConstOffsets;
  346. orderedParams->push_back(constOffsets);
  347. }
  348. if (sample) {
  349. mask = mask | ImageOperandsMask::Sample;
  350. orderedParams->push_back(sample);
  351. }
  352. if (minLod) {
  353. requireCapability(spv::Capability::MinLod);
  354. mask = mask | ImageOperandsMask::MinLod;
  355. orderedParams->push_back(minLod);
  356. }
  357. return mask;
  358. }
  359. uint32_t
  360. ModuleBuilder::createImageSparseTexelsResident(uint32_t resident_code) {
  361. assert(insertPoint && "null insert point");
  362. // Result type must be a boolean
  363. const uint32_t result_type = getBoolType();
  364. const uint32_t id = theContext.takeNextId();
  365. instBuilder.opImageSparseTexelsResident(result_type, id, resident_code).x();
  366. insertPoint->appendInstruction(std::move(constructSite));
  367. return id;
  368. }
  369. uint32_t ModuleBuilder::createImageTexelPointer(uint32_t resultType,
  370. uint32_t imageId,
  371. uint32_t coordinate,
  372. uint32_t sample) {
  373. assert(insertPoint && "null insert point");
  374. const uint32_t id = theContext.takeNextId();
  375. instBuilder.opImageTexelPointer(resultType, id, imageId, coordinate, sample)
  376. .x();
  377. insertPoint->appendInstruction(std::move(constructSite));
  378. return id;
  379. }
  380. uint32_t ModuleBuilder::createImageSample(
  381. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  382. uint32_t coordinate, uint32_t compareVal, uint32_t bias, uint32_t lod,
  383. std::pair<uint32_t, uint32_t> grad, uint32_t constOffset,
  384. uint32_t varOffset, uint32_t constOffsets, uint32_t sample, uint32_t minLod,
  385. uint32_t residencyCodeId) {
  386. assert(insertPoint && "null insert point");
  387. // The Lod and Grad image operands requires explicit-lod instructions.
  388. // Otherwise we use implicit-lod instructions.
  389. const bool isExplicit = lod || (grad.first && grad.second);
  390. const bool isSparse = (residencyCodeId != 0);
  391. // minLod is only valid with Implicit instructions and Grad instructions.
  392. // This means that we cannot have Lod and minLod together because Lod requires
  393. // explicit insturctions. So either lod or minLod or both must be zero.
  394. assert(lod == 0 || minLod == 0);
  395. uint32_t retType = texelType;
  396. if (isSparse) {
  397. requireCapability(spv::Capability::SparseResidency);
  398. retType = getSparseResidencyStructType(texelType);
  399. }
  400. // An OpSampledImage is required to do the image sampling.
  401. const uint32_t sampledImgId = theContext.takeNextId();
  402. const uint32_t sampledImgTy = getSampledImageType(imageType);
  403. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  404. insertPoint->appendInstruction(std::move(constructSite));
  405. uint32_t texelId = theContext.takeNextId();
  406. llvm::SmallVector<uint32_t, 4> params;
  407. const auto mask =
  408. composeImageOperandsMask(bias, lod, grad, constOffset, varOffset,
  409. constOffsets, sample, minLod, &params);
  410. instBuilder.opImageSample(retType, texelId, sampledImgId, coordinate,
  411. compareVal, mask, isExplicit, isSparse);
  412. for (const auto param : params)
  413. instBuilder.idRef(param);
  414. instBuilder.x();
  415. insertPoint->appendInstruction(std::move(constructSite));
  416. if (isSparse) {
  417. // Write the Residency Code
  418. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  419. createStore(residencyCodeId, status);
  420. // Extract the real result from the struct
  421. texelId = createCompositeExtract(texelType, texelId, {1});
  422. }
  423. return texelId;
  424. }
  425. void ModuleBuilder::createImageWrite(QualType imageType, uint32_t imageId,
  426. uint32_t coordId, uint32_t texelId) {
  427. assert(insertPoint && "null insert point");
  428. requireCapability(
  429. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  430. instBuilder.opImageWrite(imageId, coordId, texelId, llvm::None).x();
  431. insertPoint->appendInstruction(std::move(constructSite));
  432. }
  433. uint32_t ModuleBuilder::createImageFetchOrRead(
  434. bool doImageFetch, uint32_t texelType, QualType imageType, uint32_t image,
  435. uint32_t coordinate, uint32_t lod, uint32_t constOffset, uint32_t varOffset,
  436. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  437. assert(insertPoint && "null insert point");
  438. llvm::SmallVector<uint32_t, 2> params;
  439. const auto mask =
  440. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  441. /*bias*/ 0, lod, std::make_pair(0, 0), constOffset, varOffset,
  442. constOffsets, sample, /*minLod*/ 0, &params));
  443. const bool isSparse = (residencyCodeId != 0);
  444. uint32_t retType = texelType;
  445. if (isSparse) {
  446. requireCapability(spv::Capability::SparseResidency);
  447. retType = getSparseResidencyStructType(texelType);
  448. }
  449. if (!doImageFetch) {
  450. requireCapability(
  451. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  452. }
  453. uint32_t texelId = theContext.takeNextId();
  454. instBuilder.opImageFetchRead(retType, texelId, image, coordinate, mask,
  455. doImageFetch, isSparse);
  456. for (const auto param : params)
  457. instBuilder.idRef(param);
  458. instBuilder.x();
  459. insertPoint->appendInstruction(std::move(constructSite));
  460. if (isSparse) {
  461. // Write the Residency Code
  462. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  463. createStore(residencyCodeId, status);
  464. // Extract the real result from the struct
  465. texelId = createCompositeExtract(texelType, texelId, {1});
  466. }
  467. return texelId;
  468. }
  469. uint32_t ModuleBuilder::createImageGather(
  470. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  471. uint32_t coordinate, uint32_t component, uint32_t compareVal,
  472. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  473. uint32_t sample, uint32_t residencyCodeId) {
  474. assert(insertPoint && "null insert point");
  475. uint32_t sparseRetType = 0;
  476. if (residencyCodeId) {
  477. requireCapability(spv::Capability::SparseResidency);
  478. sparseRetType = getSparseResidencyStructType(texelType);
  479. }
  480. // An OpSampledImage is required to do the image sampling.
  481. const uint32_t sampledImgId = theContext.takeNextId();
  482. const uint32_t sampledImgTy = getSampledImageType(imageType);
  483. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  484. insertPoint->appendInstruction(std::move(constructSite));
  485. llvm::SmallVector<uint32_t, 2> params;
  486. // TODO: Update ImageGather to accept minLod if necessary.
  487. const auto mask =
  488. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  489. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  490. constOffsets, sample, /*minLod*/ 0, &params));
  491. uint32_t texelId = theContext.takeNextId();
  492. if (compareVal) {
  493. if (residencyCodeId) {
  494. // Note: OpImageSparseDrefGather does not take the component parameter.
  495. instBuilder.opImageSparseDrefGather(sparseRetType, texelId, sampledImgId,
  496. coordinate, compareVal, mask);
  497. } else {
  498. // Note: OpImageDrefGather does not take the component parameter.
  499. instBuilder.opImageDrefGather(texelType, texelId, sampledImgId,
  500. coordinate, compareVal, mask);
  501. }
  502. } else {
  503. if (residencyCodeId) {
  504. instBuilder.opImageSparseGather(sparseRetType, texelId, sampledImgId,
  505. coordinate, component, mask);
  506. } else {
  507. instBuilder.opImageGather(texelType, texelId, sampledImgId, coordinate,
  508. component, mask);
  509. }
  510. }
  511. for (const auto param : params)
  512. instBuilder.idRef(param);
  513. instBuilder.x();
  514. insertPoint->appendInstruction(std::move(constructSite));
  515. if (residencyCodeId) {
  516. // Write the Residency Code
  517. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  518. createStore(residencyCodeId, status);
  519. // Extract the real result from the struct
  520. texelId = createCompositeExtract(texelType, texelId, {1});
  521. }
  522. return texelId;
  523. }
  524. uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
  525. uint32_t trueValue, uint32_t falseValue) {
  526. assert(insertPoint && "null insert point");
  527. const uint32_t id = theContext.takeNextId();
  528. instBuilder.opSelect(resultType, id, condition, trueValue, falseValue).x();
  529. insertPoint->appendInstruction(std::move(constructSite));
  530. return id;
  531. }
  532. void ModuleBuilder::createSwitch(
  533. uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
  534. llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
  535. assert(insertPoint && "null insert point");
  536. // Create the OpSelectioMerege.
  537. instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  538. .x();
  539. insertPoint->appendInstruction(std::move(constructSite));
  540. // Create the OpSwitch.
  541. instBuilder.opSwitch(selector, defaultLabel, target).x();
  542. insertPoint->appendInstruction(std::move(constructSite));
  543. }
  544. void ModuleBuilder::createKill() {
  545. assert(insertPoint && "null insert point");
  546. assert(!isCurrentBasicBlockTerminated());
  547. instBuilder.opKill().x();
  548. insertPoint->appendInstruction(std::move(constructSite));
  549. }
  550. void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
  551. uint32_t continueBB,
  552. spv::LoopControlMask loopControl) {
  553. assert(insertPoint && "null insert point");
  554. if (mergeBB && continueBB) {
  555. instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
  556. insertPoint->appendInstruction(std::move(constructSite));
  557. }
  558. instBuilder.opBranch(targetLabel).x();
  559. insertPoint->appendInstruction(std::move(constructSite));
  560. }
  561. void ModuleBuilder::createConditionalBranch(
  562. uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
  563. uint32_t mergeLabel, uint32_t continueLabel,
  564. spv::SelectionControlMask selectionControl,
  565. spv::LoopControlMask loopControl) {
  566. assert(insertPoint && "null insert point");
  567. if (mergeLabel) {
  568. if (continueLabel) {
  569. instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
  570. insertPoint->appendInstruction(std::move(constructSite));
  571. } else {
  572. instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
  573. insertPoint->appendInstruction(std::move(constructSite));
  574. }
  575. }
  576. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  577. insertPoint->appendInstruction(std::move(constructSite));
  578. }
  579. void ModuleBuilder::createReturn() {
  580. assert(insertPoint && "null insert point");
  581. instBuilder.opReturn().x();
  582. insertPoint->appendInstruction(std::move(constructSite));
  583. }
  584. void ModuleBuilder::createReturnValue(uint32_t value) {
  585. assert(insertPoint && "null insert point");
  586. instBuilder.opReturnValue(value).x();
  587. insertPoint->appendInstruction(std::move(constructSite));
  588. }
  589. uint32_t ModuleBuilder::createExtInst(uint32_t resultType, uint32_t setId,
  590. uint32_t instId,
  591. llvm::ArrayRef<uint32_t> operands) {
  592. assert(insertPoint && "null insert point");
  593. uint32_t resultId = theContext.takeNextId();
  594. instBuilder.opExtInst(resultType, resultId, setId, instId, operands).x();
  595. insertPoint->appendInstruction(std::move(constructSite));
  596. return resultId;
  597. }
  598. void ModuleBuilder::createBarrier(uint32_t execution, uint32_t memory,
  599. uint32_t semantics) {
  600. assert(insertPoint && "null insert point");
  601. if (execution)
  602. instBuilder.opControlBarrier(execution, memory, semantics).x();
  603. else
  604. instBuilder.opMemoryBarrier(memory, semantics).x();
  605. insertPoint->appendInstruction(std::move(constructSite));
  606. }
  607. uint32_t ModuleBuilder::createBitFieldExtract(uint32_t resultType,
  608. uint32_t base, uint32_t offset,
  609. uint32_t count, bool isSigned) {
  610. assert(insertPoint && "null insert point");
  611. uint32_t resultId = theContext.takeNextId();
  612. if (isSigned)
  613. instBuilder.opBitFieldSExtract(resultType, resultId, base, offset, count);
  614. else
  615. instBuilder.opBitFieldUExtract(resultType, resultId, base, offset, count);
  616. instBuilder.x();
  617. insertPoint->appendInstruction(std::move(constructSite));
  618. return resultId;
  619. }
  620. uint32_t ModuleBuilder::createBitFieldInsert(uint32_t resultType, uint32_t base,
  621. uint32_t insert, uint32_t offset,
  622. uint32_t count) {
  623. assert(insertPoint && "null insert point");
  624. uint32_t resultId = theContext.takeNextId();
  625. instBuilder
  626. .opBitFieldInsert(resultType, resultId, base, insert, offset, count)
  627. .x();
  628. insertPoint->appendInstruction(std::move(constructSite));
  629. return resultId;
  630. }
  631. void ModuleBuilder::createEmitVertex() {
  632. assert(insertPoint && "null insert point");
  633. instBuilder.opEmitVertex().x();
  634. insertPoint->appendInstruction(std::move(constructSite));
  635. }
  636. void ModuleBuilder::createEndPrimitive() {
  637. assert(insertPoint && "null insert point");
  638. instBuilder.opEndPrimitive().x();
  639. insertPoint->appendInstruction(std::move(constructSite));
  640. }
  641. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  642. spv::ExecutionMode em,
  643. llvm::ArrayRef<uint32_t> params) {
  644. instBuilder.opExecutionMode(entryPointId, em);
  645. for (const auto &param : params) {
  646. instBuilder.literalInteger(param);
  647. }
  648. instBuilder.x();
  649. theModule.addExecutionMode(std::move(constructSite));
  650. }
  651. uint32_t ModuleBuilder::getGLSLExtInstSet() {
  652. if (glslExtSetId == 0) {
  653. glslExtSetId = theContext.takeNextId();
  654. theModule.addExtInstSet(glslExtSetId, "GLSL.std.450");
  655. }
  656. return glslExtSetId;
  657. }
  658. uint32_t ModuleBuilder::addStageIOVar(uint32_t type,
  659. spv::StorageClass storageClass,
  660. std::string name) {
  661. const uint32_t pointerType = getPointerType(type, storageClass);
  662. const uint32_t varId = theContext.takeNextId();
  663. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  664. theModule.addVariable(std::move(constructSite));
  665. theModule.addDebugName(varId, name);
  666. return varId;
  667. }
  668. uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
  669. spv::BuiltIn builtin) {
  670. const uint32_t pointerType = getPointerType(type, sc);
  671. const uint32_t varId = theContext.takeNextId();
  672. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  673. theModule.addVariable(std::move(constructSite));
  674. // Decorate with the specified Builtin
  675. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  676. theModule.addDecoration(d, varId);
  677. return varId;
  678. }
  679. uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
  680. llvm::StringRef name,
  681. llvm::Optional<uint32_t> init) {
  682. assert(sc != spv::StorageClass::Function);
  683. // TODO: basically duplicated code of addFileVar()
  684. const uint32_t pointerType = getPointerType(type, sc);
  685. const uint32_t varId = theContext.takeNextId();
  686. instBuilder.opVariable(pointerType, varId, sc, init).x();
  687. theModule.addVariable(std::move(constructSite));
  688. theModule.addDebugName(varId, name);
  689. return varId;
  690. }
  691. void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
  692. uint32_t bindingNumber) {
  693. const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
  694. theModule.addDecoration(d, targetId);
  695. d = Decoration::getBinding(theContext, bindingNumber);
  696. theModule.addDecoration(d, targetId);
  697. }
  698. void ModuleBuilder::decorateInputAttachmentIndex(uint32_t targetId,
  699. uint32_t indexNumber) {
  700. const auto *d = Decoration::getInputAttachmentIndex(theContext, indexNumber);
  701. theModule.addDecoration(d, targetId);
  702. }
  703. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  704. const Decoration *d =
  705. Decoration::getLocation(theContext, location, llvm::None);
  706. theModule.addDecoration(d, targetId);
  707. }
  708. void ModuleBuilder::decorateSpecId(uint32_t targetId, uint32_t specId) {
  709. const Decoration *d = Decoration::getSpecId(theContext, specId);
  710. theModule.addDecoration(d, targetId);
  711. }
  712. void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
  713. const Decoration *d = nullptr;
  714. switch (decoration) {
  715. case spv::Decoration::Centroid:
  716. d = Decoration::getCentroid(theContext);
  717. break;
  718. case spv::Decoration::Flat:
  719. d = Decoration::getFlat(theContext);
  720. break;
  721. case spv::Decoration::NoPerspective:
  722. d = Decoration::getNoPerspective(theContext);
  723. break;
  724. case spv::Decoration::Sample:
  725. d = Decoration::getSample(theContext);
  726. break;
  727. case spv::Decoration::Block:
  728. d = Decoration::getBlock(theContext);
  729. break;
  730. case spv::Decoration::RelaxedPrecision:
  731. d = Decoration::getRelaxedPrecision(theContext);
  732. break;
  733. case spv::Decoration::Patch:
  734. d = Decoration::getPatch(theContext);
  735. break;
  736. }
  737. assert(d && "unimplemented decoration");
  738. theModule.addDecoration(d, targetId);
  739. }
  740. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  741. \
  742. uint32_t ModuleBuilder::get##ty##Type() { \
  743. const Type *type = Type::get##ty(theContext); \
  744. const uint32_t typeId = theContext.getResultIdForType(type); \
  745. theModule.addType(type, typeId); \
  746. return typeId; \
  747. }
  748. IMPL_GET_PRIMITIVE_TYPE(Void)
  749. IMPL_GET_PRIMITIVE_TYPE(Bool)
  750. IMPL_GET_PRIMITIVE_TYPE(Int32)
  751. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  752. IMPL_GET_PRIMITIVE_TYPE(Float32)
  753. #undef IMPL_GET_PRIMITIVE_TYPE
  754. // Note: At the moment, Float16 capability should not be added for Vulkan 1.0.
  755. // It is not a required capability, and adding the SPV_AMD_gpu_half_float does
  756. // not enable this capability. Any driver that supports float16 in Vulkan 1.0
  757. // should accept this extension.
  758. #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap) \
  759. \
  760. uint32_t ModuleBuilder::get##ty##Type() { \
  761. if (spv::Capability::cap == spv::Capability::Float16) \
  762. theModule.addExtension("SPV_AMD_gpu_shader_half_float"); \
  763. else \
  764. requireCapability(spv::Capability::cap); \
  765. const Type *type = Type::get##ty(theContext); \
  766. const uint32_t typeId = theContext.getResultIdForType(type); \
  767. theModule.addType(type, typeId); \
  768. return typeId; \
  769. }
  770. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
  771. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
  772. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
  773. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int16, Int16)
  774. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint16, Int16)
  775. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float16, Float16)
  776. #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
  777. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  778. const Type *type = nullptr;
  779. switch (elemCount) {
  780. case 2:
  781. type = Type::getVec2(theContext, elemType);
  782. break;
  783. case 3:
  784. type = Type::getVec3(theContext, elemType);
  785. break;
  786. case 4:
  787. type = Type::getVec4(theContext, elemType);
  788. break;
  789. default:
  790. assert(false && "unhandled vector size");
  791. // Error found. Return 0 as the <result-id> directly.
  792. return 0;
  793. }
  794. const uint32_t typeId = theContext.getResultIdForType(type);
  795. theModule.addType(type, typeId);
  796. return typeId;
  797. }
  798. uint32_t ModuleBuilder::getMatType(QualType elemType, uint32_t colType,
  799. uint32_t colCount,
  800. Type::DecorationSet decorations) {
  801. // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
  802. // Validation Rules":
  803. // Matrix types can only be parameterized with floating-point types.
  804. //
  805. // So we need special handling of non-fp matrices. We emulate non-fp
  806. // matrices as an array of vectors.
  807. if (!elemType->isFloatingType())
  808. return getArrayType(colType, getConstantUint32(colCount), decorations);
  809. const Type *type = Type::getMatrix(theContext, colType, colCount);
  810. const uint32_t typeId = theContext.getResultIdForType(type);
  811. theModule.addType(type, typeId);
  812. return typeId;
  813. }
  814. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  815. spv::StorageClass storageClass) {
  816. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  817. const uint32_t typeId = theContext.getResultIdForType(type);
  818. theModule.addType(type, typeId);
  819. return typeId;
  820. }
  821. uint32_t
  822. ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
  823. llvm::StringRef structName,
  824. llvm::ArrayRef<llvm::StringRef> fieldNames,
  825. Type::DecorationSet decorations) {
  826. const Type *type = Type::getStruct(theContext, fieldTypes, decorations);
  827. bool isRegistered = false;
  828. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  829. theModule.addType(type, typeId);
  830. if (!isRegistered) {
  831. theModule.addDebugName(typeId, structName);
  832. if (!fieldNames.empty()) {
  833. assert(fieldNames.size() == fieldTypes.size());
  834. for (uint32_t i = 0; i < fieldNames.size(); ++i)
  835. theModule.addDebugName(typeId, fieldNames[i],
  836. llvm::Optional<uint32_t>(i));
  837. }
  838. }
  839. return typeId;
  840. }
  841. uint32_t ModuleBuilder::getSparseResidencyStructType(uint32_t type) {
  842. const auto uintType = getUint32Type();
  843. return getStructType({uintType, type}, "SparseResidencyStruct",
  844. {"Residency.Code", "Result.Type"});
  845. }
  846. uint32_t ModuleBuilder::getArrayType(uint32_t elemType, uint32_t count,
  847. Type::DecorationSet decorations) {
  848. const Type *type = Type::getArray(theContext, elemType, count, decorations);
  849. const uint32_t typeId = theContext.getResultIdForType(type);
  850. theModule.addType(type, typeId);
  851. return typeId;
  852. }
  853. uint32_t ModuleBuilder::getRuntimeArrayType(uint32_t elemType,
  854. Type::DecorationSet decorations) {
  855. const Type *type = Type::getRuntimeArray(theContext, elemType, decorations);
  856. const uint32_t typeId = theContext.getResultIdForType(type);
  857. theModule.addType(type, typeId);
  858. return typeId;
  859. }
  860. uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
  861. llvm::ArrayRef<uint32_t> paramTypes) {
  862. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  863. const uint32_t typeId = theContext.getResultIdForType(type);
  864. theModule.addType(type, typeId);
  865. return typeId;
  866. }
  867. uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
  868. uint32_t depth, bool isArray, uint32_t ms,
  869. uint32_t sampled,
  870. spv::ImageFormat format) {
  871. const Type *type = Type::getImage(theContext, sampledType, dim, depth,
  872. isArray, ms, sampled, format);
  873. bool isRegistered = false;
  874. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  875. theModule.addType(type, typeId);
  876. switch (format) {
  877. case spv::ImageFormat::Rg32f:
  878. case spv::ImageFormat::Rg16f:
  879. case spv::ImageFormat::R11fG11fB10f:
  880. case spv::ImageFormat::R16f:
  881. case spv::ImageFormat::Rgba16:
  882. case spv::ImageFormat::Rgb10A2:
  883. case spv::ImageFormat::Rg16:
  884. case spv::ImageFormat::Rg8:
  885. case spv::ImageFormat::R16:
  886. case spv::ImageFormat::R8:
  887. case spv::ImageFormat::Rgba16Snorm:
  888. case spv::ImageFormat::Rg16Snorm:
  889. case spv::ImageFormat::Rg8Snorm:
  890. case spv::ImageFormat::R16Snorm:
  891. case spv::ImageFormat::R8Snorm:
  892. case spv::ImageFormat::Rg32i:
  893. case spv::ImageFormat::Rg16i:
  894. case spv::ImageFormat::Rg8i:
  895. case spv::ImageFormat::R16i:
  896. case spv::ImageFormat::R8i:
  897. case spv::ImageFormat::Rgb10a2ui:
  898. case spv::ImageFormat::Rg32ui:
  899. case spv::ImageFormat::Rg16ui:
  900. case spv::ImageFormat::Rg8ui:
  901. case spv::ImageFormat::R16ui:
  902. case spv::ImageFormat::R8ui:
  903. requireCapability(spv::Capability::StorageImageExtendedFormats);
  904. }
  905. if (dim == spv::Dim::Dim1D) {
  906. if (sampled == 2u) {
  907. requireCapability(spv::Capability::Image1D);
  908. } else {
  909. requireCapability(spv::Capability::Sampled1D);
  910. }
  911. } else if (dim == spv::Dim::Buffer) {
  912. requireCapability(spv::Capability::SampledBuffer);
  913. } else if (dim == spv::Dim::SubpassData) {
  914. requireCapability(spv::Capability::InputAttachment);
  915. }
  916. if (isArray && ms) {
  917. requireCapability(spv::Capability::ImageMSArray);
  918. }
  919. // Skip constructing the debug name if we have already done it before.
  920. if (!isRegistered) {
  921. const char *dimStr = "";
  922. switch (dim) {
  923. case spv::Dim::Dim1D:
  924. dimStr = "1d.";
  925. break;
  926. case spv::Dim::Dim2D:
  927. dimStr = "2d.";
  928. break;
  929. case spv::Dim::Dim3D:
  930. dimStr = "3d.";
  931. break;
  932. case spv::Dim::Cube:
  933. dimStr = "cube.";
  934. break;
  935. case spv::Dim::Rect:
  936. dimStr = "rect.";
  937. break;
  938. case spv::Dim::Buffer:
  939. dimStr = "buffer.";
  940. break;
  941. case spv::Dim::SubpassData:
  942. dimStr = "subpass.";
  943. break;
  944. default:
  945. break;
  946. }
  947. std::string name =
  948. std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
  949. theModule.addDebugName(typeId, name);
  950. }
  951. return typeId;
  952. }
  953. uint32_t ModuleBuilder::getSamplerType() {
  954. const Type *type = Type::getSampler(theContext);
  955. const uint32_t typeId = theContext.getResultIdForType(type);
  956. theModule.addType(type, typeId);
  957. theModule.addDebugName(typeId, "type.sampler");
  958. return typeId;
  959. }
  960. uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
  961. const Type *type = Type::getSampledImage(theContext, imageType);
  962. const uint32_t typeId = theContext.getResultIdForType(type);
  963. theModule.addType(type, typeId);
  964. theModule.addDebugName(typeId, "type.sampled.image");
  965. return typeId;
  966. }
  967. uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
  968. // Create a uint RuntimeArray with Array Stride of 4.
  969. const uint32_t uintType = getUint32Type();
  970. const auto *arrStride4 = Decoration::getArrayStride(theContext, 4u);
  971. const Type *raType =
  972. Type::getRuntimeArray(theContext, uintType, {arrStride4});
  973. const uint32_t raTypeId = theContext.getResultIdForType(raType);
  974. theModule.addType(raType, raTypeId);
  975. // Create a struct containing the runtime array as its only member.
  976. // The struct must also be decorated as BufferBlock. The offset decoration
  977. // should also be applied to the first (only) member. NonWritable decoration
  978. // should also be applied to the first member if isRW is true.
  979. llvm::SmallVector<const Decoration *, 3> typeDecs;
  980. typeDecs.push_back(Decoration::getBufferBlock(theContext));
  981. typeDecs.push_back(Decoration::getOffset(theContext, 0, 0));
  982. if (!isRW)
  983. typeDecs.push_back(Decoration::getNonWritable(theContext, 0));
  984. const Type *type = Type::getStruct(theContext, {raTypeId}, typeDecs);
  985. const uint32_t typeId = theContext.getResultIdForType(type);
  986. theModule.addType(type, typeId);
  987. theModule.addDebugName(typeId, isRW ? "type.RWByteAddressBuffer"
  988. : "type.ByteAddressBuffer");
  989. return typeId;
  990. }
  991. uint32_t ModuleBuilder::getConstantBool(bool value, bool isSpecConst) {
  992. if (isSpecConst) {
  993. const uint32_t constId = theContext.takeNextId();
  994. if (value) {
  995. instBuilder.opSpecConstantTrue(getBoolType(), constId).x();
  996. } else {
  997. instBuilder.opSpecConstantFalse(getBoolType(), constId).x();
  998. }
  999. theModule.addVariable(std::move(constructSite));
  1000. return constId;
  1001. }
  1002. const uint32_t typeId = getBoolType();
  1003. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  1004. : Constant::getFalse(theContext, typeId);
  1005. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1006. theModule.addConstant(constant, constId);
  1007. return constId;
  1008. }
  1009. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  1010. \
  1011. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  1012. const uint32_t typeId = get##builderTy##Type(); \
  1013. const Constant *constant = \
  1014. Constant::get##builderTy(theContext, typeId, value); \
  1015. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1016. theModule.addConstant(constant, constId); \
  1017. return constId; \
  1018. }
  1019. #define IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(builderTy, cppTy) \
  1020. \
  1021. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value, \
  1022. bool isSpecConst) { \
  1023. if (isSpecConst) { \
  1024. const uint32_t constId = theContext.takeNextId(); \
  1025. instBuilder \
  1026. .opSpecConstant(get##builderTy##Type(), constId, \
  1027. cast::BitwiseCast<uint32_t>(value)) \
  1028. .x(); \
  1029. theModule.addVariable(std::move(constructSite)); \
  1030. return constId; \
  1031. } \
  1032. \
  1033. const uint32_t typeId = get##builderTy##Type(); \
  1034. const Constant *constant = \
  1035. Constant::get##builderTy(theContext, typeId, value); \
  1036. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1037. theModule.addConstant(constant, constId); \
  1038. return constId; \
  1039. }
  1040. IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
  1041. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Int32, int32_t)
  1042. IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
  1043. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Uint32, uint32_t)
  1044. IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
  1045. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Float32, float)
  1046. IMPL_GET_PRIMITIVE_CONST(Float64, double)
  1047. IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
  1048. IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
  1049. #undef IMPL_GET_PRIMITIVE_CONST
  1050. #undef IMPL_GET_PRIMITIVE_CONST_SPEC_CONST
  1051. uint32_t
  1052. ModuleBuilder::getConstantComposite(uint32_t typeId,
  1053. llvm::ArrayRef<uint32_t> constituents) {
  1054. const Constant *constant =
  1055. Constant::getComposite(theContext, typeId, constituents);
  1056. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1057. theModule.addConstant(constant, constId);
  1058. return constId;
  1059. }
  1060. uint32_t ModuleBuilder::getConstantNull(uint32_t typeId) {
  1061. const Constant *constant = Constant::getNull(theContext, typeId);
  1062. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1063. theModule.addConstant(constant, constId);
  1064. return constId;
  1065. }
  1066. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  1067. auto it = basicBlocks.find(labelId);
  1068. if (it == basicBlocks.end()) {
  1069. assert(false && "invalid <label-id>");
  1070. return nullptr;
  1071. }
  1072. return it->second.get();
  1073. }
  1074. } // end namespace spirv
  1075. } // end namespace clang