ModuleBuilder.cpp 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "TypeTranslator.h"
  11. #include "spirv/unified1//spirv.hpp11"
  12. #include "clang/SPIRV/BitwiseCast.h"
  13. #include "clang/SPIRV/InstBuilder.h"
  14. namespace clang {
  15. namespace spirv {
  16. ModuleBuilder::ModuleBuilder(SPIRVContext *C, FeatureManager *features,
  17. bool reflect)
  18. : theContext(*C), featureManager(features), allowReflect(reflect),
  19. theModule(), theFunction(nullptr), insertPoint(nullptr),
  20. instBuilder(nullptr), glslExtSetId(0) {
  21. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  22. this->constructSite = std::move(words);
  23. });
  24. // Set the SPIR-V version if needed.
  25. if (featureManager && featureManager->getTargetEnv() == SPV_ENV_VULKAN_1_1)
  26. theModule.setVersion(0x00010300);
  27. }
  28. std::vector<uint32_t> ModuleBuilder::takeModule() {
  29. theModule.setBound(theContext.getNextId());
  30. std::vector<uint32_t> binary;
  31. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  32. binary.insert(binary.end(), words.begin(), words.end());
  33. });
  34. theModule.take(&ib);
  35. return binary;
  36. }
  37. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  38. llvm::StringRef funcName, uint32_t fId) {
  39. if (theFunction) {
  40. assert(false && "found nested function");
  41. return 0;
  42. }
  43. // If the caller doesn't supply a function <result-id>, we need to get one.
  44. if (!fId)
  45. fId = theContext.takeNextId();
  46. theFunction = llvm::make_unique<Function>(
  47. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  48. theModule.addDebugName(fId, funcName);
  49. return fId;
  50. }
  51. uint32_t ModuleBuilder::addFnParam(uint32_t ptrType, llvm::StringRef name) {
  52. assert(theFunction && "found detached parameter");
  53. const uint32_t paramId = theContext.takeNextId();
  54. theFunction->addParameter(ptrType, paramId);
  55. theModule.addDebugName(paramId, name);
  56. return paramId;
  57. }
  58. uint32_t ModuleBuilder::addFnVar(uint32_t varType, llvm::StringRef name,
  59. llvm::Optional<uint32_t> init) {
  60. assert(theFunction && "found detached local variable");
  61. const uint32_t ptrType = getPointerType(varType, spv::StorageClass::Function);
  62. const uint32_t varId = theContext.takeNextId();
  63. theFunction->addVariable(ptrType, varId, init);
  64. theModule.addDebugName(varId, name);
  65. return varId;
  66. }
  67. bool ModuleBuilder::endFunction() {
  68. if (theFunction == nullptr) {
  69. assert(false && "no active function");
  70. return false;
  71. }
  72. // Move all basic blocks into the current function.
  73. // TODO: we should adjust the order the basic blocks according to
  74. // SPIR-V validation rules.
  75. for (auto &bb : basicBlocks) {
  76. theFunction->addBasicBlock(std::move(bb.second));
  77. }
  78. basicBlocks.clear();
  79. theModule.addFunction(std::move(theFunction));
  80. theFunction.reset(nullptr);
  81. insertPoint = nullptr;
  82. return true;
  83. }
  84. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  85. if (theFunction == nullptr) {
  86. assert(false && "found detached basic block");
  87. return 0;
  88. }
  89. const uint32_t labelId = theContext.takeNextId();
  90. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId, name);
  91. return labelId;
  92. }
  93. void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
  94. assert(insertPoint && "null insert point");
  95. insertPoint->addSuccessor(getBasicBlock(successorLabel));
  96. }
  97. void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
  98. assert(insertPoint && "null insert point");
  99. insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
  100. }
  101. void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
  102. assert(insertPoint && "null insert point");
  103. insertPoint->setContinueTarget(getBasicBlock(continueLabel));
  104. }
  105. void ModuleBuilder::setInsertPoint(uint32_t labelId) {
  106. insertPoint = getBasicBlock(labelId);
  107. }
  108. uint32_t
  109. ModuleBuilder::createCompositeConstruct(uint32_t resultType,
  110. llvm::ArrayRef<uint32_t> constituents) {
  111. assert(insertPoint && "null insert point");
  112. const uint32_t resultId = theContext.takeNextId();
  113. instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
  114. insertPoint->appendInstruction(std::move(constructSite));
  115. return resultId;
  116. }
  117. uint32_t
  118. ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
  119. llvm::ArrayRef<uint32_t> indexes) {
  120. assert(insertPoint && "null insert point");
  121. const uint32_t resultId = theContext.takeNextId();
  122. instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
  123. insertPoint->appendInstruction(std::move(constructSite));
  124. return resultId;
  125. }
  126. uint32_t ModuleBuilder::createCompositeInsert(uint32_t resultType,
  127. uint32_t composite,
  128. llvm::ArrayRef<uint32_t> indices,
  129. uint32_t object) {
  130. assert(insertPoint && "null insert point");
  131. const uint32_t resultId = theContext.takeNextId();
  132. instBuilder
  133. .opCompositeInsert(resultType, resultId, object, composite, indices)
  134. .x();
  135. insertPoint->appendInstruction(std::move(constructSite));
  136. return resultId;
  137. }
  138. uint32_t
  139. ModuleBuilder::createVectorShuffle(uint32_t resultType, uint32_t vector1,
  140. uint32_t vector2,
  141. llvm::ArrayRef<uint32_t> selectors) {
  142. assert(insertPoint && "null insert point");
  143. const uint32_t resultId = theContext.takeNextId();
  144. instBuilder.opVectorShuffle(resultType, resultId, vector1, vector2, selectors)
  145. .x();
  146. insertPoint->appendInstruction(std::move(constructSite));
  147. return resultId;
  148. }
  149. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  150. assert(insertPoint && "null insert point");
  151. const uint32_t resultId = theContext.takeNextId();
  152. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  153. insertPoint->appendInstruction(std::move(constructSite));
  154. return resultId;
  155. }
  156. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  157. assert(insertPoint && "null insert point");
  158. instBuilder.opStore(address, value, llvm::None).x();
  159. insertPoint->appendInstruction(std::move(constructSite));
  160. }
  161. uint32_t ModuleBuilder::createFunctionCall(uint32_t returnType,
  162. uint32_t functionId,
  163. llvm::ArrayRef<uint32_t> params) {
  164. assert(insertPoint && "null insert point");
  165. const uint32_t id = theContext.takeNextId();
  166. instBuilder.opFunctionCall(returnType, id, functionId, params).x();
  167. insertPoint->appendInstruction(std::move(constructSite));
  168. return id;
  169. }
  170. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  171. llvm::ArrayRef<uint32_t> indexes) {
  172. assert(insertPoint && "null insert point");
  173. const uint32_t id = theContext.takeNextId();
  174. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  175. insertPoint->appendInstruction(std::move(constructSite));
  176. return id;
  177. }
  178. uint32_t ModuleBuilder::createUnaryOp(spv::Op op, uint32_t resultType,
  179. uint32_t operand) {
  180. assert(insertPoint && "null insert point");
  181. const uint32_t id = theContext.takeNextId();
  182. instBuilder.unaryOp(op, resultType, id, operand).x();
  183. insertPoint->appendInstruction(std::move(constructSite));
  184. switch (op) {
  185. case spv::Op::OpImageQuerySize:
  186. case spv::Op::OpImageQueryLevels:
  187. case spv::Op::OpImageQuerySamples:
  188. requireCapability(spv::Capability::ImageQuery);
  189. break;
  190. default:
  191. // Only checking for ImageQueries, the other Ops can be ignored.
  192. break;
  193. }
  194. return id;
  195. }
  196. uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
  197. uint32_t lhs, uint32_t rhs) {
  198. assert(insertPoint && "null insert point");
  199. const uint32_t id = theContext.takeNextId();
  200. instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
  201. insertPoint->appendInstruction(std::move(constructSite));
  202. switch (op) {
  203. case spv::Op::OpImageQueryLod:
  204. case spv::Op::OpImageQuerySizeLod:
  205. requireCapability(spv::Capability::ImageQuery);
  206. break;
  207. default:
  208. // Only checking for ImageQueries, the other Ops can be ignored.
  209. break;
  210. }
  211. return id;
  212. }
  213. uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
  214. uint32_t resultType,
  215. uint32_t lhs, uint32_t rhs) {
  216. const uint32_t id = theContext.takeNextId();
  217. instBuilder.specConstantBinaryOp(op, resultType, id, lhs, rhs).x();
  218. theModule.addVariable(std::move(constructSite));
  219. return id;
  220. }
  221. uint32_t ModuleBuilder::createGroupNonUniformOp(spv::Op op, uint32_t resultType,
  222. uint32_t execScope) {
  223. assert(insertPoint && "null insert point");
  224. const uint32_t id = theContext.takeNextId();
  225. instBuilder.groupNonUniformOp(op, resultType, id, execScope).x();
  226. insertPoint->appendInstruction(std::move(constructSite));
  227. return id;
  228. }
  229. uint32_t ModuleBuilder::createGroupNonUniformUnaryOp(
  230. spv::Op op, uint32_t resultType, uint32_t execScope, uint32_t operand,
  231. llvm::Optional<spv::GroupOperation> groupOp) {
  232. assert(insertPoint && "null insert point");
  233. const uint32_t id = theContext.takeNextId();
  234. instBuilder
  235. .groupNonUniformUnaryOp(op, resultType, id, execScope, groupOp, operand)
  236. .x();
  237. insertPoint->appendInstruction(std::move(constructSite));
  238. return id;
  239. }
  240. uint32_t ModuleBuilder::createGroupNonUniformBinaryOp(spv::Op op,
  241. uint32_t resultType,
  242. uint32_t execScope,
  243. uint32_t operand1,
  244. uint32_t operand2) {
  245. assert(insertPoint && "null insert point");
  246. const uint32_t id = theContext.takeNextId();
  247. instBuilder
  248. .groupNonUniformBinaryOp(op, resultType, id, execScope, operand1,
  249. operand2)
  250. .x();
  251. insertPoint->appendInstruction(std::move(constructSite));
  252. return id;
  253. }
  254. uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
  255. uint32_t orignalValuePtr,
  256. uint32_t scopeId,
  257. uint32_t memorySemanticsId,
  258. uint32_t valueToOp) {
  259. assert(insertPoint && "null insert point");
  260. const uint32_t id = theContext.takeNextId();
  261. switch (opcode) {
  262. case spv::Op::OpAtomicIAdd:
  263. instBuilder.opAtomicIAdd(resultType, id, orignalValuePtr, scopeId,
  264. memorySemanticsId, valueToOp);
  265. break;
  266. case spv::Op::OpAtomicISub:
  267. instBuilder.opAtomicISub(resultType, id, orignalValuePtr, scopeId,
  268. memorySemanticsId, valueToOp);
  269. break;
  270. case spv::Op::OpAtomicAnd:
  271. instBuilder.opAtomicAnd(resultType, id, orignalValuePtr, scopeId,
  272. memorySemanticsId, valueToOp);
  273. break;
  274. case spv::Op::OpAtomicOr:
  275. instBuilder.opAtomicOr(resultType, id, orignalValuePtr, scopeId,
  276. memorySemanticsId, valueToOp);
  277. break;
  278. case spv::Op::OpAtomicXor:
  279. instBuilder.opAtomicXor(resultType, id, orignalValuePtr, scopeId,
  280. memorySemanticsId, valueToOp);
  281. break;
  282. case spv::Op::OpAtomicUMax:
  283. instBuilder.opAtomicUMax(resultType, id, orignalValuePtr, scopeId,
  284. memorySemanticsId, valueToOp);
  285. break;
  286. case spv::Op::OpAtomicUMin:
  287. instBuilder.opAtomicUMin(resultType, id, orignalValuePtr, scopeId,
  288. memorySemanticsId, valueToOp);
  289. break;
  290. case spv::Op::OpAtomicSMax:
  291. instBuilder.opAtomicSMax(resultType, id, orignalValuePtr, scopeId,
  292. memorySemanticsId, valueToOp);
  293. break;
  294. case spv::Op::OpAtomicSMin:
  295. instBuilder.opAtomicSMin(resultType, id, orignalValuePtr, scopeId,
  296. memorySemanticsId, valueToOp);
  297. break;
  298. case spv::Op::OpAtomicExchange:
  299. instBuilder.opAtomicExchange(resultType, id, orignalValuePtr, scopeId,
  300. memorySemanticsId, valueToOp);
  301. break;
  302. default:
  303. assert(false && "unimplemented atomic opcode");
  304. }
  305. instBuilder.x();
  306. insertPoint->appendInstruction(std::move(constructSite));
  307. return id;
  308. }
  309. uint32_t ModuleBuilder::createAtomicCompareExchange(
  310. uint32_t resultType, uint32_t orignalValuePtr, uint32_t scopeId,
  311. uint32_t equalMemorySemanticsId, uint32_t unequalMemorySemanticsId,
  312. uint32_t valueToOp, uint32_t comparator) {
  313. assert(insertPoint && "null insert point");
  314. const uint32_t id = theContext.takeNextId();
  315. instBuilder.opAtomicCompareExchange(
  316. resultType, id, orignalValuePtr, scopeId, equalMemorySemanticsId,
  317. unequalMemorySemanticsId, valueToOp, comparator);
  318. instBuilder.x();
  319. insertPoint->appendInstruction(std::move(constructSite));
  320. return id;
  321. }
  322. spv::ImageOperandsMask ModuleBuilder::composeImageOperandsMask(
  323. uint32_t bias, uint32_t lod, const std::pair<uint32_t, uint32_t> &grad,
  324. uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets,
  325. uint32_t sample, uint32_t minLod,
  326. llvm::SmallVectorImpl<uint32_t> *orderedParams) {
  327. using spv::ImageOperandsMask;
  328. // SPIR-V Image Operands from least significant bit to most significant bit
  329. // Bias, Lod, Grad, ConstOffset, Offset, ConstOffsets, Sample, MinLod
  330. auto mask = ImageOperandsMask::MaskNone;
  331. orderedParams->clear();
  332. if (bias) {
  333. mask = mask | ImageOperandsMask::Bias;
  334. orderedParams->push_back(bias);
  335. }
  336. if (lod) {
  337. mask = mask | ImageOperandsMask::Lod;
  338. orderedParams->push_back(lod);
  339. }
  340. if (grad.first && grad.second) {
  341. mask = mask | ImageOperandsMask::Grad;
  342. orderedParams->push_back(grad.first);
  343. orderedParams->push_back(grad.second);
  344. }
  345. if (constOffset) {
  346. mask = mask | ImageOperandsMask::ConstOffset;
  347. orderedParams->push_back(constOffset);
  348. }
  349. if (varOffset) {
  350. mask = mask | ImageOperandsMask::Offset;
  351. requireCapability(spv::Capability::ImageGatherExtended);
  352. orderedParams->push_back(varOffset);
  353. }
  354. if (constOffsets) {
  355. mask = mask | ImageOperandsMask::ConstOffsets;
  356. requireCapability(spv::Capability::ImageGatherExtended);
  357. orderedParams->push_back(constOffsets);
  358. }
  359. if (sample) {
  360. mask = mask | ImageOperandsMask::Sample;
  361. orderedParams->push_back(sample);
  362. }
  363. if (minLod) {
  364. requireCapability(spv::Capability::MinLod);
  365. mask = mask | ImageOperandsMask::MinLod;
  366. orderedParams->push_back(minLod);
  367. }
  368. return mask;
  369. }
  370. uint32_t
  371. ModuleBuilder::createImageSparseTexelsResident(uint32_t resident_code) {
  372. assert(insertPoint && "null insert point");
  373. // Result type must be a boolean
  374. const uint32_t result_type = getBoolType();
  375. const uint32_t id = theContext.takeNextId();
  376. instBuilder.opImageSparseTexelsResident(result_type, id, resident_code).x();
  377. insertPoint->appendInstruction(std::move(constructSite));
  378. return id;
  379. }
  380. uint32_t ModuleBuilder::createImageTexelPointer(uint32_t resultType,
  381. uint32_t imageId,
  382. uint32_t coordinate,
  383. uint32_t sample) {
  384. assert(insertPoint && "null insert point");
  385. const uint32_t id = theContext.takeNextId();
  386. instBuilder.opImageTexelPointer(resultType, id, imageId, coordinate, sample)
  387. .x();
  388. insertPoint->appendInstruction(std::move(constructSite));
  389. return id;
  390. }
  391. uint32_t ModuleBuilder::createImageSample(
  392. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  393. bool isNonUniform, uint32_t coordinate, uint32_t compareVal, uint32_t bias,
  394. uint32_t lod, std::pair<uint32_t, uint32_t> grad, uint32_t constOffset,
  395. uint32_t varOffset, uint32_t constOffsets, uint32_t sample, uint32_t minLod,
  396. uint32_t residencyCodeId) {
  397. assert(insertPoint && "null insert point");
  398. // The Lod and Grad image operands requires explicit-lod instructions.
  399. // Otherwise we use implicit-lod instructions.
  400. const bool isExplicit = lod || (grad.first && grad.second);
  401. const bool isSparse = (residencyCodeId != 0);
  402. // minLod is only valid with Implicit instructions and Grad instructions.
  403. // This means that we cannot have Lod and minLod together because Lod requires
  404. // explicit insturctions. So either lod or minLod or both must be zero.
  405. assert(lod == 0 || minLod == 0);
  406. uint32_t retType = texelType;
  407. if (isSparse) {
  408. requireCapability(spv::Capability::SparseResidency);
  409. retType = getSparseResidencyStructType(texelType);
  410. }
  411. // An OpSampledImage is required to do the image sampling.
  412. const uint32_t sampledImgId = theContext.takeNextId();
  413. const uint32_t sampledImgTy = getSampledImageType(imageType);
  414. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  415. insertPoint->appendInstruction(std::move(constructSite));
  416. if (isNonUniform) {
  417. // The sampled image will be used to access resource's memory, so we need
  418. // to decorate it with NonUniformEXT.
  419. decorateNonUniformEXT(sampledImgId);
  420. }
  421. uint32_t texelId = theContext.takeNextId();
  422. llvm::SmallVector<uint32_t, 4> params;
  423. const auto mask =
  424. composeImageOperandsMask(bias, lod, grad, constOffset, varOffset,
  425. constOffsets, sample, minLod, &params);
  426. instBuilder.opImageSample(retType, texelId, sampledImgId, coordinate,
  427. compareVal, mask, isExplicit, isSparse);
  428. for (const auto param : params)
  429. instBuilder.idRef(param);
  430. instBuilder.x();
  431. insertPoint->appendInstruction(std::move(constructSite));
  432. if (isSparse) {
  433. // Write the Residency Code
  434. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  435. createStore(residencyCodeId, status);
  436. // Extract the real result from the struct
  437. texelId = createCompositeExtract(texelType, texelId, {1});
  438. }
  439. return texelId;
  440. }
  441. void ModuleBuilder::createImageWrite(QualType imageType, uint32_t imageId,
  442. uint32_t coordId, uint32_t texelId) {
  443. assert(insertPoint && "null insert point");
  444. requireCapability(
  445. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  446. instBuilder.opImageWrite(imageId, coordId, texelId, llvm::None).x();
  447. insertPoint->appendInstruction(std::move(constructSite));
  448. }
  449. uint32_t ModuleBuilder::createImageFetchOrRead(
  450. bool doImageFetch, uint32_t texelType, QualType imageType, uint32_t image,
  451. uint32_t coordinate, uint32_t lod, uint32_t constOffset, uint32_t varOffset,
  452. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  453. assert(insertPoint && "null insert point");
  454. llvm::SmallVector<uint32_t, 2> params;
  455. const auto mask =
  456. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  457. /*bias*/ 0, lod, std::make_pair(0, 0), constOffset, varOffset,
  458. constOffsets, sample, /*minLod*/ 0, &params));
  459. const bool isSparse = (residencyCodeId != 0);
  460. uint32_t retType = texelType;
  461. if (isSparse) {
  462. requireCapability(spv::Capability::SparseResidency);
  463. retType = getSparseResidencyStructType(texelType);
  464. }
  465. if (!doImageFetch) {
  466. requireCapability(
  467. TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
  468. }
  469. uint32_t texelId = theContext.takeNextId();
  470. instBuilder.opImageFetchRead(retType, texelId, image, coordinate, mask,
  471. doImageFetch, isSparse);
  472. for (const auto param : params)
  473. instBuilder.idRef(param);
  474. instBuilder.x();
  475. insertPoint->appendInstruction(std::move(constructSite));
  476. if (isSparse) {
  477. // Write the Residency Code
  478. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  479. createStore(residencyCodeId, status);
  480. // Extract the real result from the struct
  481. texelId = createCompositeExtract(texelType, texelId, {1});
  482. }
  483. return texelId;
  484. }
  485. uint32_t ModuleBuilder::createImageGather(
  486. uint32_t texelType, uint32_t imageType, uint32_t image, uint32_t sampler,
  487. bool isNonUniform, uint32_t coordinate, uint32_t component,
  488. uint32_t compareVal, uint32_t constOffset, uint32_t varOffset,
  489. uint32_t constOffsets, uint32_t sample, uint32_t residencyCodeId) {
  490. assert(insertPoint && "null insert point");
  491. uint32_t sparseRetType = 0;
  492. if (residencyCodeId) {
  493. requireCapability(spv::Capability::SparseResidency);
  494. sparseRetType = getSparseResidencyStructType(texelType);
  495. }
  496. // An OpSampledImage is required to do the image sampling.
  497. const uint32_t sampledImgId = theContext.takeNextId();
  498. const uint32_t sampledImgTy = getSampledImageType(imageType);
  499. instBuilder.opSampledImage(sampledImgTy, sampledImgId, image, sampler).x();
  500. insertPoint->appendInstruction(std::move(constructSite));
  501. if (isNonUniform) {
  502. // The sampled image will be used to access resource's memory, so we need
  503. // to decorate it with NonUniformEXT.
  504. decorateNonUniformEXT(sampledImgId);
  505. }
  506. llvm::SmallVector<uint32_t, 2> params;
  507. // TODO: Update ImageGather to accept minLod if necessary.
  508. const auto mask =
  509. llvm::Optional<spv::ImageOperandsMask>(composeImageOperandsMask(
  510. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  511. constOffsets, sample, /*minLod*/ 0, &params));
  512. uint32_t texelId = theContext.takeNextId();
  513. if (compareVal) {
  514. if (residencyCodeId) {
  515. // Note: OpImageSparseDrefGather does not take the component parameter.
  516. instBuilder.opImageSparseDrefGather(sparseRetType, texelId, sampledImgId,
  517. coordinate, compareVal, mask);
  518. } else {
  519. // Note: OpImageDrefGather does not take the component parameter.
  520. instBuilder.opImageDrefGather(texelType, texelId, sampledImgId,
  521. coordinate, compareVal, mask);
  522. }
  523. } else {
  524. if (residencyCodeId) {
  525. instBuilder.opImageSparseGather(sparseRetType, texelId, sampledImgId,
  526. coordinate, component, mask);
  527. } else {
  528. instBuilder.opImageGather(texelType, texelId, sampledImgId, coordinate,
  529. component, mask);
  530. }
  531. }
  532. for (const auto param : params)
  533. instBuilder.idRef(param);
  534. instBuilder.x();
  535. insertPoint->appendInstruction(std::move(constructSite));
  536. if (residencyCodeId) {
  537. // Write the Residency Code
  538. const auto status = createCompositeExtract(getUint32Type(), texelId, {0});
  539. createStore(residencyCodeId, status);
  540. // Extract the real result from the struct
  541. texelId = createCompositeExtract(texelType, texelId, {1});
  542. }
  543. return texelId;
  544. }
  545. uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
  546. uint32_t trueValue, uint32_t falseValue) {
  547. assert(insertPoint && "null insert point");
  548. const uint32_t id = theContext.takeNextId();
  549. instBuilder.opSelect(resultType, id, condition, trueValue, falseValue).x();
  550. insertPoint->appendInstruction(std::move(constructSite));
  551. return id;
  552. }
  553. void ModuleBuilder::createSwitch(
  554. uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
  555. llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
  556. assert(insertPoint && "null insert point");
  557. // Create the OpSelectioMerege.
  558. instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
  559. .x();
  560. insertPoint->appendInstruction(std::move(constructSite));
  561. // Create the OpSwitch.
  562. instBuilder.opSwitch(selector, defaultLabel, target).x();
  563. insertPoint->appendInstruction(std::move(constructSite));
  564. }
  565. void ModuleBuilder::createKill() {
  566. assert(insertPoint && "null insert point");
  567. assert(!isCurrentBasicBlockTerminated());
  568. instBuilder.opKill().x();
  569. insertPoint->appendInstruction(std::move(constructSite));
  570. }
  571. void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
  572. uint32_t continueBB,
  573. spv::LoopControlMask loopControl) {
  574. assert(insertPoint && "null insert point");
  575. if (mergeBB && continueBB) {
  576. instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
  577. insertPoint->appendInstruction(std::move(constructSite));
  578. }
  579. instBuilder.opBranch(targetLabel).x();
  580. insertPoint->appendInstruction(std::move(constructSite));
  581. }
  582. void ModuleBuilder::createConditionalBranch(
  583. uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
  584. uint32_t mergeLabel, uint32_t continueLabel,
  585. spv::SelectionControlMask selectionControl,
  586. spv::LoopControlMask loopControl) {
  587. assert(insertPoint && "null insert point");
  588. if (mergeLabel) {
  589. if (continueLabel) {
  590. instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
  591. insertPoint->appendInstruction(std::move(constructSite));
  592. } else {
  593. instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
  594. insertPoint->appendInstruction(std::move(constructSite));
  595. }
  596. }
  597. instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
  598. insertPoint->appendInstruction(std::move(constructSite));
  599. }
  600. void ModuleBuilder::createReturn() {
  601. assert(insertPoint && "null insert point");
  602. instBuilder.opReturn().x();
  603. insertPoint->appendInstruction(std::move(constructSite));
  604. }
  605. void ModuleBuilder::createReturnValue(uint32_t value) {
  606. assert(insertPoint && "null insert point");
  607. instBuilder.opReturnValue(value).x();
  608. insertPoint->appendInstruction(std::move(constructSite));
  609. }
  610. uint32_t ModuleBuilder::createExtInst(uint32_t resultType, uint32_t setId,
  611. uint32_t instId,
  612. llvm::ArrayRef<uint32_t> operands) {
  613. assert(insertPoint && "null insert point");
  614. uint32_t resultId = theContext.takeNextId();
  615. instBuilder.opExtInst(resultType, resultId, setId, instId, operands).x();
  616. insertPoint->appendInstruction(std::move(constructSite));
  617. return resultId;
  618. }
  619. void ModuleBuilder::createBarrier(uint32_t execution, uint32_t memory,
  620. uint32_t semantics) {
  621. assert(insertPoint && "null insert point");
  622. if (execution)
  623. instBuilder.opControlBarrier(execution, memory, semantics).x();
  624. else
  625. instBuilder.opMemoryBarrier(memory, semantics).x();
  626. insertPoint->appendInstruction(std::move(constructSite));
  627. }
  628. uint32_t ModuleBuilder::createBitFieldExtract(uint32_t resultType,
  629. uint32_t base, uint32_t offset,
  630. uint32_t count, bool isSigned) {
  631. assert(insertPoint && "null insert point");
  632. uint32_t resultId = theContext.takeNextId();
  633. if (isSigned)
  634. instBuilder.opBitFieldSExtract(resultType, resultId, base, offset, count);
  635. else
  636. instBuilder.opBitFieldUExtract(resultType, resultId, base, offset, count);
  637. instBuilder.x();
  638. insertPoint->appendInstruction(std::move(constructSite));
  639. return resultId;
  640. }
  641. uint32_t ModuleBuilder::createBitFieldInsert(uint32_t resultType, uint32_t base,
  642. uint32_t insert, uint32_t offset,
  643. uint32_t count) {
  644. assert(insertPoint && "null insert point");
  645. uint32_t resultId = theContext.takeNextId();
  646. instBuilder
  647. .opBitFieldInsert(resultType, resultId, base, insert, offset, count)
  648. .x();
  649. insertPoint->appendInstruction(std::move(constructSite));
  650. return resultId;
  651. }
  652. void ModuleBuilder::createEmitVertex() {
  653. assert(insertPoint && "null insert point");
  654. instBuilder.opEmitVertex().x();
  655. insertPoint->appendInstruction(std::move(constructSite));
  656. }
  657. void ModuleBuilder::createEndPrimitive() {
  658. assert(insertPoint && "null insert point");
  659. instBuilder.opEndPrimitive().x();
  660. insertPoint->appendInstruction(std::move(constructSite));
  661. }
  662. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  663. spv::ExecutionMode em,
  664. llvm::ArrayRef<uint32_t> params) {
  665. instBuilder.opExecutionMode(entryPointId, em);
  666. for (const auto &param : params) {
  667. instBuilder.literalInteger(param);
  668. }
  669. instBuilder.x();
  670. theModule.addExecutionMode(std::move(constructSite));
  671. }
  672. void ModuleBuilder::addExtension(Extension ext, llvm::StringRef target,
  673. SourceLocation srcLoc) {
  674. assert(featureManager);
  675. featureManager->requestExtension(ext, target, srcLoc);
  676. // Do not emit OpExtension if the given extension is natively supported in the
  677. // target environment.
  678. if (featureManager->isExtensionRequiredForTargetEnv(ext))
  679. theModule.addExtension(featureManager->getExtensionName(ext));
  680. }
  681. uint32_t ModuleBuilder::getGLSLExtInstSet() {
  682. if (glslExtSetId == 0) {
  683. glslExtSetId = theContext.takeNextId();
  684. theModule.addExtInstSet(glslExtSetId, "GLSL.std.450");
  685. }
  686. return glslExtSetId;
  687. }
  688. uint32_t ModuleBuilder::addStageIOVar(uint32_t type,
  689. spv::StorageClass storageClass,
  690. std::string name) {
  691. const uint32_t pointerType = getPointerType(type, storageClass);
  692. const uint32_t varId = theContext.takeNextId();
  693. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  694. theModule.addVariable(std::move(constructSite));
  695. theModule.addDebugName(varId, name);
  696. return varId;
  697. }
  698. uint32_t ModuleBuilder::addStageBuiltinVar(uint32_t type, spv::StorageClass sc,
  699. spv::BuiltIn builtin) {
  700. const uint32_t pointerType = getPointerType(type, sc);
  701. const uint32_t varId = theContext.takeNextId();
  702. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  703. theModule.addVariable(std::move(constructSite));
  704. // Decorate with the specified Builtin
  705. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  706. theModule.addDecoration(d, varId);
  707. return varId;
  708. }
  709. uint32_t ModuleBuilder::addModuleVar(uint32_t type, spv::StorageClass sc,
  710. llvm::StringRef name,
  711. llvm::Optional<uint32_t> init) {
  712. assert(sc != spv::StorageClass::Function);
  713. // TODO: basically duplicated code of addFileVar()
  714. const uint32_t pointerType = getPointerType(type, sc);
  715. const uint32_t varId = theContext.takeNextId();
  716. instBuilder.opVariable(pointerType, varId, sc, init).x();
  717. theModule.addVariable(std::move(constructSite));
  718. theModule.addDebugName(varId, name);
  719. return varId;
  720. }
  721. void ModuleBuilder::decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
  722. uint32_t bindingNumber) {
  723. const auto *d = Decoration::getDescriptorSet(theContext, setNumber);
  724. theModule.addDecoration(d, targetId);
  725. d = Decoration::getBinding(theContext, bindingNumber);
  726. theModule.addDecoration(d, targetId);
  727. }
  728. void ModuleBuilder::decorateInputAttachmentIndex(uint32_t targetId,
  729. uint32_t indexNumber) {
  730. const auto *d = Decoration::getInputAttachmentIndex(theContext, indexNumber);
  731. theModule.addDecoration(d, targetId);
  732. }
  733. void ModuleBuilder::decorateCounterBufferId(uint32_t mainBufferId,
  734. uint32_t counterBufferId) {
  735. if (allowReflect) {
  736. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  737. {});
  738. theModule.addDecoration(
  739. Decoration::getHlslCounterBufferGOOGLE(theContext, counterBufferId),
  740. mainBufferId);
  741. }
  742. }
  743. void ModuleBuilder::decorateHlslSemantic(uint32_t targetId,
  744. llvm::StringRef semantic,
  745. llvm::Optional<uint32_t> memberIdx) {
  746. if (allowReflect) {
  747. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  748. {});
  749. theModule.addDecoration(
  750. Decoration::getHlslSemanticGOOGLE(theContext, semantic, memberIdx),
  751. targetId);
  752. }
  753. }
  754. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  755. const Decoration *d =
  756. Decoration::getLocation(theContext, location, llvm::None);
  757. theModule.addDecoration(d, targetId);
  758. }
  759. void ModuleBuilder::decorateIndex(uint32_t targetId, uint32_t index) {
  760. const Decoration *d = Decoration::getIndex(theContext, index);
  761. theModule.addDecoration(d, targetId);
  762. }
  763. void ModuleBuilder::decorateSpecId(uint32_t targetId, uint32_t specId) {
  764. const Decoration *d = Decoration::getSpecId(theContext, specId);
  765. theModule.addDecoration(d, targetId);
  766. }
  767. void ModuleBuilder::decorateCentroid(uint32_t targetId) {
  768. const Decoration *d = Decoration::getCentroid(theContext);
  769. theModule.addDecoration(d, targetId);
  770. }
  771. void ModuleBuilder::decorateFlat(uint32_t targetId) {
  772. const Decoration *d = Decoration::getFlat(theContext);
  773. theModule.addDecoration(d, targetId);
  774. }
  775. void ModuleBuilder::decorateNoPerspective(uint32_t targetId) {
  776. const Decoration *d = Decoration::getNoPerspective(theContext);
  777. theModule.addDecoration(d, targetId);
  778. }
  779. void ModuleBuilder::decorateSample(uint32_t targetId) {
  780. const Decoration *d = Decoration::getSample(theContext);
  781. theModule.addDecoration(d, targetId);
  782. }
  783. void ModuleBuilder::decorateBlock(uint32_t targetId) {
  784. const Decoration *d = Decoration::getBlock(theContext);
  785. theModule.addDecoration(d, targetId);
  786. }
  787. void ModuleBuilder::decorateRelaxedPrecision(uint32_t targetId) {
  788. const Decoration *d = Decoration::getRelaxedPrecision(theContext);
  789. theModule.addDecoration(d, targetId);
  790. }
  791. void ModuleBuilder::decoratePatch(uint32_t targetId) {
  792. const Decoration *d = Decoration::getPatch(theContext);
  793. theModule.addDecoration(d, targetId);
  794. }
  795. void ModuleBuilder::decorateNonUniformEXT(uint32_t targetId) {
  796. const Decoration *d = Decoration::getNonUniformEXT(theContext);
  797. theModule.addDecoration(d, targetId);
  798. }
  799. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  800. \
  801. uint32_t ModuleBuilder::get##ty##Type() { \
  802. const Type *type = Type::get##ty(theContext); \
  803. const uint32_t typeId = theContext.getResultIdForType(type); \
  804. theModule.addType(type, typeId); \
  805. return typeId; \
  806. }
  807. IMPL_GET_PRIMITIVE_TYPE(Void)
  808. IMPL_GET_PRIMITIVE_TYPE(Bool)
  809. IMPL_GET_PRIMITIVE_TYPE(Int32)
  810. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  811. IMPL_GET_PRIMITIVE_TYPE(Float32)
  812. #undef IMPL_GET_PRIMITIVE_TYPE
  813. // Note: At the moment, Float16 capability should not be added for Vulkan 1.0.
  814. // It is not a required capability, and adding the SPV_AMD_gpu_half_float does
  815. // not enable this capability. Any driver that supports float16 in Vulkan 1.0
  816. // should accept this extension.
  817. #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap) \
  818. \
  819. uint32_t ModuleBuilder::get##ty##Type() { \
  820. if (spv::Capability::cap == spv::Capability::Float16) \
  821. addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float", {}); \
  822. else \
  823. requireCapability(spv::Capability::cap); \
  824. const Type *type = Type::get##ty(theContext); \
  825. const uint32_t typeId = theContext.getResultIdForType(type); \
  826. theModule.addType(type, typeId); \
  827. return typeId; \
  828. }
  829. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
  830. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
  831. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
  832. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int16, Int16)
  833. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint16, Int16)
  834. IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float16, Float16)
  835. #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
  836. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  837. const Type *type = nullptr;
  838. switch (elemCount) {
  839. case 2:
  840. type = Type::getVec2(theContext, elemType);
  841. break;
  842. case 3:
  843. type = Type::getVec3(theContext, elemType);
  844. break;
  845. case 4:
  846. type = Type::getVec4(theContext, elemType);
  847. break;
  848. default:
  849. assert(false && "unhandled vector size");
  850. // Error found. Return 0 as the <result-id> directly.
  851. return 0;
  852. }
  853. const uint32_t typeId = theContext.getResultIdForType(type);
  854. theModule.addType(type, typeId);
  855. return typeId;
  856. }
  857. uint32_t ModuleBuilder::getMatType(QualType elemType, uint32_t colType,
  858. uint32_t colCount,
  859. Type::DecorationSet decorations) {
  860. // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
  861. // Validation Rules":
  862. // Matrix types can only be parameterized with floating-point types.
  863. //
  864. // So we need special handling of non-fp matrices. We emulate non-fp
  865. // matrices as an array of vectors.
  866. if (!elemType->isFloatingType())
  867. return getArrayType(colType, getConstantUint32(colCount), decorations);
  868. const Type *type = Type::getMatrix(theContext, colType, colCount);
  869. const uint32_t typeId = theContext.getResultIdForType(type);
  870. theModule.addType(type, typeId);
  871. return typeId;
  872. }
  873. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  874. spv::StorageClass storageClass) {
  875. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  876. const uint32_t typeId = theContext.getResultIdForType(type);
  877. theModule.addType(type, typeId);
  878. return typeId;
  879. }
  880. uint32_t
  881. ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
  882. llvm::StringRef structName,
  883. llvm::ArrayRef<llvm::StringRef> fieldNames,
  884. Type::DecorationSet decorations) {
  885. const Type *type = Type::getStruct(theContext, fieldTypes, decorations);
  886. bool isRegistered = false;
  887. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  888. theModule.addType(type, typeId);
  889. if (!isRegistered) {
  890. theModule.addDebugName(typeId, structName);
  891. if (!fieldNames.empty()) {
  892. assert(fieldNames.size() == fieldTypes.size());
  893. for (uint32_t i = 0; i < fieldNames.size(); ++i)
  894. theModule.addDebugName(typeId, fieldNames[i],
  895. llvm::Optional<uint32_t>(i));
  896. }
  897. }
  898. return typeId;
  899. }
  900. uint32_t ModuleBuilder::getSparseResidencyStructType(uint32_t type) {
  901. const auto uintType = getUint32Type();
  902. return getStructType({uintType, type}, "SparseResidencyStruct",
  903. {"Residency.Code", "Result.Type"});
  904. }
  905. uint32_t ModuleBuilder::getArrayType(uint32_t elemType, uint32_t count,
  906. Type::DecorationSet decorations) {
  907. const Type *type = Type::getArray(theContext, elemType, count, decorations);
  908. const uint32_t typeId = theContext.getResultIdForType(type);
  909. theModule.addType(type, typeId);
  910. return typeId;
  911. }
  912. uint32_t ModuleBuilder::getRuntimeArrayType(uint32_t elemType,
  913. Type::DecorationSet decorations) {
  914. const Type *type = Type::getRuntimeArray(theContext, elemType, decorations);
  915. const uint32_t typeId = theContext.getResultIdForType(type);
  916. theModule.addType(type, typeId);
  917. return typeId;
  918. }
  919. uint32_t ModuleBuilder::getFunctionType(uint32_t returnType,
  920. llvm::ArrayRef<uint32_t> paramTypes) {
  921. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  922. const uint32_t typeId = theContext.getResultIdForType(type);
  923. theModule.addType(type, typeId);
  924. return typeId;
  925. }
  926. uint32_t ModuleBuilder::getImageType(uint32_t sampledType, spv::Dim dim,
  927. uint32_t depth, bool isArray, uint32_t ms,
  928. uint32_t sampled,
  929. spv::ImageFormat format) {
  930. const Type *type = Type::getImage(theContext, sampledType, dim, depth,
  931. isArray, ms, sampled, format);
  932. bool isRegistered = false;
  933. const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
  934. theModule.addType(type, typeId);
  935. switch (format) {
  936. case spv::ImageFormat::Rg32f:
  937. case spv::ImageFormat::Rg16f:
  938. case spv::ImageFormat::R11fG11fB10f:
  939. case spv::ImageFormat::R16f:
  940. case spv::ImageFormat::Rgba16:
  941. case spv::ImageFormat::Rgb10A2:
  942. case spv::ImageFormat::Rg16:
  943. case spv::ImageFormat::Rg8:
  944. case spv::ImageFormat::R16:
  945. case spv::ImageFormat::R8:
  946. case spv::ImageFormat::Rgba16Snorm:
  947. case spv::ImageFormat::Rg16Snorm:
  948. case spv::ImageFormat::Rg8Snorm:
  949. case spv::ImageFormat::R16Snorm:
  950. case spv::ImageFormat::R8Snorm:
  951. case spv::ImageFormat::Rg32i:
  952. case spv::ImageFormat::Rg16i:
  953. case spv::ImageFormat::Rg8i:
  954. case spv::ImageFormat::R16i:
  955. case spv::ImageFormat::R8i:
  956. case spv::ImageFormat::Rgb10a2ui:
  957. case spv::ImageFormat::Rg32ui:
  958. case spv::ImageFormat::Rg16ui:
  959. case spv::ImageFormat::Rg8ui:
  960. case spv::ImageFormat::R16ui:
  961. case spv::ImageFormat::R8ui:
  962. requireCapability(spv::Capability::StorageImageExtendedFormats);
  963. break;
  964. default:
  965. // Only image formats requiring extended formats are relevant. The rest just pass through.
  966. break;
  967. }
  968. if (dim == spv::Dim::Dim1D) {
  969. if (sampled == 2u) {
  970. requireCapability(spv::Capability::Image1D);
  971. } else {
  972. requireCapability(spv::Capability::Sampled1D);
  973. }
  974. } else if (dim == spv::Dim::Buffer) {
  975. requireCapability(spv::Capability::SampledBuffer);
  976. } else if (dim == spv::Dim::SubpassData) {
  977. requireCapability(spv::Capability::InputAttachment);
  978. }
  979. if (isArray && ms) {
  980. requireCapability(spv::Capability::ImageMSArray);
  981. }
  982. // Skip constructing the debug name if we have already done it before.
  983. if (!isRegistered) {
  984. const char *dimStr = "";
  985. switch (dim) {
  986. case spv::Dim::Dim1D:
  987. dimStr = "1d.";
  988. break;
  989. case spv::Dim::Dim2D:
  990. dimStr = "2d.";
  991. break;
  992. case spv::Dim::Dim3D:
  993. dimStr = "3d.";
  994. break;
  995. case spv::Dim::Cube:
  996. dimStr = "cube.";
  997. break;
  998. case spv::Dim::Rect:
  999. dimStr = "rect.";
  1000. break;
  1001. case spv::Dim::Buffer:
  1002. dimStr = "buffer.";
  1003. break;
  1004. case spv::Dim::SubpassData:
  1005. dimStr = "subpass.";
  1006. break;
  1007. default:
  1008. break;
  1009. }
  1010. std::string name =
  1011. std::string("type.") + dimStr + "image" + (isArray ? ".array" : "");
  1012. theModule.addDebugName(typeId, name);
  1013. }
  1014. return typeId;
  1015. }
  1016. uint32_t ModuleBuilder::getSamplerType() {
  1017. const Type *type = Type::getSampler(theContext);
  1018. const uint32_t typeId = theContext.getResultIdForType(type);
  1019. theModule.addType(type, typeId);
  1020. theModule.addDebugName(typeId, "type.sampler");
  1021. return typeId;
  1022. }
  1023. uint32_t ModuleBuilder::getSampledImageType(uint32_t imageType) {
  1024. const Type *type = Type::getSampledImage(theContext, imageType);
  1025. const uint32_t typeId = theContext.getResultIdForType(type);
  1026. theModule.addType(type, typeId);
  1027. theModule.addDebugName(typeId, "type.sampled.image");
  1028. return typeId;
  1029. }
  1030. uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
  1031. // Create a uint RuntimeArray with Array Stride of 4.
  1032. const uint32_t uintType = getUint32Type();
  1033. const auto *arrStride4 = Decoration::getArrayStride(theContext, 4u);
  1034. const Type *raType =
  1035. Type::getRuntimeArray(theContext, uintType, {arrStride4});
  1036. const uint32_t raTypeId = theContext.getResultIdForType(raType);
  1037. theModule.addType(raType, raTypeId);
  1038. // Create a struct containing the runtime array as its only member.
  1039. // The struct must also be decorated as BufferBlock. The offset decoration
  1040. // should also be applied to the first (only) member. NonWritable decoration
  1041. // should also be applied to the first member if isRW is true.
  1042. llvm::SmallVector<const Decoration *, 3> typeDecs;
  1043. typeDecs.push_back(Decoration::getBufferBlock(theContext));
  1044. typeDecs.push_back(Decoration::getOffset(theContext, 0, 0));
  1045. if (!isRW)
  1046. typeDecs.push_back(Decoration::getNonWritable(theContext, 0));
  1047. const Type *type = Type::getStruct(theContext, {raTypeId}, typeDecs);
  1048. const uint32_t typeId = theContext.getResultIdForType(type);
  1049. theModule.addType(type, typeId);
  1050. theModule.addDebugName(typeId, isRW ? "type.RWByteAddressBuffer"
  1051. : "type.ByteAddressBuffer");
  1052. return typeId;
  1053. }
  1054. uint32_t ModuleBuilder::getConstantBool(bool value, bool isSpecConst) {
  1055. if (isSpecConst) {
  1056. const uint32_t constId = theContext.takeNextId();
  1057. if (value) {
  1058. instBuilder.opSpecConstantTrue(getBoolType(), constId).x();
  1059. } else {
  1060. instBuilder.opSpecConstantFalse(getBoolType(), constId).x();
  1061. }
  1062. theModule.addVariable(std::move(constructSite));
  1063. return constId;
  1064. }
  1065. const uint32_t typeId = getBoolType();
  1066. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  1067. : Constant::getFalse(theContext, typeId);
  1068. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1069. theModule.addConstant(constant, constId);
  1070. return constId;
  1071. }
  1072. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  1073. \
  1074. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  1075. const uint32_t typeId = get##builderTy##Type(); \
  1076. const Constant *constant = \
  1077. Constant::get##builderTy(theContext, typeId, value); \
  1078. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1079. theModule.addConstant(constant, constId); \
  1080. return constId; \
  1081. }
  1082. #define IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(builderTy, cppTy) \
  1083. \
  1084. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value, \
  1085. bool isSpecConst) { \
  1086. if (isSpecConst) { \
  1087. const uint32_t constId = theContext.takeNextId(); \
  1088. instBuilder \
  1089. .opSpecConstant(get##builderTy##Type(), constId, \
  1090. cast::BitwiseCast<uint32_t>(value)) \
  1091. .x(); \
  1092. theModule.addVariable(std::move(constructSite)); \
  1093. return constId; \
  1094. } \
  1095. \
  1096. const uint32_t typeId = get##builderTy##Type(); \
  1097. const Constant *constant = \
  1098. Constant::get##builderTy(theContext, typeId, value); \
  1099. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  1100. theModule.addConstant(constant, constId); \
  1101. return constId; \
  1102. }
  1103. IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
  1104. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Int32, int32_t)
  1105. IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
  1106. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Uint32, uint32_t)
  1107. IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
  1108. IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Float32, float)
  1109. IMPL_GET_PRIMITIVE_CONST(Float64, double)
  1110. IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
  1111. IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
  1112. #undef IMPL_GET_PRIMITIVE_CONST
  1113. #undef IMPL_GET_PRIMITIVE_CONST_SPEC_CONST
  1114. uint32_t
  1115. ModuleBuilder::getConstantComposite(uint32_t typeId,
  1116. llvm::ArrayRef<uint32_t> constituents) {
  1117. const Constant *constant =
  1118. Constant::getComposite(theContext, typeId, constituents);
  1119. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1120. theModule.addConstant(constant, constId);
  1121. return constId;
  1122. }
  1123. uint32_t ModuleBuilder::getConstantNull(uint32_t typeId) {
  1124. const Constant *constant = Constant::getNull(theContext, typeId);
  1125. const uint32_t constId = theContext.getResultIdForConstant(constant);
  1126. theModule.addConstant(constant, constId);
  1127. return constId;
  1128. }
  1129. BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
  1130. auto it = basicBlocks.find(labelId);
  1131. if (it == basicBlocks.end()) {
  1132. assert(false && "invalid <label-id>");
  1133. return nullptr;
  1134. }
  1135. return it->second.get();
  1136. }
  1137. } // end namespace spirv
  1138. } // end namespace clang