ModuleBuilder.cpp 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "TypeTranslator.h"
  11. #include "spirv/unified1/spirv.hpp11"
  12. #include "clang/SPIRV/BitwiseCast.h"
  13. #include "clang/SPIRV/InstBuilder.h"
  14. namespace clang {
  15. namespace spirv {
  16. ModuleBuilder::ModuleBuilder(SPIRVContext *C, FeatureManager *features,
  17. const SpirvCodeGenOptions &opts)
  18. : theContext(*C), featureManager(features), spirvOptions(opts),
  19. theModule(opts), theFunction(nullptr), insertPoint(nullptr),
  20. instBuilder(nullptr), glslExtSetId(0) {
  21. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  22. this->constructSite = std::move(words);
  23. });
  24. }
  25. std::vector<uint32_t> ModuleBuilder::takeModule() {
  26. theModule.setBound(theContext.getNextId());
  27. std::vector<uint32_t> binary;
  28. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  29. binary.insert(binary.end(), words.begin(), words.end());
  30. });
  31. theModule.take(&ib);
  32. return binary;
  33. }
  34. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  35. llvm::StringRef funcName, uint32_t fId) {
  36. if (theFunction) {
  37. assert(false && "found nested function");
  38. return 0;
  39. }
  40. // If the caller doesn't supply a function <result-id>, we need to get one.
  41. if (!fId)
  42. fId = theContext.takeNextId();
  43. theFunction = llvm::make_unique<Function>(
  44. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  45. theModule.addDebugName(fId, funcName);
  46. return fId;
  47. }
  48. uint32_t ModuleBuilder::addFnParam(uint32_t ptrType, llvm::StringRef name) {
  49. assert(theFunction && "found detached parameter");
  50. const uint32_t paramId = theContext.takeNextId();
  51. theFunction->addParameter(ptrType, paramId);
  52. theModule.addDebugName(paramId, name);
  53. return paramId;
  54. }
  55. uint32_t ModuleBuilder::addFnVar(uint32_t varType, llvm::StringRef name,
  56. llvm::Optional<uint32_t> init) {
  57. assert(theFunction && "found detached local variable");
  58. const uint32_t ptrType = getPointerType(varType, spv::StorageClass::Function);
  59. const uint32_t varId = theContext.takeNextId();
  60. theFunction->addVariable(ptrType, varId, init);
  61. theModule.addDebugName(varId, name);
  62. return varId;
  63. }
  64. bool ModuleBuilder::endFunction() {
  65. if (theFunction == nullptr) {
  66. assert(false && "no active function");
  67. return false;
  68. }
  69. // Move all basic blocks into the current function.
  70. // TODO: we should adjust the order the basic blocks according to
  71. // SPIR-V validation rules.
  72. for (auto &bb : basicBlocks) {
  73. theFunction->addBasicBlock(std::move(bb.second));
  74. }
  75. basicBlocks.clear();
  76. theModule.addFunction(std::move(theFunction));
  77. theFunction.reset(nullptr);
  78. insertPoint = nullptr;
  79. return true;
  80. }
  81. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  82. if (theFunction == nullptr) {
  83. assert(false && "found detached basic block");
  84. return 0;
  85. }
  86. const uint32_t labelId = theContext.takeNextId();
  87. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId, name);
  88. return labelId;
  89. }
  90. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  91. assert(insertPoint && "null insert point");
  92. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  93. }
  94. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  95. assert(insertPoint && "null insert point");
  96. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  97. }
  98. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  99. assert(insertPoint && "null insert point");
  100. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  101. }
  102. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  103. insertPoint = getBasicBlock(labelId);
  104. }
  105. uint32_t
  106. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  107. llvm::ArrayRef<uint32_t> constituents) {
  108. assert(insertPoint && "null insert point");
  109. const uint32_t resultId = theContext.takeNextId();
  110. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  111. insertPoint->appendInstruction(std::move(constructSite));
  112. return resultId;
  113. }
  114. uint32_t
  115. ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
  116. llvm::ArrayRef<uint32_t> indexes) {
  117. assert(insertPoint && "null insert point");
  118. const uint32_t resultId = theContext.takeNextId();
  119. instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
  120. insertPoint->appendInstruction(std::move(constructSite));
  121. return resultId;
  122. }
  123. uint32_t ModuleBuilder::createCompositeInsert(uint32_t resultType,
  124. uint32_t composite,
  125. llvm::ArrayRef<uint32_t> indices,
  126. uint32_t object) {
  127. assert(insertPoint && "null insert point");
  128. const uint32_t resultId = theContext.takeNextId();
  129. instBuilder
  130. .opCompositeInsert(resultType, resultId, object, composite, indices)
  131. .x();
  132. insertPoint->appendInstruction(std::move(constructSite));
  133. return resultId;
  134. }
  135. uint32_t
  136. ModuleBuilder::createVectorShuffle(uint32_t resultType, uint32_t vector1,
  137. uint32_t vector2,
  138. llvm::ArrayRef<uint32_t> selectors) {
  139. assert(insertPoint && "null insert point");
  140. const uint32_t resultId = theContext.takeNextId();
  141. instBuilder.opVectorShuffle(resultType, resultId, vector1, vector2, selectors)
  142. .x();
  143. insertPoint->appendInstruction(std::move(constructSite));
  144. return resultId;
  145. }
  146. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  147. assert(insertPoint && "null insert point");
  148. const uint32_t resultId = theContext.takeNextId();
  149. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  150. insertPoint->appendInstruction(std::move(constructSite));
  151. return resultId;
  152. }
  153. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  154. assert(insertPoint && "null insert point");
  155. instBuilder.opStore(address, value, llvm::None).x();
  156. insertPoint->appendInstruction(std::move(constructSite));
  157. }
  158. uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
  159. uint32_t functionId,
  160. llvm::ArrayRef<uint32_t> params) {
  161. assert(insertPoint && "null insert point");
  162. const uint32_t id = theContext.takeNextId();
  163. instBuilder.opFunctionCall(returnType, id, functionId, params).x();
  164. insertPoint->appendInstruction(std::move(constructSite));
  165. return id;
  166. }
  167. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  168. llvm::ArrayRef<uint32_t> indexes) {
  169. assert(insertPoint && "null insert point");
  170. const uint32_t id = theContext.takeNextId();
  171. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  172. insertPoint->appendInstruction(std::move(constructSite));
  173. return id;
  174. }
  175. uint32_t ModuleBuilder::createUnaryOp(spv::Op op, uint32_t resultType,
  176. uint32_t operand) {
  177. assert(insertPoint && "null insert point");
  178. const uint32_t id = theContext.takeNextId();
  179. instBuilder.unaryOp(op, resultType, id, operand).x();
  180. insertPoint->appendInstruction(std::move(constructSite));
  181. switch (op) {
  182. case spv::Op::OpImageQuerySize:
  183. case spv::Op::OpImageQueryLevels:
  184. case spv::Op::OpImageQuerySamples:
  185. requireCapability(spv::Capability::ImageQuery);
  186. break;
  187. default:
  188. // Only checking for ImageQueries, the other Ops can be ignored.
  189. break;
  190. }
  191. return id;
  192. }
  193. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  194. uint32_t lhs, uint32_t rhs) {
  195. assert(insertPoint && "null insert point");
  196. const uint32_t id = theContext.takeNextId();
  197. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  198. insertPoint->appendInstruction(std::move(constructSite));
  199. switch (op) {
  200. case spv::Op::OpImageQueryLod:
  201. case spv::Op::OpImageQuerySizeLod:
  202. requireCapability(spv::Capability::ImageQuery);
  203. break;
  204. default:
  205. // Only checking for ImageQueries, the other Ops can be ignored.
  206. break;
  207. }
  208. return id;
  209. }
  210. uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
  211. uint32_t resultType,
  212. uint32_t lhs, uint32_t rhs) {
  213. const uint32_t id = theContext.takeNextId();
  214. instBuilder.specConstantBinaryOp(op, resultType, id, lhs, rhs).x();
  215. theModule.addVariable(std::move(constructSite));
  216. return id;
  217. }
  218. uint32_t ModuleBuilder::createGroupNonUniformOp(spv::Op op, uint32_t resultType,
  219. uint32_t execScope) {
  220. assert(insertPoint && "null insert point");
  221. const uint32_t id = theContext.takeNextId();
  222. instBuilder.groupNonUniformOp(op, resultType, id, execScope).x();
  223. insertPoint->appendInstruction(std::move(constructSite));
  224. return id;
  225. }
  226. uint32_t ModuleBuilder::createGroupNonUniformUnaryOp(
  227. spv::Op op, uint32_t resultType, uint32_t execScope, uint32_t operand,
  228. llvm::Optional<spv::GroupOperation> groupOp) {
  229. assert(insertPoint && "null insert point");
  230. const uint32_t id = theContext.takeNextId();
  231. instBuilder
  232. .groupNonUniformUnaryOp(op, resultType, id, execScope, groupOp, operand)
  233. .x();
  234. insertPoint->appendInstruction(std::move(constructSite));
  235. return id;
  236. }
  237. uint32_t ModuleBuilder::createGroupNonUniformBinaryOp(spv::Op op,
  238. uint32_t resultType,
  239. uint32_t execScope,
  240. uint32_t operand1,
  241. uint32_t operand2) {
  242. assert(insertPoint && "null insert point");
  243. const uint32_t id = theContext.takeNextId();
  244. instBuilder
  245. .groupNonUniformBinaryOp(op, resultType, id, execScope, operand1,
  246. operand2)
  247. .x();
  248. insertPoint->appendInstruction(std::move(constructSite));
  249. return id;
  250. }
  251. uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
  252. uint32_t orignalValuePtr,
  253. uint32_t scopeId,
  254. uint32_t memorySemanticsId,
  255. uint32_t valueToOp) {
  256. assert(insertPoint && "null insert point");
  257. const uint32_t id = theContext.takeNextId();
  258. instBuilder
  259. .atomicOp(opcode, resultType, id, orignalValuePtr, scopeId,
  260. memorySemanticsId, valueToOp)
  261. .x();
  262. insertPoint->appendInstruction(std::move(constructSite));
  263. return id;
  264. }
  265. uint32_t ModuleBuilder::createAtomicCompareExchange(
  266. uint32_t resultType, uint32_t orignalValuePtr, uint32_t scopeId,
  267. uint32_t equalMemorySemanticsId, uint32_t unequalMemorySemanticsId,
  268. uint32_t valueToOp, uint32_t comparator) {
  269. assert(insertPoint && "null insert point");
  270. const uint32_t id = theContext.takeNextId();
  271. instBuilder.opAtomicCompareExchange(
  272. resultType, id, orignalValuePtr, scopeId, equalMemorySemanticsId,
  273. unequalMemorySemanticsId, valueToOp, comparator);
  274. instBuilder.x();
  275. insertPoint->appendInstruction(std::move(constructSite));
  276. return id;
  277. }
  278. spv::ImageOperandsMask ModuleBuilder::composeImageOperandsMask(
  279. uint32_t bias, uint32_t lod, const std::pair<uint32_t, uint32_t> &grad,
  280. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  281. uint32_t sample, uint32_t minLod,
  282. llvm::SmallVectorImpl<uint32_t> *orderedParams) {
  283. using spv::ImageOperandsMask;
  284. // SPIR-V Image Operands from least significant bit to most significant bit
  285. // Bias, Lod, Grad, ConstOffset, Offset, ConstOffsets, Sample, MinLod
  286. auto mask = ImageOperandsMask::MaskNone;
  287. orderedParams->clear();
  288. if (bias) {
  289. mask = mask | ImageOperandsMask::Bias;
  290. orderedParams->push_back(bias);
  291. }
  292. if (lod) {
  293. mask = mask | ImageOperandsMask::Lod;
  294. orderedParams->push_back(lod);
  295. }
  296. if (grad.first && grad.second) {
  297. mask = mask | ImageOperandsMask::Grad;
  298. orderedParams->push_back(grad.first);
  299. orderedParams->push_back(grad.second);
  300. }
  301. if (constOffset) {
  302. mask = mask | ImageOperandsMask::ConstOffset;
  303. orderedParams->push_back(constOffset);
  304. }
  305. if (varOffset) {
  306. mask = mask | ImageOperandsMask::Offset;
  307. requireCapability(spv::Capability::ImageGatherExtended);
  308. orderedParams->push_back(varOffset);
  309. }
  310. if (constOffsets) {
  311. mask = mask | ImageOperandsMask::ConstOffsets;
  312. requireCapability(spv::Capability::ImageGatherExtended);
  313. orderedParams->push_back(constOffsets);
  314. }
  315. if (sample) {
  316. mask = mask | ImageOperandsMask::Sample;
  317. orderedParams->push_back(sample);
  318. }
  319. if (minLod) {
  320. requireCapability(spv::Capability::MinLod);
  321. mask = mask | ImageOperandsMask::MinLod;
  322. orderedParams->push_back(minLod);
  323. }
  324. return mask;
  325. }
  326. uint32_t
  327. ModuleBuilder::createImageSparseTexelsResident(uint32_t resident_code) {
  328. assert(insertPoint && "null insert point");
  329. // Result type must be a boolean
  330. const uint32_t result_type = getBoolType();
  331. const uint32_t id = theContext.takeNextId();
  332. instBuilder.opImageSparseTexelsResident(result_type, id, resident_code).x();
  333. insertPoint->appendInstruction(std::move(constructSite));
  334. return id;
  335. }
  336. uint32_t ModuleBuilder::createImageTexelPointer(uint32_t resultType,
  337. uint32_t imageId,
  338. uint32_t coordinate,
  339. uint32_t sample) {
  340. assert(insertPoint && "null insert point");
  341. const uint32_t id = theContext.takeNextId();
  342. instBuilder.opImageTexelPointer(resultType, id, imageId, coordinate, sample)
  343. .x();
  344. insertPoint->appendInstruction(std::move(constructSite));
  345. return id;
  346. }
  347. uint32_t ModuleBuilder::createImageSample(
  348. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  349. bool isNonUniform, uint32_t coordinate, uint32_t compareVal, uint32_t bias,
  350. uint32_t lod, std::pair<uint32_t, uint32_t> grad, uint32_t constOffset,
  351. uint32_t varOffset, uint32_t constOffsets, uint32_t sample, uint32_t minLod,
  352. uint32_t residencyCodeId) {
  353. assert(insertPoint && "null insert point");
  354. // The Lod and Grad image operands requires explicit-lod instructions.
  355. // Otherwise we use implicit-lod instructions.
  356. const bool isExplicit = lod || (grad.first && grad.second);
  357. const bool isSparse = (residencyCodeId != 0);
  358. // minLod is only valid with Implicit instructions and Grad instructions.
  359. // This means that we cannot have Lod and minLod together because Lod requires
  360. // explicit insturctions. So either lod or minLod or both must be zero.
  361. assert(lod == 0 || minLod == 0);
  362. uint32_t retType = texelType;
  363. if (isSparse) {
  364. requireCapability(spv::Capability::SparseResidency);
  365. retType = getSparseResidencyStructType(texelType);
  366. }
  367. // An OpSampledImage is required to do the image sampling.
  368. const uint32_t sampledImgId = theContext.takeNextId();
  369. const uint32_t sampledImgTy = getSampledImageType(imageType);
  370. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  371. insertPoint->appendInstruction(std::move(constructSite));
  372. if (isNonUniform) {
  373. // The sampled image will be used to access resource's memory, so we need
  374. // to decorate it with NonUniformEXT.
  375. decorateNonUniformEXT(sampledImgId);
  376. }
  377. uint32_t texelId = theContext.takeNextId();
  378. llvm::SmallVector<uint32_t, 4> params;
  379. const auto mask =
  380. composeImageOperandsMask(bias, lod, grad, constOffset, varOffset,
  381. constOffsets, sample, minLod, &params);
  382. instBuilder.opImageSample(retType, texelId, sampledImgId, coordinate,
  383. compareVal, mask, isExplicit, isSparse);
  384. for (const auto param : params)
  385. instBuilder.idRef(param);
  386. instBuilder.x();
  387. insertPoint->appendInstruction(std::move(constructSite));
  388. if (isSparse) {
  389. // Write the Residency Code
  390. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  391. createStore(residencyCodeId, status);
  392. // Extract the real result from the struct
  393. texelId = createCompositeExtract(texelType, texelId, {1});
  394. }
  395. return texelId;
  396. }
  397. void ModuleBuilder::createImageWrite(QualType imageType, uint32_t imageId,
  398. uint32_t coordId, uint32_t texelId) {
  399. assert(insertPoint && "null insert point");
  400. requireCapability(
  401. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  402. instBuilder.opImageWrite(imageId, coordId, texelId, llvm::None).x();
  403. insertPoint->appendInstruction(std::move(constructSite));
  404. }
  405. uint32_t ModuleBuilder::createImageFetchOrRead(
  406. bool doImageFetch, uint32_t texelType, QualType imageType, uint32_t image,
  407. uint32_t coordinate, uint32_t lod, uint32_t constOffset, uint32_t varOffset,
  408. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  409. assert(insertPoint && "null insert point");
  410. llvm::SmallVector<uint32_t, 2> params;
  411. const auto mask =
  412. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  413. /*bias*/ 0, lod, std::make_pair(0, 0), constOffset, varOffset,
  414. constOffsets, sample, /*minLod*/ 0, &params));
  415. const bool isSparse = (residencyCodeId != 0);
  416. uint32_t retType = texelType;
  417. if (isSparse) {
  418. requireCapability(spv::Capability::SparseResidency);
  419. retType = getSparseResidencyStructType(texelType);
  420. }
  421. if (!doImageFetch) {
  422. requireCapability(
  423. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  424. }
  425. uint32_t texelId = theContext.takeNextId();
  426. instBuilder.opImageFetchRead(retType, texelId, image, coordinate, mask,
  427. doImageFetch, isSparse);
  428. for (const auto param : params)
  429. instBuilder.idRef(param);
  430. instBuilder.x();
  431. insertPoint->appendInstruction(std::move(constructSite));
  432. if (isSparse) {
  433. // Write the Residency Code
  434. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  435. createStore(residencyCodeId, status);
  436. // Extract the real result from the struct
  437. texelId = createCompositeExtract(texelType, texelId, {1});
  438. }
  439. return texelId;
  440. }
  441. uint32_t ModuleBuilder::createImageGather(
  442. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  443. bool isNonUniform, uint32_t coordinate, uint32_t component,
  444. uint32_t compareVal, uint32_t constOffset, uint32_t varOffset,
  445. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  446. assert(insertPoint && "null insert point");
  447. uint32_t sparseRetType = 0;
  448. if (residencyCodeId) {
  449. requireCapability(spv::Capability::SparseResidency);
  450. sparseRetType = getSparseResidencyStructType(texelType);
  451. }
  452. // An OpSampledImage is required to do the image sampling.
  453. const uint32_t sampledImgId = theContext.takeNextId();
  454. const uint32_t sampledImgTy = getSampledImageType(imageType);
  455. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  456. insertPoint->appendInstruction(std::move(constructSite));
  457. if (isNonUniform) {
  458. // The sampled image will be used to access resource's memory, so we need
  459. // to decorate it with NonUniformEXT.
  460. decorateNonUniformEXT(sampledImgId);
  461. }
  462. llvm::SmallVector<uint32_t, 2> params;
  463. // TODO: Update ImageGather to accept minLod if necessary.
  464. const auto mask =
  465. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  466. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  467. constOffsets, sample, /*minLod*/ 0, &params));
  468. uint32_t texelId = theContext.takeNextId();
  469. if (compareVal) {
  470. if (residencyCodeId) {
  471. // Note: OpImageSparseDrefGather does not take the component parameter.
  472. instBuilder.opImageSparseDrefGather(sparseRetType, texelId, sampledImgId,
  473. coordinate, compareVal, mask);
  474. } else {
  475. // Note: OpImageDrefGather does not take the component parameter.
  476. instBuilder.opImageDrefGather(texelType, texelId, sampledImgId,
  477. coordinate, compareVal, mask);
  478. }
  479. } else {
  480. if (residencyCodeId) {
  481. instBuilder.opImageSparseGather(sparseRetType, texelId, sampledImgId,
  482. coordinate, component, mask);
  483. } else {
  484. instBuilder.opImageGather(texelType, texelId, sampledImgId, coordinate,
  485. component, mask);
  486. }
  487. }
  488. for (const auto param : params)
  489. instBuilder.idRef(param);
  490. instBuilder.x();
  491. insertPoint->appendInstruction(std::move(constructSite));
  492. if (residencyCodeId) {
  493. // Write the Residency Code
  494. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  495. createStore(residencyCodeId, status);
  496. // Extract the real result from the struct
  497. texelId = createCompositeExtract(texelType, texelId, {1});
  498. }
  499. return texelId;
  500. }
  501. uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
  502. uint32_t trueValue, uint32_t falseValue) {
  503. assert(insertPoint && "null insert point");
  504. const uint32_t id = theContext.takeNextId();
  505. instBuilder.opSelect(resultType, id, condition, trueValue, falseValue).x();
  506. insertPoint->appendInstruction(std::move(constructSite));
  507. return id;
  508. }
  509. void ModuleBuilder::createSwitch(
  510. uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
  511. llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
  512. assert(insertPoint && "null insert point");
  513. // Create the OpSelectioMerege.
  514. instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  515. .x();
  516. insertPoint->appendInstruction(std::move(constructSite));
  517. // Create the OpSwitch.
  518. instBuilder.opSwitch(selector, defaultLabel, target).x();
  519. insertPoint->appendInstruction(std::move(constructSite));
  520. }
  521. void ModuleBuilder::createKill() {
  522. assert(insertPoint && "null insert point");
  523. assert(!isCurrentBasicBlockTerminated());
  524. instBuilder.opKill().x();
  525. insertPoint->appendInstruction(std::move(constructSite));
  526. }
  527. void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
  528. uint32_t continueBB,
  529. spv::LoopControlMask loopControl) {
  530. assert(insertPoint && "null insert point");
  531. if (mergeBB && continueBB) {
  532. instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
  533. insertPoint->appendInstruction(std::move(constructSite));
  534. }
  535. instBuilder.opBranch(targetLabel).x();
  536. insertPoint->appendInstruction(std::move(constructSite));
  537. }
  538. void ModuleBuilder::createConditionalBranch(
  539. uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
  540. uint32_t mergeLabel, uint32_t continueLabel,
  541. spv::SelectionControlMask selectionControl,
  542. spv::LoopControlMask loopControl) {
  543. assert(insertPoint && "null insert point");
  544. if (mergeLabel) {
  545. if (continueLabel) {
  546. instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
  547. insertPoint->appendInstruction(std::move(constructSite));
  548. } else {
  549. instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
  550. insertPoint->appendInstruction(std::move(constructSite));
  551. }
  552. }
  553. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  554. insertPoint->appendInstruction(std::move(constructSite));
  555. }
  556. void ModuleBuilder::createReturn() {
  557. assert(insertPoint && "null insert point");
  558. instBuilder.opReturn().x();
  559. insertPoint->appendInstruction(std::move(constructSite));
  560. }
  561. void ModuleBuilder::createReturnValue(uint32_t value) {
  562. assert(insertPoint && "null insert point");
  563. instBuilder.opReturnValue(value).x();
  564. insertPoint->appendInstruction(std::move(constructSite));
  565. }
  566. uint32_t ModuleBuilder::createExtInst(uint32_t resultType, uint32_t setId,
  567. uint32_t instId,
  568. llvm::ArrayRef<uint32_t> operands) {
  569. assert(insertPoint && "null insert point");
  570. uint32_t resultId = theContext.takeNextId();
  571. instBuilder.opExtInst(resultType, resultId, setId, instId, operands).x();
  572. insertPoint->appendInstruction(std::move(constructSite));
  573. return resultId;
  574. }
  575. void ModuleBuilder::createBarrier(uint32_t execution, uint32_t memory,
  576. uint32_t semantics) {
  577. assert(insertPoint && "null insert point");
  578. if (execution)
  579. instBuilder.opControlBarrier(execution, memory, semantics).x();
  580. else
  581. instBuilder.opMemoryBarrier(memory, semantics).x();
  582. insertPoint->appendInstruction(std::move(constructSite));
  583. }
  584. uint32_t ModuleBuilder::createBitFieldExtract(uint32_t resultType,
  585. uint32_t base, uint32_t offset,
  586. uint32_t count, bool isSigned) {
  587. assert(insertPoint && "null insert point");
  588. uint32_t resultId = theContext.takeNextId();
  589. if (isSigned)
  590. instBuilder.opBitFieldSExtract(resultType, resultId, base, offset, count);
  591. else
  592. instBuilder.opBitFieldUExtract(resultType, resultId, base, offset, count);
  593. instBuilder.x();
  594. insertPoint->appendInstruction(std::move(constructSite));
  595. return resultId;
  596. }
  597. uint32_t ModuleBuilder::createBitFieldInsert(uint32_t resultType, uint32_t base,
  598. uint32_t insert, uint32_t offset,
  599. uint32_t count) {
  600. assert(insertPoint && "null insert point");
  601. uint32_t resultId = theContext.takeNextId();
  602. instBuilder
  603. .opBitFieldInsert(resultType, resultId, base, insert, offset, count)
  604. .x();
  605. insertPoint->appendInstruction(std::move(constructSite));
  606. return resultId;
  607. }
  608. void ModuleBuilder::createEmitVertex() {
  609. assert(insertPoint && "null insert point");
  610. instBuilder.opEmitVertex().x();
  611. insertPoint->appendInstruction(std::move(constructSite));
  612. }
  613. void ModuleBuilder::createEndPrimitive() {
  614. assert(insertPoint && "null insert point");
  615. instBuilder.opEndPrimitive().x();
  616. insertPoint->appendInstruction(std::move(constructSite));
  617. }
  618. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  619. spv::ExecutionMode em,
  620. llvm::ArrayRef<uint32_t> params) {
  621. instBuilder.opExecutionMode(entryPointId, em);
  622. for (const auto &param : params) {
  623. instBuilder.literalInteger(param);
  624. }
  625. instBuilder.x();
  626. theModule.addExecutionMode(std::move(constructSite));
  627. }
  628. void ModuleBuilder::addExtension(Extension ext, llvm::StringRef target,
  629. SourceLocation srcLoc) {
  630. assert(featureManager);
  631. featureManager->requestExtension(ext, target, srcLoc);
  632. // Do not emit OpExtension if the given extension is natively supported in the
  633. // target environment.
  634. if (featureManager->isExtensionRequiredForTargetEnv(ext))
  635. theModule.addExtension(featureManager->getExtensionName(ext));
  636. }
  637. uint32_t ModuleBuilder::getGLSLExtInstSet() {
  638. if (glslExtSetId == 0) {
  639. glslExtSetId = theContext.takeNextId();
  640. theModule.addExtInstSet(glslExtSetId, "GLSL.std.450");
  641. }
  642. return glslExtSetId;
  643. }
  644. uint32_t ModuleBuilder::addStageIOVar(uint32_t type,
  645. spv::StorageClass storageClass,
  646. std::string name) {
  647. const uint32_t pointerType = getPointerType(type, storageClass);
  648. const uint32_t varId = theContext.takeNextId();
  649. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  650. theModule.addVariable(std::move(constructSite));
  651. theModule.addDebugName(varId, name);
  652. return varId;
  653. }
  654. uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
  655. spv::BuiltIn builtin) {
  656. const uint32_t pointerType = getPointerType(type, sc);
  657. const uint32_t varId = theContext.takeNextId();
  658. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  659. theModule.addVariable(std::move(constructSite));
  660. // Decorate with the specified Builtin
  661. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  662. theModule.addDecoration(d, varId);
  663. return varId;
  664. }
  665. uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
  666. llvm::StringRef name,
  667. llvm::Optional<uint32_t> init) {
  668. assert(sc != spv::StorageClass::Function);
  669. // TODO: basically duplicated code of addFileVar()
  670. const uint32_t pointerType = getPointerType(type, sc);
  671. const uint32_t varId = theContext.takeNextId();
  672. instBuilder.opVariable(pointerType, varId, sc, init).x();
  673. theModule.addVariable(std::move(constructSite));
  674. theModule.addDebugName(varId, name);
  675. return varId;
  676. }
  677. void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
  678. uint32_t bindingNumber) {
  679. const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
  680. theModule.addDecoration(d, targetId);
  681. d = Decoration::getBinding(theContext, bindingNumber);
  682. theModule.addDecoration(d, targetId);
  683. }
  684. void ModuleBuilder::decorateInputAttachmentIndex(uint32_t targetId,
  685. uint32_t indexNumber) {
  686. const auto *d = Decoration::getInputAttachmentIndex(theContext, indexNumber);
  687. theModule.addDecoration(d, targetId);
  688. }
  689. void ModuleBuilder::decorateCounterBufferId(uint32_t mainBufferId,
  690. uint32_t counterBufferId) {
  691. if (spirvOptions.enableReflect) {
  692. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  693. {});
  694. theModule.addDecoration(
  695. Decoration::getHlslCounterBufferGOOGLE(theContext, counterBufferId),
  696. mainBufferId);
  697. }
  698. }
  699. void ModuleBuilder::decorateHlslSemantic(uint32_t targetId,
  700. llvm::StringRef semantic,
  701. llvm::Optional<uint32_t> memberIdx) {
  702. if (spirvOptions.enableReflect) {
  703. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  704. {});
  705. theModule.addDecoration(
  706. Decoration::getHlslSemanticGOOGLE(theContext, semantic, memberIdx),
  707. targetId);
  708. }
  709. }
  710. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  711. const Decoration *d =
  712. Decoration::getLocation(theContext, location, llvm::None);
  713. theModule.addDecoration(d, targetId);
  714. }
  715. void ModuleBuilder::decorateIndex(uint32_t targetId, uint32_t index) {
  716. const Decoration *d = Decoration::getIndex(theContext, index);
  717. theModule.addDecoration(d, targetId);
  718. }
  719. void ModuleBuilder::decorateSpecId(uint32_t targetId, uint32_t specId) {
  720. const Decoration *d = Decoration::getSpecId(theContext, specId);
  721. theModule.addDecoration(d, targetId);
  722. }
  723. void ModuleBuilder::decorateCentroid(uint32_t targetId) {
  724. const Decoration *d = Decoration::getCentroid(theContext);
  725. theModule.addDecoration(d, targetId);
  726. }
  727. void ModuleBuilder::decorateFlat(uint32_t targetId) {
  728. const Decoration *d = Decoration::getFlat(theContext);
  729. theModule.addDecoration(d, targetId);
  730. }
  731. void ModuleBuilder::decorateNoPerspective(uint32_t targetId) {
  732. const Decoration *d = Decoration::getNoPerspective(theContext);
  733. theModule.addDecoration(d, targetId);
  734. }
  735. void ModuleBuilder::decorateSample(uint32_t targetId) {
  736. const Decoration *d = Decoration::getSample(theContext);
  737. theModule.addDecoration(d, targetId);
  738. }
  739. void ModuleBuilder::decorateBlock(uint32_t targetId) {
  740. const Decoration *d = Decoration::getBlock(theContext);
  741. theModule.addDecoration(d, targetId);
  742. }
  743. void ModuleBuilder::decorateRelaxedPrecision(uint32_t targetId) {
  744. const Decoration *d = Decoration::getRelaxedPrecision(theContext);
  745. theModule.addDecoration(d, targetId);
  746. }
  747. void ModuleBuilder::decoratePatch(uint32_t targetId) {
  748. const Decoration *d = Decoration::getPatch(theContext);
  749. theModule.addDecoration(d, targetId);
  750. }
  751. void ModuleBuilder::decorateNonUniformEXT(uint32_t targetId) {
  752. const Decoration *d = Decoration::getNonUniformEXT(theContext);
  753. theModule.addDecoration(d, targetId);
  754. }
  755. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  756. \
  757. uint32_t ModuleBuilder::get##ty##Type() { \
  758. const Type *type = Type::get##ty(theContext); \
  759. const uint32_t typeId = theContext.getResultIdForType(type); \
  760. theModule.addType(type, typeId); \
  761. return typeId; \
  762. }
  763. IMPL_GET_PRIMITIVE_TYPE(Void)
  764. IMPL_GET_PRIMITIVE_TYPE(Bool)
  765. IMPL_GET_PRIMITIVE_TYPE(Int32)
  766. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  767. IMPL_GET_PRIMITIVE_TYPE(Float32)
  768. #undef IMPL_GET_PRIMITIVE_TYPE
  769. // Note: At the moment, Float16 capability should not be added for Vulkan 1.0.
  770. // It is not a required capability, and adding the SPV_AMD_gpu_half_float does
  771. // not enable this capability. Any driver that supports float16 in Vulkan 1.0
  772. // should accept this extension.
  773. #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap) \
  774. \
  775. uint32_t ModuleBuilder::get##ty##Type() { \
  776. if (spv::Capability::cap == spv::Capability::Float16) \
  777. addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float", {}); \
  778. else \
  779. requireCapability(spv::Capability::cap); \
  780. const Type *type = Type::get##ty(theContext); \
  781. const uint32_t typeId = theContext.getResultIdForType(type); \
  782. theModule.addType(type, typeId); \
  783. return typeId; \
  784. }
  785. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
  786. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
  787. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
  788. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int16, Int16)
  789. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint16, Int16)
  790. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float16, Float16)
  791. #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
  792. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  793. const Type *type = nullptr;
  794. switch (elemCount) {
  795. case 2:
  796. type = Type::getVec2(theContext, elemType);
  797. break;
  798. case 3:
  799. type = Type::getVec3(theContext, elemType);
  800. break;
  801. case 4:
  802. type = Type::getVec4(theContext, elemType);
  803. break;
  804. default:
  805. assert(false && "unhandled vector size");
  806. // Error found. Return 0 as the <result-id> directly.
  807. return 0;
  808. }
  809. const uint32_t typeId = theContext.getResultIdForType(type);
  810. theModule.addType(type, typeId);
  811. return typeId;
  812. }
  813. uint32_t ModuleBuilder::getMatType(QualType elemType, uint32_t colType,
  814. uint32_t colCount,
  815. Type::DecorationSet decorations) {
  816. // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
  817. // Validation Rules":
  818. // Matrix types can only be parameterized with floating-point types.
  819. //
  820. // So we need special handling of non-fp matrices. We emulate non-fp
  821. // matrices as an array of vectors.
  822. if (!elemType->isFloatingType())
  823. return getArrayType(colType, getConstantUint32(colCount), decorations);
  824. const Type *type = Type::getMatrix(theContext, colType, colCount);
  825. const uint32_t typeId = theContext.getResultIdForType(type);
  826. theModule.addType(type, typeId);
  827. return typeId;
  828. }
  829. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  830. spv::StorageClass storageClass) {
  831. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  832. const uint32_t typeId = theContext.getResultIdForType(type);
  833. theModule.addType(type, typeId);
  834. return typeId;
  835. }
  836. uint32_t
  837. ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
  838. llvm::StringRef structName,
  839. llvm::ArrayRef<llvm::StringRef> fieldNames,
  840. Type::DecorationSet decorations) {
  841. const Type *type =
  842. Type::getStruct(theContext, fieldTypes, structName, decorations);
  843. bool isRegistered = false;
  844. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  845. theModule.addType(type, typeId);
  846. if (!isRegistered) {
  847. theModule.addDebugName(typeId, structName);
  848. if (!fieldNames.empty()) {
  849. assert(fieldNames.size() == fieldTypes.size());
  850. for (uint32_t i = 0; i < fieldNames.size(); ++i)
  851. theModule.addDebugName(typeId, fieldNames[i],
  852. llvm::Optional<uint32_t>(i));
  853. }
  854. }
  855. return typeId;
  856. }
  857. uint32_t ModuleBuilder::getSparseResidencyStructType(uint32_t type) {
  858. const auto uintType = getUint32Type();
  859. return getStructType({uintType, type}, "SparseResidencyStruct",
  860. {"Residency.Code", "Result.Type"});
  861. }
  862. uint32_t ModuleBuilder::getArrayType(uint32_t elemType, uint32_t count,
  863. Type::DecorationSet decorations) {
  864. const Type *type = Type::getArray(theContext, elemType, count, decorations);
  865. const uint32_t typeId = theContext.getResultIdForType(type);
  866. theModule.addType(type, typeId);
  867. return typeId;
  868. }
  869. uint32_t ModuleBuilder::getRuntimeArrayType(uint32_t elemType,
  870. Type::DecorationSet decorations) {
  871. const Type *type = Type::getRuntimeArray(theContext, elemType, decorations);
  872. const uint32_t typeId = theContext.getResultIdForType(type);
  873. theModule.addType(type, typeId);
  874. return typeId;
  875. }
  876. uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
  877. llvm::ArrayRef<uint32_t> paramTypes) {
  878. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  879. const uint32_t typeId = theContext.getResultIdForType(type);
  880. theModule.addType(type, typeId);
  881. return typeId;
  882. }
  883. uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
  884. uint32_t depth, bool isArray, uint32_t ms,
  885. uint32_t sampled,
  886. spv::ImageFormat format) {
  887. const Type *type = Type::getImage(theContext, sampledType, dim, depth,
  888. isArray, ms, sampled, format);
  889. bool isRegistered = false;
  890. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  891. theModule.addType(type, typeId);
  892. switch (format) {
  893. case spv::ImageFormat::Rg32f:
  894. case spv::ImageFormat::Rg16f:
  895. case spv::ImageFormat::R11fG11fB10f:
  896. case spv::ImageFormat::R16f:
  897. case spv::ImageFormat::Rgba16:
  898. case spv::ImageFormat::Rgb10A2:
  899. case spv::ImageFormat::Rg16:
  900. case spv::ImageFormat::Rg8:
  901. case spv::ImageFormat::R16:
  902. case spv::ImageFormat::R8:
  903. case spv::ImageFormat::Rgba16Snorm:
  904. case spv::ImageFormat::Rg16Snorm:
  905. case spv::ImageFormat::Rg8Snorm:
  906. case spv::ImageFormat::R16Snorm:
  907. case spv::ImageFormat::R8Snorm:
  908. case spv::ImageFormat::Rg32i:
  909. case spv::ImageFormat::Rg16i:
  910. case spv::ImageFormat::Rg8i:
  911. case spv::ImageFormat::R16i:
  912. case spv::ImageFormat::R8i:
  913. case spv::ImageFormat::Rgb10a2ui:
  914. case spv::ImageFormat::Rg32ui:
  915. case spv::ImageFormat::Rg16ui:
  916. case spv::ImageFormat::Rg8ui:
  917. case spv::ImageFormat::R16ui:
  918. case spv::ImageFormat::R8ui:
  919. requireCapability(spv::Capability::StorageImageExtendedFormats);
  920. break;
  921. default:
  922. // Only image formats requiring extended formats are relevant. The rest just
  923. // pass through.
  924. break;
  925. }
  926. if (dim == spv::Dim::Dim1D) {
  927. if (sampled == 2u) {
  928. requireCapability(spv::Capability::Image1D);
  929. } else {
  930. requireCapability(spv::Capability::Sampled1D);
  931. }
  932. } else if (dim == spv::Dim::Buffer) {
  933. requireCapability(spv::Capability::SampledBuffer);
  934. } else if (dim == spv::Dim::SubpassData) {
  935. requireCapability(spv::Capability::InputAttachment);
  936. }
  937. if (isArray && ms) {
  938. requireCapability(spv::Capability::ImageMSArray);
  939. }
  940. // Skip constructing the debug name if we have already done it before.
  941. if (!isRegistered) {
  942. const char *dimStr = "";
  943. switch (dim) {
  944. case spv::Dim::Dim1D:
  945. dimStr = "1d.";
  946. break;
  947. case spv::Dim::Dim2D:
  948. dimStr = "2d.";
  949. break;
  950. case spv::Dim::Dim3D:
  951. dimStr = "3d.";
  952. break;
  953. case spv::Dim::Cube:
  954. dimStr = "cube.";
  955. break;
  956. case spv::Dim::Rect:
  957. dimStr = "rect.";
  958. break;
  959. case spv::Dim::Buffer:
  960. dimStr = "buffer.";
  961. break;
  962. case spv::Dim::SubpassData:
  963. dimStr = "subpass.";
  964. break;
  965. default:
  966. break;
  967. }
  968. std::string name =
  969. std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
  970. theModule.addDebugName(typeId, name);
  971. }
  972. return typeId;
  973. }
  974. uint32_t ModuleBuilder::getSamplerType() {
  975. const Type *type = Type::getSampler(theContext);
  976. const uint32_t typeId = theContext.getResultIdForType(type);
  977. theModule.addType(type, typeId);
  978. theModule.addDebugName(typeId, "type.sampler");
  979. return typeId;
  980. }
  981. uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
  982. const Type *type = Type::getSampledImage(theContext, imageType);
  983. const uint32_t typeId = theContext.getResultIdForType(type);
  984. theModule.addType(type, typeId);
  985. theModule.addDebugName(typeId, "type.sampled.image");
  986. return typeId;
  987. }
  988. uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
  989. // Create a uint RuntimeArray with Array Stride of 4.
  990. const uint32_t uintType = getUint32Type();
  991. const auto *arrStride4 = Decoration::getArrayStride(theContext, 4u);
  992. const Type *raType =
  993. Type::getRuntimeArray(theContext, uintType, {arrStride4});
  994. const uint32_t raTypeId = theContext.getResultIdForType(raType);
  995. theModule.addType(raType, raTypeId);
  996. // Create a struct containing the runtime array as its only member.
  997. // The struct must also be decorated as BufferBlock. The offset decoration
  998. // should also be applied to the first (only) member. NonWritable decoration
  999. // should also be applied to the first member if isRW is true.
  1000. llvm::SmallVector<const Decoration *, 3> typeDecs;
  1001. typeDecs.push_back(Decoration::getBufferBlock(theContext));
  1002. typeDecs.push_back(Decoration::getOffset(theContext, 0, 0));
  1003. if (!isRW)
  1004. typeDecs.push_back(Decoration::getNonWritable(theContext, 0));
  1005. const Type *type = Type::getStruct(theContext, {raTypeId}, "", typeDecs);
  1006. const uint32_t typeId = theContext.getResultIdForType(type);
  1007. theModule.addType(type, typeId);
  1008. theModule.addDebugName(typeId, isRW ? "type.RWByteAddressBuffer"
  1009. : "type.ByteAddressBuffer");
  1010. return typeId;
  1011. }
  1012. uint32_t ModuleBuilder::getConstantBool(bool value, bool isSpecConst) {
  1013. if (isSpecConst) {
  1014. const uint32_t constId = theContext.takeNextId();
  1015. if (value) {
  1016. instBuilder.opSpecConstantTrue(getBoolType(), constId).x();
  1017. } else {
  1018. instBuilder.opSpecConstantFalse(getBoolType(), constId).x();
  1019. }
  1020. theModule.addVariable(std::move(constructSite));
  1021. return constId;
  1022. }
  1023. const uint32_t typeId = getBoolType();
  1024. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  1025. : Constant::getFalse(theContext, typeId);
  1026. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1027. theModule.addConstant(constant, constId);
  1028. return constId;
  1029. }
  1030. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  1031. \
  1032. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  1033. const uint32_t typeId = get##builderTy##Type(); \
  1034. const Constant *constant = \
  1035. Constant::get##builderTy(theContext, typeId, value); \
  1036. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1037. theModule.addConstant(constant, constId); \
  1038. return constId; \
  1039. }
  1040. #define IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(builderTy, cppTy) \
  1041. \
  1042. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value, \
  1043. bool isSpecConst) { \
  1044. if (isSpecConst) { \
  1045. const uint32_t constId = theContext.takeNextId(); \
  1046. instBuilder \
  1047. .opSpecConstant(get##builderTy##Type(), constId, \
  1048. cast::BitwiseCast<uint32_t>(value)) \
  1049. .x(); \
  1050. theModule.addVariable(std::move(constructSite)); \
  1051. return constId; \
  1052. } \
  1053. \
  1054. const uint32_t typeId = get##builderTy##Type(); \
  1055. const Constant *constant = \
  1056. Constant::get##builderTy(theContext, typeId, value); \
  1057. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1058. theModule.addConstant(constant, constId); \
  1059. return constId; \
  1060. }
  1061. IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
  1062. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Int32, int32_t)
  1063. IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
  1064. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Uint32, uint32_t)
  1065. IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
  1066. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Float32, float)
  1067. IMPL_GET_PRIMITIVE_CONST(Float64, double)
  1068. IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
  1069. IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
  1070. #undef IMPL_GET_PRIMITIVE_CONST
  1071. #undef IMPL_GET_PRIMITIVE_CONST_SPEC_CONST
  1072. uint32_t
  1073. ModuleBuilder::getConstantComposite(uint32_t typeId,
  1074. llvm::ArrayRef<uint32_t> constituents) {
  1075. const Constant *constant =
  1076. Constant::getComposite(theContext, typeId, constituents);
  1077. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1078. theModule.addConstant(constant, constId);
  1079. return constId;
  1080. }
  1081. uint32_t ModuleBuilder::getConstantNull(uint32_t typeId) {
  1082. const Constant *constant = Constant::getNull(theContext, typeId);
  1083. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1084. theModule.addConstant(constant, constId);
  1085. return constId;
  1086. }
  1087. void ModuleBuilder::debugLine(uint32_t file, uint32_t line, uint32_t column) {
  1088. instBuilder.opLine(file, line, column).x();
  1089. insertPoint->appendInstruction(std::move(constructSite));
  1090. }
  1091. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  1092. auto it = basicBlocks.find(labelId);
  1093. if (it == basicBlocks.end()) {
  1094. assert(false && "invalid <label-id>");
  1095. return nullptr;
  1096. }
  1097. return it->second.get();
  1098. }
  1099. } // end namespace spirv
  1100. } // end namespace clang