ModuleBuilder.cpp 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "TypeTranslator.h"
  11. #include "spirv/unified1//spirv.hpp11"
  12. #include "clang/SPIRV/BitwiseCast.h"
  13. #include "clang/SPIRV/InstBuilder.h"
  14. namespace clang {
  15. namespace spirv {
  16. ModuleBuilder::ModuleBuilder(SPIRVContext *C, FeatureManager *features,
  17. bool reflect)
  18. : theContext(*C), featureManager(features), allowReflect(reflect),
  19. theModule(), theFunction(nullptr), insertPoint(nullptr),
  20. instBuilder(nullptr), glslExtSetId(0) {
  21. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  22. this->constructSite = std::move(words);
  23. });
  24. // Set the SPIR-V version if needed.
  25. if (featureManager && featureManager->getTargetEnv() == SPV_ENV_VULKAN_1_1)
  26. theModule.useVulkan1p1();
  27. }
  28. std::vector<uint32_t> ModuleBuilder::takeModule() {
  29. theModule.setBound(theContext.getNextId());
  30. std::vector<uint32_t> binary;
  31. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  32. binary.insert(binary.end(), words.begin(), words.end());
  33. });
  34. theModule.take(&ib);
  35. return binary;
  36. }
  37. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  38. llvm::StringRef funcName, uint32_t fId) {
  39. if (theFunction) {
  40. assert(false && "found nested function");
  41. return 0;
  42. }
  43. // If the caller doesn't supply a function <result-id>, we need to get one.
  44. if (!fId)
  45. fId = theContext.takeNextId();
  46. theFunction = llvm::make_unique<Function>(
  47. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  48. theModule.addDebugName(fId, funcName);
  49. return fId;
  50. }
  51. uint32_t ModuleBuilder::addFnParam(uint32_t ptrType, llvm::StringRef name) {
  52. assert(theFunction && "found detached parameter");
  53. const uint32_t paramId = theContext.takeNextId();
  54. theFunction->addParameter(ptrType, paramId);
  55. theModule.addDebugName(paramId, name);
  56. return paramId;
  57. }
  58. uint32_t ModuleBuilder::addFnVar(uint32_t varType, llvm::StringRef name,
  59. llvm::Optional<uint32_t> init) {
  60. assert(theFunction && "found detached local variable");
  61. const uint32_t ptrType = getPointerType(varType, spv::StorageClass::Function);
  62. const uint32_t varId = theContext.takeNextId();
  63. theFunction->addVariable(ptrType, varId, init);
  64. theModule.addDebugName(varId, name);
  65. return varId;
  66. }
  67. bool ModuleBuilder::endFunction() {
  68. if (theFunction == nullptr) {
  69. assert(false && "no active function");
  70. return false;
  71. }
  72. // Move all basic blocks into the current function.
  73. // TODO: we should adjust the order the basic blocks according to
  74. // SPIR-V validation rules.
  75. for (auto &bb : basicBlocks) {
  76. theFunction->addBasicBlock(std::move(bb.second));
  77. }
  78. basicBlocks.clear();
  79. theModule.addFunction(std::move(theFunction));
  80. theFunction.reset(nullptr);
  81. insertPoint = nullptr;
  82. return true;
  83. }
  84. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  85. if (theFunction == nullptr) {
  86. assert(false && "found detached basic block");
  87. return 0;
  88. }
  89. const uint32_t labelId = theContext.takeNextId();
  90. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId, name);
  91. return labelId;
  92. }
  93. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  94. assert(insertPoint && "null insert point");
  95. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  96. }
  97. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  98. assert(insertPoint && "null insert point");
  99. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  100. }
  101. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  102. assert(insertPoint && "null insert point");
  103. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  104. }
  105. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  106. insertPoint = getBasicBlock(labelId);
  107. }
  108. uint32_t
  109. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  110. llvm::ArrayRef<uint32_t> constituents) {
  111. assert(insertPoint && "null insert point");
  112. const uint32_t resultId = theContext.takeNextId();
  113. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  114. insertPoint->appendInstruction(std::move(constructSite));
  115. return resultId;
  116. }
  117. uint32_t
  118. ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
  119. llvm::ArrayRef<uint32_t> indexes) {
  120. assert(insertPoint && "null insert point");
  121. const uint32_t resultId = theContext.takeNextId();
  122. instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
  123. insertPoint->appendInstruction(std::move(constructSite));
  124. return resultId;
  125. }
  126. uint32_t ModuleBuilder::createCompositeInsert(uint32_t resultType,
  127. uint32_t composite,
  128. llvm::ArrayRef<uint32_t> indices,
  129. uint32_t object) {
  130. assert(insertPoint && "null insert point");
  131. const uint32_t resultId = theContext.takeNextId();
  132. instBuilder
  133. .opCompositeInsert(resultType, resultId, object, composite, indices)
  134. .x();
  135. insertPoint->appendInstruction(std::move(constructSite));
  136. return resultId;
  137. }
  138. uint32_t
  139. ModuleBuilder::createVectorShuffle(uint32_t resultType, uint32_t vector1,
  140. uint32_t vector2,
  141. llvm::ArrayRef<uint32_t> selectors) {
  142. assert(insertPoint && "null insert point");
  143. const uint32_t resultId = theContext.takeNextId();
  144. instBuilder.opVectorShuffle(resultType, resultId, vector1, vector2, selectors)
  145. .x();
  146. insertPoint->appendInstruction(std::move(constructSite));
  147. return resultId;
  148. }
  149. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  150. assert(insertPoint && "null insert point");
  151. const uint32_t resultId = theContext.takeNextId();
  152. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  153. insertPoint->appendInstruction(std::move(constructSite));
  154. return resultId;
  155. }
  156. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  157. assert(insertPoint && "null insert point");
  158. instBuilder.opStore(address, value, llvm::None).x();
  159. insertPoint->appendInstruction(std::move(constructSite));
  160. }
  161. uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
  162. uint32_t functionId,
  163. llvm::ArrayRef<uint32_t> params) {
  164. assert(insertPoint && "null insert point");
  165. const uint32_t id = theContext.takeNextId();
  166. instBuilder.opFunctionCall(returnType, id, functionId, params).x();
  167. insertPoint->appendInstruction(std::move(constructSite));
  168. return id;
  169. }
  170. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  171. llvm::ArrayRef<uint32_t> indexes) {
  172. assert(insertPoint && "null insert point");
  173. const uint32_t id = theContext.takeNextId();
  174. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  175. insertPoint->appendInstruction(std::move(constructSite));
  176. return id;
  177. }
  178. uint32_t ModuleBuilder::createUnaryOp(spv::Op op, uint32_t resultType,
  179. uint32_t operand) {
  180. assert(insertPoint && "null insert point");
  181. const uint32_t id = theContext.takeNextId();
  182. instBuilder.unaryOp(op, resultType, id, operand).x();
  183. insertPoint->appendInstruction(std::move(constructSite));
  184. switch (op) {
  185. case spv::Op::OpImageQuerySize:
  186. case spv::Op::OpImageQueryLevels:
  187. case spv::Op::OpImageQuerySamples:
  188. requireCapability(spv::Capability::ImageQuery);
  189. break;
  190. default:
  191. // Only checking for ImageQueries, the other Ops can be ignored.
  192. break;
  193. }
  194. return id;
  195. }
  196. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  197. uint32_t lhs, uint32_t rhs) {
  198. assert(insertPoint && "null insert point");
  199. const uint32_t id = theContext.takeNextId();
  200. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  201. insertPoint->appendInstruction(std::move(constructSite));
  202. switch (op) {
  203. case spv::Op::OpImageQueryLod:
  204. case spv::Op::OpImageQuerySizeLod:
  205. requireCapability(spv::Capability::ImageQuery);
  206. break;
  207. default:
  208. // Only checking for ImageQueries, the other Ops can be ignored.
  209. break;
  210. }
  211. return id;
  212. }
  213. uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
  214. uint32_t resultType,
  215. uint32_t lhs, uint32_t rhs) {
  216. const uint32_t id = theContext.takeNextId();
  217. instBuilder.specConstantBinaryOp(op, resultType, id, lhs, rhs).x();
  218. theModule.addVariable(std::move(constructSite));
  219. return id;
  220. }
  221. uint32_t ModuleBuilder::createGroupNonUniformOp(spv::Op op, uint32_t resultType,
  222. uint32_t execScope) {
  223. assert(insertPoint && "null insert point");
  224. const uint32_t id = theContext.takeNextId();
  225. instBuilder.groupNonUniformOp(op, resultType, id, execScope).x();
  226. insertPoint->appendInstruction(std::move(constructSite));
  227. return id;
  228. }
  229. uint32_t ModuleBuilder::createGroupNonUniformUnaryOp(
  230. spv::Op op, uint32_t resultType, uint32_t execScope, uint32_t operand,
  231. llvm::Optional<spv::GroupOperation> groupOp) {
  232. assert(insertPoint && "null insert point");
  233. const uint32_t id = theContext.takeNextId();
  234. instBuilder
  235. .groupNonUniformUnaryOp(op, resultType, id, execScope, groupOp, operand)
  236. .x();
  237. insertPoint->appendInstruction(std::move(constructSite));
  238. return id;
  239. }
  240. uint32_t ModuleBuilder::createGroupNonUniformBinaryOp(spv::Op op,
  241. uint32_t resultType,
  242. uint32_t execScope,
  243. uint32_t operand1,
  244. uint32_t operand2) {
  245. assert(insertPoint && "null insert point");
  246. const uint32_t id = theContext.takeNextId();
  247. instBuilder
  248. .groupNonUniformBinaryOp(op, resultType, id, execScope, operand1,
  249. operand2)
  250. .x();
  251. insertPoint->appendInstruction(std::move(constructSite));
  252. return id;
  253. }
  254. uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
  255. uint32_t orignalValuePtr,
  256. uint32_t scopeId,
  257. uint32_t memorySemanticsId,
  258. uint32_t valueToOp) {
  259. assert(insertPoint && "null insert point");
  260. const uint32_t id = theContext.takeNextId();
  261. instBuilder
  262. .atomicOp(opcode, resultType, id, orignalValuePtr, scopeId,
  263. memorySemanticsId, valueToOp)
  264. .x();
  265. insertPoint->appendInstruction(std::move(constructSite));
  266. return id;
  267. }
  268. uint32_t ModuleBuilder::createAtomicCompareExchange(
  269. uint32_t resultType, uint32_t orignalValuePtr, uint32_t scopeId,
  270. uint32_t equalMemorySemanticsId, uint32_t unequalMemorySemanticsId,
  271. uint32_t valueToOp, uint32_t comparator) {
  272. assert(insertPoint && "null insert point");
  273. const uint32_t id = theContext.takeNextId();
  274. instBuilder.opAtomicCompareExchange(
  275. resultType, id, orignalValuePtr, scopeId, equalMemorySemanticsId,
  276. unequalMemorySemanticsId, valueToOp, comparator);
  277. instBuilder.x();
  278. insertPoint->appendInstruction(std::move(constructSite));
  279. return id;
  280. }
  281. spv::ImageOperandsMask ModuleBuilder::composeImageOperandsMask(
  282. uint32_t bias, uint32_t lod, const std::pair<uint32_t, uint32_t> &grad,
  283. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  284. uint32_t sample, uint32_t minLod,
  285. llvm::SmallVectorImpl<uint32_t> *orderedParams) {
  286. using spv::ImageOperandsMask;
  287. // SPIR-V Image Operands from least significant bit to most significant bit
  288. // Bias, Lod, Grad, ConstOffset, Offset, ConstOffsets, Sample, MinLod
  289. auto mask = ImageOperandsMask::MaskNone;
  290. orderedParams->clear();
  291. if (bias) {
  292. mask = mask | ImageOperandsMask::Bias;
  293. orderedParams->push_back(bias);
  294. }
  295. if (lod) {
  296. mask = mask | ImageOperandsMask::Lod;
  297. orderedParams->push_back(lod);
  298. }
  299. if (grad.first && grad.second) {
  300. mask = mask | ImageOperandsMask::Grad;
  301. orderedParams->push_back(grad.first);
  302. orderedParams->push_back(grad.second);
  303. }
  304. if (constOffset) {
  305. mask = mask | ImageOperandsMask::ConstOffset;
  306. orderedParams->push_back(constOffset);
  307. }
  308. if (varOffset) {
  309. mask = mask | ImageOperandsMask::Offset;
  310. requireCapability(spv::Capability::ImageGatherExtended);
  311. orderedParams->push_back(varOffset);
  312. }
  313. if (constOffsets) {
  314. mask = mask | ImageOperandsMask::ConstOffsets;
  315. requireCapability(spv::Capability::ImageGatherExtended);
  316. orderedParams->push_back(constOffsets);
  317. }
  318. if (sample) {
  319. mask = mask | ImageOperandsMask::Sample;
  320. orderedParams->push_back(sample);
  321. }
  322. if (minLod) {
  323. requireCapability(spv::Capability::MinLod);
  324. mask = mask | ImageOperandsMask::MinLod;
  325. orderedParams->push_back(minLod);
  326. }
  327. return mask;
  328. }
  329. uint32_t
  330. ModuleBuilder::createImageSparseTexelsResident(uint32_t resident_code) {
  331. assert(insertPoint && "null insert point");
  332. // Result type must be a boolean
  333. const uint32_t result_type = getBoolType();
  334. const uint32_t id = theContext.takeNextId();
  335. instBuilder.opImageSparseTexelsResident(result_type, id, resident_code).x();
  336. insertPoint->appendInstruction(std::move(constructSite));
  337. return id;
  338. }
  339. uint32_t ModuleBuilder::createImageTexelPointer(uint32_t resultType,
  340. uint32_t imageId,
  341. uint32_t coordinate,
  342. uint32_t sample) {
  343. assert(insertPoint && "null insert point");
  344. const uint32_t id = theContext.takeNextId();
  345. instBuilder.opImageTexelPointer(resultType, id, imageId, coordinate, sample)
  346. .x();
  347. insertPoint->appendInstruction(std::move(constructSite));
  348. return id;
  349. }
  350. uint32_t ModuleBuilder::createImageSample(
  351. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  352. bool isNonUniform, uint32_t coordinate, uint32_t compareVal, uint32_t bias,
  353. uint32_t lod, std::pair<uint32_t, uint32_t> grad, uint32_t constOffset,
  354. uint32_t varOffset, uint32_t constOffsets, uint32_t sample, uint32_t minLod,
  355. uint32_t residencyCodeId) {
  356. assert(insertPoint && "null insert point");
  357. // The Lod and Grad image operands requires explicit-lod instructions.
  358. // Otherwise we use implicit-lod instructions.
  359. const bool isExplicit = lod || (grad.first && grad.second);
  360. const bool isSparse = (residencyCodeId != 0);
  361. // minLod is only valid with Implicit instructions and Grad instructions.
  362. // This means that we cannot have Lod and minLod together because Lod requires
  363. // explicit insturctions. So either lod or minLod or both must be zero.
  364. assert(lod == 0 || minLod == 0);
  365. uint32_t retType = texelType;
  366. if (isSparse) {
  367. requireCapability(spv::Capability::SparseResidency);
  368. retType = getSparseResidencyStructType(texelType);
  369. }
  370. // An OpSampledImage is required to do the image sampling.
  371. const uint32_t sampledImgId = theContext.takeNextId();
  372. const uint32_t sampledImgTy = getSampledImageType(imageType);
  373. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  374. insertPoint->appendInstruction(std::move(constructSite));
  375. if (isNonUniform) {
  376. // The sampled image will be used to access resource's memory, so we need
  377. // to decorate it with NonUniformEXT.
  378. decorateNonUniformEXT(sampledImgId);
  379. }
  380. uint32_t texelId = theContext.takeNextId();
  381. llvm::SmallVector<uint32_t, 4> params;
  382. const auto mask =
  383. composeImageOperandsMask(bias, lod, grad, constOffset, varOffset,
  384. constOffsets, sample, minLod, &params);
  385. instBuilder.opImageSample(retType, texelId, sampledImgId, coordinate,
  386. compareVal, mask, isExplicit, isSparse);
  387. for (const auto param : params)
  388. instBuilder.idRef(param);
  389. instBuilder.x();
  390. insertPoint->appendInstruction(std::move(constructSite));
  391. if (isSparse) {
  392. // Write the Residency Code
  393. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  394. createStore(residencyCodeId, status);
  395. // Extract the real result from the struct
  396. texelId = createCompositeExtract(texelType, texelId, {1});
  397. }
  398. return texelId;
  399. }
  400. void ModuleBuilder::createImageWrite(QualType imageType, uint32_t imageId,
  401. uint32_t coordId, uint32_t texelId) {
  402. assert(insertPoint && "null insert point");
  403. requireCapability(
  404. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  405. instBuilder.opImageWrite(imageId, coordId, texelId, llvm::None).x();
  406. insertPoint->appendInstruction(std::move(constructSite));
  407. }
  408. uint32_t ModuleBuilder::createImageFetchOrRead(
  409. bool doImageFetch, uint32_t texelType, QualType imageType, uint32_t image,
  410. uint32_t coordinate, uint32_t lod, uint32_t constOffset, uint32_t varOffset,
  411. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  412. assert(insertPoint && "null insert point");
  413. llvm::SmallVector<uint32_t, 2> params;
  414. const auto mask =
  415. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  416. /*bias*/ 0, lod, std::make_pair(0, 0), constOffset, varOffset,
  417. constOffsets, sample, /*minLod*/ 0, &params));
  418. const bool isSparse = (residencyCodeId != 0);
  419. uint32_t retType = texelType;
  420. if (isSparse) {
  421. requireCapability(spv::Capability::SparseResidency);
  422. retType = getSparseResidencyStructType(texelType);
  423. }
  424. if (!doImageFetch) {
  425. requireCapability(
  426. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  427. }
  428. uint32_t texelId = theContext.takeNextId();
  429. instBuilder.opImageFetchRead(retType, texelId, image, coordinate, mask,
  430. doImageFetch, isSparse);
  431. for (const auto param : params)
  432. instBuilder.idRef(param);
  433. instBuilder.x();
  434. insertPoint->appendInstruction(std::move(constructSite));
  435. if (isSparse) {
  436. // Write the Residency Code
  437. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  438. createStore(residencyCodeId, status);
  439. // Extract the real result from the struct
  440. texelId = createCompositeExtract(texelType, texelId, {1});
  441. }
  442. return texelId;
  443. }
  444. uint32_t ModuleBuilder::createImageGather(
  445. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  446. bool isNonUniform, uint32_t coordinate, uint32_t component,
  447. uint32_t compareVal, uint32_t constOffset, uint32_t varOffset,
  448. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  449. assert(insertPoint && "null insert point");
  450. uint32_t sparseRetType = 0;
  451. if (residencyCodeId) {
  452. requireCapability(spv::Capability::SparseResidency);
  453. sparseRetType = getSparseResidencyStructType(texelType);
  454. }
  455. // An OpSampledImage is required to do the image sampling.
  456. const uint32_t sampledImgId = theContext.takeNextId();
  457. const uint32_t sampledImgTy = getSampledImageType(imageType);
  458. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  459. insertPoint->appendInstruction(std::move(constructSite));
  460. if (isNonUniform) {
  461. // The sampled image will be used to access resource's memory, so we need
  462. // to decorate it with NonUniformEXT.
  463. decorateNonUniformEXT(sampledImgId);
  464. }
  465. llvm::SmallVector<uint32_t, 2> params;
  466. // TODO: Update ImageGather to accept minLod if necessary.
  467. const auto mask =
  468. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  469. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  470. constOffsets, sample, /*minLod*/ 0, &params));
  471. uint32_t texelId = theContext.takeNextId();
  472. if (compareVal) {
  473. if (residencyCodeId) {
  474. // Note: OpImageSparseDrefGather does not take the component parameter.
  475. instBuilder.opImageSparseDrefGather(sparseRetType, texelId, sampledImgId,
  476. coordinate, compareVal, mask);
  477. } else {
  478. // Note: OpImageDrefGather does not take the component parameter.
  479. instBuilder.opImageDrefGather(texelType, texelId, sampledImgId,
  480. coordinate, compareVal, mask);
  481. }
  482. } else {
  483. if (residencyCodeId) {
  484. instBuilder.opImageSparseGather(sparseRetType, texelId, sampledImgId,
  485. coordinate, component, mask);
  486. } else {
  487. instBuilder.opImageGather(texelType, texelId, sampledImgId, coordinate,
  488. component, mask);
  489. }
  490. }
  491. for (const auto param : params)
  492. instBuilder.idRef(param);
  493. instBuilder.x();
  494. insertPoint->appendInstruction(std::move(constructSite));
  495. if (residencyCodeId) {
  496. // Write the Residency Code
  497. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  498. createStore(residencyCodeId, status);
  499. // Extract the real result from the struct
  500. texelId = createCompositeExtract(texelType, texelId, {1});
  501. }
  502. return texelId;
  503. }
  504. uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
  505. uint32_t trueValue, uint32_t falseValue) {
  506. assert(insertPoint && "null insert point");
  507. const uint32_t id = theContext.takeNextId();
  508. instBuilder.opSelect(resultType, id, condition, trueValue, falseValue).x();
  509. insertPoint->appendInstruction(std::move(constructSite));
  510. return id;
  511. }
  512. void ModuleBuilder::createSwitch(
  513. uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
  514. llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
  515. assert(insertPoint && "null insert point");
  516. // Create the OpSelectioMerege.
  517. instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  518. .x();
  519. insertPoint->appendInstruction(std::move(constructSite));
  520. // Create the OpSwitch.
  521. instBuilder.opSwitch(selector, defaultLabel, target).x();
  522. insertPoint->appendInstruction(std::move(constructSite));
  523. }
  524. void ModuleBuilder::createKill() {
  525. assert(insertPoint && "null insert point");
  526. assert(!isCurrentBasicBlockTerminated());
  527. instBuilder.opKill().x();
  528. insertPoint->appendInstruction(std::move(constructSite));
  529. }
  530. void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
  531. uint32_t continueBB,
  532. spv::LoopControlMask loopControl) {
  533. assert(insertPoint && "null insert point");
  534. if (mergeBB && continueBB) {
  535. instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
  536. insertPoint->appendInstruction(std::move(constructSite));
  537. }
  538. instBuilder.opBranch(targetLabel).x();
  539. insertPoint->appendInstruction(std::move(constructSite));
  540. }
  541. void ModuleBuilder::createConditionalBranch(
  542. uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
  543. uint32_t mergeLabel, uint32_t continueLabel,
  544. spv::SelectionControlMask selectionControl,
  545. spv::LoopControlMask loopControl) {
  546. assert(insertPoint && "null insert point");
  547. if (mergeLabel) {
  548. if (continueLabel) {
  549. instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
  550. insertPoint->appendInstruction(std::move(constructSite));
  551. } else {
  552. instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
  553. insertPoint->appendInstruction(std::move(constructSite));
  554. }
  555. }
  556. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  557. insertPoint->appendInstruction(std::move(constructSite));
  558. }
  559. void ModuleBuilder::createReturn() {
  560. assert(insertPoint && "null insert point");
  561. instBuilder.opReturn().x();
  562. insertPoint->appendInstruction(std::move(constructSite));
  563. }
  564. void ModuleBuilder::createReturnValue(uint32_t value) {
  565. assert(insertPoint && "null insert point");
  566. instBuilder.opReturnValue(value).x();
  567. insertPoint->appendInstruction(std::move(constructSite));
  568. }
  569. uint32_t ModuleBuilder::createExtInst(uint32_t resultType, uint32_t setId,
  570. uint32_t instId,
  571. llvm::ArrayRef<uint32_t> operands) {
  572. assert(insertPoint && "null insert point");
  573. uint32_t resultId = theContext.takeNextId();
  574. instBuilder.opExtInst(resultType, resultId, setId, instId, operands).x();
  575. insertPoint->appendInstruction(std::move(constructSite));
  576. return resultId;
  577. }
  578. void ModuleBuilder::createBarrier(uint32_t execution, uint32_t memory,
  579. uint32_t semantics) {
  580. assert(insertPoint && "null insert point");
  581. if (execution)
  582. instBuilder.opControlBarrier(execution, memory, semantics).x();
  583. else
  584. instBuilder.opMemoryBarrier(memory, semantics).x();
  585. insertPoint->appendInstruction(std::move(constructSite));
  586. }
  587. uint32_t ModuleBuilder::createBitFieldExtract(uint32_t resultType,
  588. uint32_t base, uint32_t offset,
  589. uint32_t count, bool isSigned) {
  590. assert(insertPoint && "null insert point");
  591. uint32_t resultId = theContext.takeNextId();
  592. if (isSigned)
  593. instBuilder.opBitFieldSExtract(resultType, resultId, base, offset, count);
  594. else
  595. instBuilder.opBitFieldUExtract(resultType, resultId, base, offset, count);
  596. instBuilder.x();
  597. insertPoint->appendInstruction(std::move(constructSite));
  598. return resultId;
  599. }
  600. uint32_t ModuleBuilder::createBitFieldInsert(uint32_t resultType, uint32_t base,
  601. uint32_t insert, uint32_t offset,
  602. uint32_t count) {
  603. assert(insertPoint && "null insert point");
  604. uint32_t resultId = theContext.takeNextId();
  605. instBuilder
  606. .opBitFieldInsert(resultType, resultId, base, insert, offset, count)
  607. .x();
  608. insertPoint->appendInstruction(std::move(constructSite));
  609. return resultId;
  610. }
  611. void ModuleBuilder::createEmitVertex() {
  612. assert(insertPoint && "null insert point");
  613. instBuilder.opEmitVertex().x();
  614. insertPoint->appendInstruction(std::move(constructSite));
  615. }
  616. void ModuleBuilder::createEndPrimitive() {
  617. assert(insertPoint && "null insert point");
  618. instBuilder.opEndPrimitive().x();
  619. insertPoint->appendInstruction(std::move(constructSite));
  620. }
  621. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  622. spv::ExecutionMode em,
  623. llvm::ArrayRef<uint32_t> params) {
  624. instBuilder.opExecutionMode(entryPointId, em);
  625. for (const auto &param : params) {
  626. instBuilder.literalInteger(param);
  627. }
  628. instBuilder.x();
  629. theModule.addExecutionMode(std::move(constructSite));
  630. }
  631. void ModuleBuilder::addExtension(Extension ext, llvm::StringRef target,
  632. SourceLocation srcLoc) {
  633. assert(featureManager);
  634. featureManager->requestExtension(ext, target, srcLoc);
  635. // Do not emit OpExtension if the given extension is natively supported in the
  636. // target environment.
  637. if (featureManager->isExtensionRequiredForTargetEnv(ext))
  638. theModule.addExtension(featureManager->getExtensionName(ext));
  639. }
  640. uint32_t ModuleBuilder::getGLSLExtInstSet() {
  641. if (glslExtSetId == 0) {
  642. glslExtSetId = theContext.takeNextId();
  643. theModule.addExtInstSet(glslExtSetId, "GLSL.std.450");
  644. }
  645. return glslExtSetId;
  646. }
  647. uint32_t ModuleBuilder::addStageIOVar(uint32_t type,
  648. spv::StorageClass storageClass,
  649. std::string name) {
  650. const uint32_t pointerType = getPointerType(type, storageClass);
  651. const uint32_t varId = theContext.takeNextId();
  652. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  653. theModule.addVariable(std::move(constructSite));
  654. theModule.addDebugName(varId, name);
  655. return varId;
  656. }
  657. uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
  658. spv::BuiltIn builtin) {
  659. const uint32_t pointerType = getPointerType(type, sc);
  660. const uint32_t varId = theContext.takeNextId();
  661. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  662. theModule.addVariable(std::move(constructSite));
  663. // Decorate with the specified Builtin
  664. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  665. theModule.addDecoration(d, varId);
  666. return varId;
  667. }
  668. uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
  669. llvm::StringRef name,
  670. llvm::Optional<uint32_t> init) {
  671. assert(sc != spv::StorageClass::Function);
  672. // TODO: basically duplicated code of addFileVar()
  673. const uint32_t pointerType = getPointerType(type, sc);
  674. const uint32_t varId = theContext.takeNextId();
  675. instBuilder.opVariable(pointerType, varId, sc, init).x();
  676. theModule.addVariable(std::move(constructSite));
  677. theModule.addDebugName(varId, name);
  678. return varId;
  679. }
  680. void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
  681. uint32_t bindingNumber) {
  682. const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
  683. theModule.addDecoration(d, targetId);
  684. d = Decoration::getBinding(theContext, bindingNumber);
  685. theModule.addDecoration(d, targetId);
  686. }
  687. void ModuleBuilder::decorateInputAttachmentIndex(uint32_t targetId,
  688. uint32_t indexNumber) {
  689. const auto *d = Decoration::getInputAttachmentIndex(theContext, indexNumber);
  690. theModule.addDecoration(d, targetId);
  691. }
  692. void ModuleBuilder::decorateCounterBufferId(uint32_t mainBufferId,
  693. uint32_t counterBufferId) {
  694. if (allowReflect) {
  695. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  696. {});
  697. theModule.addDecoration(
  698. Decoration::getHlslCounterBufferGOOGLE(theContext, counterBufferId),
  699. mainBufferId);
  700. }
  701. }
  702. void ModuleBuilder::decorateHlslSemantic(uint32_t targetId,
  703. llvm::StringRef semantic,
  704. llvm::Optional<uint32_t> memberIdx) {
  705. if (allowReflect) {
  706. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  707. {});
  708. theModule.addDecoration(
  709. Decoration::getHlslSemanticGOOGLE(theContext, semantic, memberIdx),
  710. targetId);
  711. }
  712. }
  713. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  714. const Decoration *d =
  715. Decoration::getLocation(theContext, location, llvm::None);
  716. theModule.addDecoration(d, targetId);
  717. }
  718. void ModuleBuilder::decorateIndex(uint32_t targetId, uint32_t index) {
  719. const Decoration *d = Decoration::getIndex(theContext, index);
  720. theModule.addDecoration(d, targetId);
  721. }
  722. void ModuleBuilder::decorateSpecId(uint32_t targetId, uint32_t specId) {
  723. const Decoration *d = Decoration::getSpecId(theContext, specId);
  724. theModule.addDecoration(d, targetId);
  725. }
  726. void ModuleBuilder::decorateCentroid(uint32_t targetId) {
  727. const Decoration *d = Decoration::getCentroid(theContext);
  728. theModule.addDecoration(d, targetId);
  729. }
  730. void ModuleBuilder::decorateFlat(uint32_t targetId) {
  731. const Decoration *d = Decoration::getFlat(theContext);
  732. theModule.addDecoration(d, targetId);
  733. }
  734. void ModuleBuilder::decorateNoPerspective(uint32_t targetId) {
  735. const Decoration *d = Decoration::getNoPerspective(theContext);
  736. theModule.addDecoration(d, targetId);
  737. }
  738. void ModuleBuilder::decorateSample(uint32_t targetId) {
  739. const Decoration *d = Decoration::getSample(theContext);
  740. theModule.addDecoration(d, targetId);
  741. }
  742. void ModuleBuilder::decorateBlock(uint32_t targetId) {
  743. const Decoration *d = Decoration::getBlock(theContext);
  744. theModule.addDecoration(d, targetId);
  745. }
  746. void ModuleBuilder::decorateRelaxedPrecision(uint32_t targetId) {
  747. const Decoration *d = Decoration::getRelaxedPrecision(theContext);
  748. theModule.addDecoration(d, targetId);
  749. }
  750. void ModuleBuilder::decoratePatch(uint32_t targetId) {
  751. const Decoration *d = Decoration::getPatch(theContext);
  752. theModule.addDecoration(d, targetId);
  753. }
  754. void ModuleBuilder::decorateNonUniformEXT(uint32_t targetId) {
  755. const Decoration *d = Decoration::getNonUniformEXT(theContext);
  756. theModule.addDecoration(d, targetId);
  757. }
  758. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  759. \
  760. uint32_t ModuleBuilder::get##ty##Type() { \
  761. const Type *type = Type::get##ty(theContext); \
  762. const uint32_t typeId = theContext.getResultIdForType(type); \
  763. theModule.addType(type, typeId); \
  764. return typeId; \
  765. }
  766. IMPL_GET_PRIMITIVE_TYPE(Void)
  767. IMPL_GET_PRIMITIVE_TYPE(Bool)
  768. IMPL_GET_PRIMITIVE_TYPE(Int32)
  769. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  770. IMPL_GET_PRIMITIVE_TYPE(Float32)
  771. #undef IMPL_GET_PRIMITIVE_TYPE
  772. // Note: At the moment, Float16 capability should not be added for Vulkan 1.0.
  773. // It is not a required capability, and adding the SPV_AMD_gpu_half_float does
  774. // not enable this capability. Any driver that supports float16 in Vulkan 1.0
  775. // should accept this extension.
  776. #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap) \
  777. \
  778. uint32_t ModuleBuilder::get##ty##Type() { \
  779. if (spv::Capability::cap == spv::Capability::Float16) \
  780. addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float", {}); \
  781. else \
  782. requireCapability(spv::Capability::cap); \
  783. const Type *type = Type::get##ty(theContext); \
  784. const uint32_t typeId = theContext.getResultIdForType(type); \
  785. theModule.addType(type, typeId); \
  786. return typeId; \
  787. }
  788. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
  789. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
  790. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
  791. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int16, Int16)
  792. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint16, Int16)
  793. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float16, Float16)
  794. #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
  795. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  796. const Type *type = nullptr;
  797. switch (elemCount) {
  798. case 2:
  799. type = Type::getVec2(theContext, elemType);
  800. break;
  801. case 3:
  802. type = Type::getVec3(theContext, elemType);
  803. break;
  804. case 4:
  805. type = Type::getVec4(theContext, elemType);
  806. break;
  807. default:
  808. assert(false && "unhandled vector size");
  809. // Error found. Return 0 as the <result-id> directly.
  810. return 0;
  811. }
  812. const uint32_t typeId = theContext.getResultIdForType(type);
  813. theModule.addType(type, typeId);
  814. return typeId;
  815. }
  816. uint32_t ModuleBuilder::getMatType(QualType elemType, uint32_t colType,
  817. uint32_t colCount,
  818. Type::DecorationSet decorations) {
  819. // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
  820. // Validation Rules":
  821. // Matrix types can only be parameterized with floating-point types.
  822. //
  823. // So we need special handling of non-fp matrices. We emulate non-fp
  824. // matrices as an array of vectors.
  825. if (!elemType->isFloatingType())
  826. return getArrayType(colType, getConstantUint32(colCount), decorations);
  827. const Type *type = Type::getMatrix(theContext, colType, colCount);
  828. const uint32_t typeId = theContext.getResultIdForType(type);
  829. theModule.addType(type, typeId);
  830. return typeId;
  831. }
  832. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  833. spv::StorageClass storageClass) {
  834. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  835. const uint32_t typeId = theContext.getResultIdForType(type);
  836. theModule.addType(type, typeId);
  837. return typeId;
  838. }
  839. uint32_t
  840. ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
  841. llvm::StringRef structName,
  842. llvm::ArrayRef<llvm::StringRef> fieldNames,
  843. Type::DecorationSet decorations) {
  844. const Type *type =
  845. Type::getStruct(theContext, fieldTypes, structName, decorations);
  846. bool isRegistered = false;
  847. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  848. theModule.addType(type, typeId);
  849. if (!isRegistered) {
  850. theModule.addDebugName(typeId, structName);
  851. if (!fieldNames.empty()) {
  852. assert(fieldNames.size() == fieldTypes.size());
  853. for (uint32_t i = 0; i < fieldNames.size(); ++i)
  854. theModule.addDebugName(typeId, fieldNames[i],
  855. llvm::Optional<uint32_t>(i));
  856. }
  857. }
  858. return typeId;
  859. }
  860. uint32_t ModuleBuilder::getSparseResidencyStructType(uint32_t type) {
  861. const auto uintType = getUint32Type();
  862. return getStructType({uintType, type}, "SparseResidencyStruct",
  863. {"Residency.Code", "Result.Type"});
  864. }
  865. uint32_t ModuleBuilder::getArrayType(uint32_t elemType, uint32_t count,
  866. Type::DecorationSet decorations) {
  867. const Type *type = Type::getArray(theContext, elemType, count, decorations);
  868. const uint32_t typeId = theContext.getResultIdForType(type);
  869. theModule.addType(type, typeId);
  870. return typeId;
  871. }
  872. uint32_t ModuleBuilder::getRuntimeArrayType(uint32_t elemType,
  873. Type::DecorationSet decorations) {
  874. const Type *type = Type::getRuntimeArray(theContext, elemType, decorations);
  875. const uint32_t typeId = theContext.getResultIdForType(type);
  876. theModule.addType(type, typeId);
  877. return typeId;
  878. }
  879. uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
  880. llvm::ArrayRef<uint32_t> paramTypes) {
  881. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  882. const uint32_t typeId = theContext.getResultIdForType(type);
  883. theModule.addType(type, typeId);
  884. return typeId;
  885. }
  886. uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
  887. uint32_t depth, bool isArray, uint32_t ms,
  888. uint32_t sampled,
  889. spv::ImageFormat format) {
  890. const Type *type = Type::getImage(theContext, sampledType, dim, depth,
  891. isArray, ms, sampled, format);
  892. bool isRegistered = false;
  893. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  894. theModule.addType(type, typeId);
  895. switch (format) {
  896. case spv::ImageFormat::Rg32f:
  897. case spv::ImageFormat::Rg16f:
  898. case spv::ImageFormat::R11fG11fB10f:
  899. case spv::ImageFormat::R16f:
  900. case spv::ImageFormat::Rgba16:
  901. case spv::ImageFormat::Rgb10A2:
  902. case spv::ImageFormat::Rg16:
  903. case spv::ImageFormat::Rg8:
  904. case spv::ImageFormat::R16:
  905. case spv::ImageFormat::R8:
  906. case spv::ImageFormat::Rgba16Snorm:
  907. case spv::ImageFormat::Rg16Snorm:
  908. case spv::ImageFormat::Rg8Snorm:
  909. case spv::ImageFormat::R16Snorm:
  910. case spv::ImageFormat::R8Snorm:
  911. case spv::ImageFormat::Rg32i:
  912. case spv::ImageFormat::Rg16i:
  913. case spv::ImageFormat::Rg8i:
  914. case spv::ImageFormat::R16i:
  915. case spv::ImageFormat::R8i:
  916. case spv::ImageFormat::Rgb10a2ui:
  917. case spv::ImageFormat::Rg32ui:
  918. case spv::ImageFormat::Rg16ui:
  919. case spv::ImageFormat::Rg8ui:
  920. case spv::ImageFormat::R16ui:
  921. case spv::ImageFormat::R8ui:
  922. requireCapability(spv::Capability::StorageImageExtendedFormats);
  923. break;
  924. default:
  925. // Only image formats requiring extended formats are relevant. The rest just
  926. // pass through.
  927. break;
  928. }
  929. if (dim == spv::Dim::Dim1D) {
  930. if (sampled == 2u) {
  931. requireCapability(spv::Capability::Image1D);
  932. } else {
  933. requireCapability(spv::Capability::Sampled1D);
  934. }
  935. } else if (dim == spv::Dim::Buffer) {
  936. requireCapability(spv::Capability::SampledBuffer);
  937. } else if (dim == spv::Dim::SubpassData) {
  938. requireCapability(spv::Capability::InputAttachment);
  939. }
  940. if (isArray && ms) {
  941. requireCapability(spv::Capability::ImageMSArray);
  942. }
  943. // Skip constructing the debug name if we have already done it before.
  944. if (!isRegistered) {
  945. const char *dimStr = "";
  946. switch (dim) {
  947. case spv::Dim::Dim1D:
  948. dimStr = "1d.";
  949. break;
  950. case spv::Dim::Dim2D:
  951. dimStr = "2d.";
  952. break;
  953. case spv::Dim::Dim3D:
  954. dimStr = "3d.";
  955. break;
  956. case spv::Dim::Cube:
  957. dimStr = "cube.";
  958. break;
  959. case spv::Dim::Rect:
  960. dimStr = "rect.";
  961. break;
  962. case spv::Dim::Buffer:
  963. dimStr = "buffer.";
  964. break;
  965. case spv::Dim::SubpassData:
  966. dimStr = "subpass.";
  967. break;
  968. default:
  969. break;
  970. }
  971. std::string name =
  972. std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
  973. theModule.addDebugName(typeId, name);
  974. }
  975. return typeId;
  976. }
  977. uint32_t ModuleBuilder::getSamplerType() {
  978. const Type *type = Type::getSampler(theContext);
  979. const uint32_t typeId = theContext.getResultIdForType(type);
  980. theModule.addType(type, typeId);
  981. theModule.addDebugName(typeId, "type.sampler");
  982. return typeId;
  983. }
  984. uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
  985. const Type *type = Type::getSampledImage(theContext, imageType);
  986. const uint32_t typeId = theContext.getResultIdForType(type);
  987. theModule.addType(type, typeId);
  988. theModule.addDebugName(typeId, "type.sampled.image");
  989. return typeId;
  990. }
  991. uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
  992. // Create a uint RuntimeArray with Array Stride of 4.
  993. const uint32_t uintType = getUint32Type();
  994. const auto *arrStride4 = Decoration::getArrayStride(theContext, 4u);
  995. const Type *raType =
  996. Type::getRuntimeArray(theContext, uintType, {arrStride4});
  997. const uint32_t raTypeId = theContext.getResultIdForType(raType);
  998. theModule.addType(raType, raTypeId);
  999. // Create a struct containing the runtime array as its only member.
  1000. // The struct must also be decorated as BufferBlock. The offset decoration
  1001. // should also be applied to the first (only) member. NonWritable decoration
  1002. // should also be applied to the first member if isRW is true.
  1003. llvm::SmallVector<const Decoration *, 3> typeDecs;
  1004. typeDecs.push_back(Decoration::getBufferBlock(theContext));
  1005. typeDecs.push_back(Decoration::getOffset(theContext, 0, 0));
  1006. if (!isRW)
  1007. typeDecs.push_back(Decoration::getNonWritable(theContext, 0));
  1008. const Type *type = Type::getStruct(theContext, {raTypeId}, "", typeDecs);
  1009. const uint32_t typeId = theContext.getResultIdForType(type);
  1010. theModule.addType(type, typeId);
  1011. theModule.addDebugName(typeId, isRW ? "type.RWByteAddressBuffer"
  1012. : "type.ByteAddressBuffer");
  1013. return typeId;
  1014. }
  1015. uint32_t ModuleBuilder::getConstantBool(bool value, bool isSpecConst) {
  1016. if (isSpecConst) {
  1017. const uint32_t constId = theContext.takeNextId();
  1018. if (value) {
  1019. instBuilder.opSpecConstantTrue(getBoolType(), constId).x();
  1020. } else {
  1021. instBuilder.opSpecConstantFalse(getBoolType(), constId).x();
  1022. }
  1023. theModule.addVariable(std::move(constructSite));
  1024. return constId;
  1025. }
  1026. const uint32_t typeId = getBoolType();
  1027. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  1028. : Constant::getFalse(theContext, typeId);
  1029. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1030. theModule.addConstant(constant, constId);
  1031. return constId;
  1032. }
  1033. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  1034. \
  1035. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  1036. const uint32_t typeId = get##builderTy##Type(); \
  1037. const Constant *constant = \
  1038. Constant::get##builderTy(theContext, typeId, value); \
  1039. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1040. theModule.addConstant(constant, constId); \
  1041. return constId; \
  1042. }
  1043. #define IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(builderTy, cppTy) \
  1044. \
  1045. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value, \
  1046. bool isSpecConst) { \
  1047. if (isSpecConst) { \
  1048. const uint32_t constId = theContext.takeNextId(); \
  1049. instBuilder \
  1050. .opSpecConstant(get##builderTy##Type(), constId, \
  1051. cast::BitwiseCast<uint32_t>(value)) \
  1052. .x(); \
  1053. theModule.addVariable(std::move(constructSite)); \
  1054. return constId; \
  1055. } \
  1056. \
  1057. const uint32_t typeId = get##builderTy##Type(); \
  1058. const Constant *constant = \
  1059. Constant::get##builderTy(theContext, typeId, value); \
  1060. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1061. theModule.addConstant(constant, constId); \
  1062. return constId; \
  1063. }
  1064. IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
  1065. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Int32, int32_t)
  1066. IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
  1067. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Uint32, uint32_t)
  1068. IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
  1069. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Float32, float)
  1070. IMPL_GET_PRIMITIVE_CONST(Float64, double)
  1071. IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
  1072. IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
  1073. #undef IMPL_GET_PRIMITIVE_CONST
  1074. #undef IMPL_GET_PRIMITIVE_CONST_SPEC_CONST
  1075. uint32_t
  1076. ModuleBuilder::getConstantComposite(uint32_t typeId,
  1077. llvm::ArrayRef<uint32_t> constituents) {
  1078. const Constant *constant =
  1079. Constant::getComposite(theContext, typeId, constituents);
  1080. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1081. theModule.addConstant(constant, constId);
  1082. return constId;
  1083. }
  1084. uint32_t ModuleBuilder::getConstantNull(uint32_t typeId) {
  1085. const Constant *constant = Constant::getNull(theContext, typeId);
  1086. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1087. theModule.addConstant(constant, constId);
  1088. return constId;
  1089. }
  1090. void ModuleBuilder::debugLine(uint32_t file, uint32_t line, uint32_t column) {
  1091. instBuilder.opLine(file, line, column).x();
  1092. insertPoint->appendInstruction(std::move(constructSite));
  1093. }
  1094. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  1095. auto it = basicBlocks.find(labelId);
  1096. if (it == basicBlocks.end()) {
  1097. assert(false && "invalid <label-id>");
  1098. return nullptr;
  1099. }
  1100. return it->second.get();
  1101. }
  1102. } // end namespace spirv
  1103. } // end namespace clang