RawBufferMethods.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. //===---- RawBufferMethods.cpp ---- Raw Buffer Methods ----------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. #include "RawBufferMethods.h"
  9. #include "AlignmentSizeCalculator.h"
  10. #include "clang/AST/ASTContext.h"
  11. #include "clang/AST/CharUnits.h"
  12. #include "clang/AST/RecordLayout.h"
  13. #include "clang/SPIRV/AstTypeProbe.h"
  14. #include "clang/SPIRV/SpirvBuilder.h"
  15. #include "clang/SPIRV/SpirvInstruction.h"
  16. namespace {
  17. /// Rounds the given value up to the given power of 2.
  18. inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) {
  19. assert(pow2 != 0);
  20. return (val + pow2 - 1) & ~(pow2 - 1);
  21. }
  22. } // anonymous namespace
  23. namespace clang {
  24. namespace spirv {
  25. SpirvInstruction *
  26. RawBufferHandler::bitCastToNumericalOrBool(SpirvInstruction *instr,
  27. QualType fromType, QualType toType,
  28. SourceLocation loc) {
  29. if (isSameType(astContext, fromType, toType))
  30. return instr;
  31. if (toType->isBooleanType() || fromType->isBooleanType())
  32. return theEmitter.castToType(instr, fromType, toType, loc);
  33. // Perform a bitcast
  34. return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toType, instr, loc);
  35. }
  36. SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset0(
  37. SpirvInstruction *buffer, SpirvInstruction *&index,
  38. QualType target16BitType, uint32_t &bitOffset) {
  39. assert(bitOffset == 0);
  40. const auto loc = buffer->getSourceLocation();
  41. SpirvInstruction *result = nullptr;
  42. auto *constUint0 =
  43. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  44. // The underlying element type of the ByteAddressBuffer is uint. So we
  45. // need to load 32-bits at the very least.
  46. auto *loadPtr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  47. {constUint0, index}, loc);
  48. result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, loc);
  49. // Only need to mask the lowest 16 bits of the loaded 32-bit uint.
  50. // OpUConvert can perform truncation in this case.
  51. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  52. astContext.UnsignedShortTy, result, loc);
  53. result = bitCastToNumericalOrBool(result, astContext.UnsignedShortTy,
  54. target16BitType, loc);
  55. result->setRValue();
  56. // Now that a 16-bit load at bit-offset 0 has been performed, the next load
  57. // should be done at *the same base index* at bit-offset 16.
  58. bitOffset = (bitOffset + 16) % 32;
  59. return result;
  60. }
  61. SpirvInstruction *RawBufferHandler::load32BitsAtBitOffset0(
  62. SpirvInstruction *buffer, SpirvInstruction *&index,
  63. QualType target32BitType, uint32_t &bitOffset) {
  64. assert(bitOffset == 0);
  65. const auto loc = buffer->getSourceLocation();
  66. SpirvInstruction *result = nullptr;
  67. // Only need to perform one 32-bit uint load.
  68. auto *constUint0 =
  69. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  70. auto *constUint1 =
  71. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  72. auto *loadPtr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  73. {constUint0, index}, loc);
  74. result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, loc);
  75. result = bitCastToNumericalOrBool(result, astContext.UnsignedIntTy,
  76. target32BitType, loc);
  77. result->setRValue();
  78. // Now that a 32-bit load at bit-offset 0 has been performed, the next load
  79. // should be done at *the next base index* at bit-offset 0.
  80. bitOffset = (bitOffset + 32) % 32;
  81. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  82. index, constUint1, loc);
  83. return result;
  84. }
  85. SpirvInstruction *RawBufferHandler::load64BitsAtBitOffset0(
  86. SpirvInstruction *buffer, SpirvInstruction *&index,
  87. QualType target64BitType, uint32_t &bitOffset) {
  88. assert(bitOffset == 0);
  89. const auto loc = buffer->getSourceLocation();
  90. SpirvInstruction *result = nullptr;
  91. SpirvInstruction *ptr = nullptr;
  92. auto *constUint0 =
  93. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  94. auto *constUint1 =
  95. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  96. auto *constUint32 =
  97. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
  98. // Need to perform two 32-bit uint loads and construct a 64-bit value.
  99. // Load the first 32-bit uint (word0).
  100. ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  101. {constUint0, index}, loc);
  102. SpirvInstruction *word0 =
  103. spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc);
  104. // Increment the base index
  105. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  106. index, constUint1, loc);
  107. // Load the second 32-bit uint (word1).
  108. ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  109. {constUint0, index}, loc);
  110. SpirvInstruction *word1 =
  111. spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc);
  112. // Convert both word0 and word1 to 64-bit uints.
  113. word0 = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  114. astContext.UnsignedLongLongTy, word0, loc);
  115. word1 = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  116. astContext.UnsignedLongLongTy, word1, loc);
  117. // Shift word1 to the left by 32 bits.
  118. word1 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  119. astContext.UnsignedLongLongTy, word1,
  120. constUint32, loc);
  121. // BitwiseOr word0 and word1.
  122. result = spvBuilder.createBinaryOp(
  123. spv::Op::OpBitwiseOr, astContext.UnsignedLongLongTy, word0, word1, loc);
  124. result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
  125. target64BitType, loc);
  126. result->setRValue();
  127. // Now that a 64-bit load at bit-offset 0 has been performed, the next load
  128. // should be done at *the base index + 2* at bit-offset 0. The index has
  129. // already been incremented once. Need to increment it once more.
  130. bitOffset = (bitOffset + 64) % 32;
  131. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  132. index, constUint1, loc);
  133. return result;
  134. }
  135. SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset16(
  136. SpirvInstruction *buffer, SpirvInstruction *&index,
  137. QualType target16BitType, uint32_t &bitOffset) {
  138. assert(bitOffset == 16);
  139. const auto loc = buffer->getSourceLocation();
  140. SpirvInstruction *result = nullptr;
  141. auto *constUint0 =
  142. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  143. auto *constUint1 =
  144. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  145. auto *constUint16 =
  146. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  147. // The underlying element type of the ByteAddressBuffer is uint. So we
  148. // need to load 32-bits at the very least.
  149. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  150. {constUint0, index}, loc);
  151. result = spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc);
  152. result = spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
  153. astContext.UnsignedIntTy, result,
  154. constUint16, loc);
  155. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  156. astContext.UnsignedShortTy, result, loc);
  157. result = bitCastToNumericalOrBool(result, astContext.UnsignedShortTy,
  158. target16BitType, loc);
  159. result->setRValue();
  160. // Now that a 16-bit load at bit-offset 16 has been performed, the next load
  161. // should be done at *the next base index* at bit-offset 0.
  162. bitOffset = (bitOffset + 16) % 32;
  163. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  164. index, constUint1, loc);
  165. return result;
  166. }
  167. SpirvInstruction *RawBufferHandler::processTemplatedLoadFromBuffer(
  168. SpirvInstruction *buffer, SpirvInstruction *&index,
  169. const QualType targetType, uint32_t &bitOffset) {
  170. const auto loc = buffer->getSourceLocation();
  171. SpirvInstruction *result = nullptr;
  172. auto *constUint0 =
  173. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  174. auto *constUint1 =
  175. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  176. // TODO: If 8-bit types are to be supported in the future, we should also
  177. // add code to support bitOffset 8 and 24.
  178. assert(bitOffset == 0 || bitOffset == 16);
  179. // Scalar types
  180. if (isScalarType(targetType)) {
  181. auto loadWidth = getElementSpirvBitwidth(
  182. astContext, targetType, theEmitter.getSpirvOptions().enable16BitTypes);
  183. switch (bitOffset) {
  184. case 0: {
  185. switch (loadWidth) {
  186. case 16:
  187. return load16BitsAtBitOffset0(buffer, index, targetType, bitOffset);
  188. break;
  189. case 32:
  190. return load32BitsAtBitOffset0(buffer, index, targetType, bitOffset);
  191. break;
  192. case 64:
  193. return load64BitsAtBitOffset0(buffer, index, targetType, bitOffset);
  194. break;
  195. default:
  196. theEmitter.emitError(
  197. "templated load of ByteAddressBuffer is only implemented for "
  198. "16, 32, and 64-bit types",
  199. loc);
  200. return nullptr;
  201. }
  202. break;
  203. }
  204. case 16: {
  205. switch (loadWidth) {
  206. case 16:
  207. return load16BitsAtBitOffset16(buffer, index, targetType, bitOffset);
  208. break;
  209. case 32:
  210. case 64:
  211. theEmitter.emitError(
  212. "templated buffer load should not result in loading "
  213. "32-bit or 64-bit values at bit offset 16",
  214. loc);
  215. return nullptr;
  216. default:
  217. theEmitter.emitError(
  218. "templated load of ByteAddressBuffer is only implemented for "
  219. "16, 32, and 64-bit types",
  220. loc);
  221. return nullptr;
  222. }
  223. break;
  224. }
  225. default:
  226. theEmitter.emitError(
  227. "templated load of ByteAddressBuffer is only implemented for "
  228. "16, 32, and 64-bit types",
  229. loc);
  230. return nullptr;
  231. }
  232. }
  233. // Vector types
  234. {
  235. QualType elemType = {};
  236. uint32_t elemCount = 0;
  237. if (isVectorType(targetType, &elemType, &elemCount)) {
  238. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  239. for (uint32_t i = 0; i < elemCount; ++i) {
  240. loadedElems.push_back(
  241. processTemplatedLoadFromBuffer(buffer, index, elemType, bitOffset));
  242. }
  243. result =
  244. spvBuilder.createCompositeConstruct(targetType, loadedElems, loc);
  245. result->setRValue();
  246. return result;
  247. }
  248. }
  249. // Array types
  250. {
  251. QualType elemType = {};
  252. uint32_t elemCount = 0;
  253. if (const auto *arrType = astContext.getAsConstantArrayType(targetType)) {
  254. elemCount = static_cast<uint32_t>(arrType->getSize().getZExtValue());
  255. elemType = arrType->getElementType();
  256. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  257. for (uint32_t i = 0; i < elemCount; ++i) {
  258. loadedElems.push_back(
  259. processTemplatedLoadFromBuffer(buffer, index, elemType, bitOffset));
  260. }
  261. result =
  262. spvBuilder.createCompositeConstruct(targetType, loadedElems, loc);
  263. result->setRValue();
  264. return result;
  265. }
  266. }
  267. // Matrix types
  268. {
  269. QualType elemType = {};
  270. uint32_t numRows = 0, numCols = 0;
  271. if (isMxNMatrix(targetType, &elemType, &numRows, &numCols)) {
  272. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  273. llvm::SmallVector<SpirvInstruction *, 4> loadedRows;
  274. for (uint32_t i = 0; i < numRows; ++i) {
  275. for (uint32_t j = 0; j < numCols; ++j) {
  276. // TODO: This is currently doing a row_major matrix load. We must
  277. // investigate whether we also need to implement it for column_major.
  278. loadedElems.push_back(processTemplatedLoadFromBuffer(
  279. buffer, index, elemType, bitOffset));
  280. }
  281. const auto rowType = astContext.getExtVectorType(elemType, numCols);
  282. loadedRows.push_back(
  283. spvBuilder.createCompositeConstruct(rowType, loadedElems, loc));
  284. loadedElems.clear();
  285. }
  286. result = spvBuilder.createCompositeConstruct(targetType, loadedRows, loc);
  287. result->setRValue();
  288. return result;
  289. }
  290. }
  291. // Struct types
  292. // The "natural" layout for structure types dictates that structs are
  293. // aligned like their field with the largest alignment.
  294. // As a result, there might exist some padding after some struct members.
  295. if (const auto *structType = targetType->getAs<RecordType>()) {
  296. const auto *decl = structType->getDecl();
  297. SpirvInstruction *originalIndex = index;
  298. uint32_t originalBitOffset = bitOffset;
  299. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  300. uint32_t fieldOffsetInBytes = 0;
  301. uint32_t structAlignment = 0, structSize = 0, stride = 0;
  302. std::tie(structAlignment, structSize) =
  303. AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
  304. .getAlignmentAndSize(targetType,
  305. theEmitter.getSpirvOptions().sBufferLayoutRule,
  306. llvm::None, &stride);
  307. for (const auto *field : decl->fields()) {
  308. AlignmentSizeCalculator alignmentCalc(astContext,
  309. theEmitter.getSpirvOptions());
  310. uint32_t fieldSize = 0, fieldAlignment = 0;
  311. std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
  312. field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
  313. /*isRowMajor*/ llvm::None, &stride);
  314. fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
  315. const auto wordOffset =
  316. ((originalBitOffset / 8) + fieldOffsetInBytes) / 4;
  317. bitOffset = (((originalBitOffset / 8) + fieldOffsetInBytes) % 4) * 8;
  318. if (wordOffset != 0) {
  319. // Divide the fieldOffset by 4 to figure out how much to increment the
  320. // index into the buffer (increment occurs by 32-bit words since the
  321. // underlying type is an array of uints).
  322. // The remainder by four tells us the *byte offset* (then multiply by 8
  323. // to get bit offset).
  324. index = spvBuilder.createBinaryOp(
  325. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  326. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  327. llvm::APInt(32, wordOffset)),
  328. loc);
  329. }
  330. loadedElems.push_back(processTemplatedLoadFromBuffer(
  331. buffer, index, field->getType(), bitOffset));
  332. fieldOffsetInBytes += fieldSize;
  333. }
  334. // After we're done with loading the entire struct, we need to update the
  335. // index and bitOffset (in case we are loading an array of structs).
  336. //
  337. // Example: struct alignment = 8. struct size = 34 bytes
  338. // (34 / 8) = 4 full words
  339. // (34 % 8) = 2 > 0, therefore need to move to the next aligned address
  340. // So the starting byte offset after loading the entire struct is:
  341. // 8 * (4 + 1) = 40
  342. assert(structAlignment != 0);
  343. uint32_t newByteOffset = roundToPow2(structSize, structAlignment);
  344. uint32_t newWordOffset = ((originalBitOffset / 8) + newByteOffset) / 4;
  345. bitOffset = 8 * (((originalBitOffset / 8) + newByteOffset) % 4);
  346. index = spvBuilder.createBinaryOp(
  347. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  348. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  349. llvm::APInt(32, newWordOffset)),
  350. loc);
  351. result = spvBuilder.createCompositeConstruct(targetType, loadedElems, loc);
  352. result->setRValue();
  353. return result;
  354. }
  355. llvm_unreachable("templated buffer load unimplemented for type");
  356. }
  357. void RawBufferHandler::store16BitsAtBitOffset0(SpirvInstruction *value,
  358. SpirvInstruction *buffer,
  359. SpirvInstruction *&index,
  360. const QualType valueType) {
  361. const auto loc = buffer->getSourceLocation();
  362. SpirvInstruction *result = nullptr;
  363. auto *constUint0 =
  364. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  365. // The underlying element type of the ByteAddressBuffer is uint. So we
  366. // need to store a 32-bit value.
  367. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  368. {constUint0, index}, loc);
  369. result = bitCastToNumericalOrBool(value, valueType,
  370. astContext.UnsignedShortTy, loc);
  371. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  372. astContext.UnsignedIntTy, result, loc);
  373. spvBuilder.createStore(ptr, result, loc);
  374. }
  375. void RawBufferHandler::store16BitsAtBitOffset16(SpirvInstruction *value,
  376. SpirvInstruction *buffer,
  377. SpirvInstruction *&index,
  378. const QualType valueType) {
  379. const auto loc = buffer->getSourceLocation();
  380. SpirvInstruction *result = nullptr;
  381. auto *constUint0 =
  382. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  383. auto *constUint1 =
  384. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  385. auto *constUint16 =
  386. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  387. // The underlying element type of the ByteAddressBuffer is uint. So we
  388. // need to store a 32-bit value.
  389. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  390. {constUint0, index}, loc);
  391. result = bitCastToNumericalOrBool(value, valueType,
  392. astContext.UnsignedShortTy, loc);
  393. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  394. astContext.UnsignedIntTy, result, loc);
  395. result = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  396. astContext.UnsignedIntTy, result,
  397. constUint16, loc);
  398. result = spvBuilder.createBinaryOp(
  399. spv::Op::OpBitwiseOr, astContext.UnsignedIntTy,
  400. spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc), result, loc);
  401. spvBuilder.createStore(ptr, result, loc);
  402. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  403. index, constUint1, loc);
  404. }
  405. void RawBufferHandler::store32BitsAtBitOffset0(SpirvInstruction *value,
  406. SpirvInstruction *buffer,
  407. SpirvInstruction *&index,
  408. const QualType valueType) {
  409. const auto loc = buffer->getSourceLocation();
  410. auto *constUint0 =
  411. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  412. auto *constUint1 =
  413. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  414. // The underlying element type of the ByteAddressBuffer is uint. So we
  415. // need to store a 32-bit value.
  416. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  417. {constUint0, index}, loc);
  418. value =
  419. bitCastToNumericalOrBool(value, valueType, astContext.UnsignedIntTy, loc);
  420. spvBuilder.createStore(ptr, value, loc);
  421. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  422. index, constUint1, loc);
  423. }
  424. void RawBufferHandler::store64BitsAtBitOffset0(SpirvInstruction *value,
  425. SpirvInstruction *buffer,
  426. SpirvInstruction *&index,
  427. const QualType valueType) {
  428. const auto loc = buffer->getSourceLocation();
  429. auto *constUint0 =
  430. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  431. auto *constUint1 =
  432. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  433. auto *constUint32 =
  434. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
  435. // The underlying element type of the ByteAddressBuffer is uint. So we
  436. // need to store two 32-bit values.
  437. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  438. {constUint0, index}, loc);
  439. // First convert the 64-bit value to uint64_t. Then extract two 32-bit words
  440. // from it.
  441. value = bitCastToNumericalOrBool(value, valueType,
  442. astContext.UnsignedLongLongTy, loc);
  443. // Use OpUConvert to perform truncation (produces the least significant bits).
  444. SpirvInstruction *lsb = spvBuilder.createUnaryOp(
  445. spv::Op::OpUConvert, astContext.UnsignedIntTy, value, loc);
  446. // Shift uint64_t to the right by 32 bits and truncate to get the most
  447. // significant bits.
  448. SpirvInstruction *msb = spvBuilder.createUnaryOp(
  449. spv::Op::OpUConvert, astContext.UnsignedIntTy,
  450. spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
  451. astContext.UnsignedLongLongTy, value,
  452. constUint32, loc),
  453. loc);
  454. spvBuilder.createStore(ptr, lsb, loc);
  455. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  456. index, constUint1, loc);
  457. ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  458. {constUint0, index}, loc);
  459. spvBuilder.createStore(ptr, msb, loc);
  460. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  461. index, constUint1, loc);
  462. }
  463. void RawBufferHandler::storeArrayOfScalars(
  464. std::deque<SpirvInstruction *> values, SpirvInstruction *buffer,
  465. SpirvInstruction *&index, const QualType valueType, uint32_t &bitOffset,
  466. SourceLocation loc) {
  467. auto *constUint0 =
  468. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  469. auto *constUint1 =
  470. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  471. auto *constUint16 =
  472. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  473. const auto storeWidth = getElementSpirvBitwidth(
  474. astContext, valueType, theEmitter.getSpirvOptions().enable16BitTypes);
  475. const uint32_t elemCount = values.size();
  476. if (storeWidth == 16u) {
  477. uint32_t elemIndex = 0;
  478. if (bitOffset == 16) {
  479. // First store the first element at offset 16 of the last memory index.
  480. store16BitsAtBitOffset16(values[0], buffer, index, valueType);
  481. bitOffset = 0;
  482. ++elemIndex;
  483. }
  484. // Do a custom store based on the number of elements.
  485. for (; elemIndex < elemCount; elemIndex = elemIndex + 2) {
  486. // The underlying element type of the ByteAddressBuffer is uint. So we
  487. // need to store a 32-bit value by combining two 16-bit values.
  488. SpirvInstruction *word = nullptr;
  489. word = bitCastToNumericalOrBool(values[elemIndex], valueType,
  490. astContext.UnsignedShortTy, loc);
  491. // Zero-extend to 32 bits.
  492. word = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  493. astContext.UnsignedIntTy, word, loc);
  494. if (elemIndex + 1 < elemCount) {
  495. SpirvInstruction *msb = nullptr;
  496. msb = bitCastToNumericalOrBool(values[elemIndex + 1], valueType,
  497. astContext.UnsignedShortTy, loc);
  498. msb = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  499. astContext.UnsignedIntTy, msb, loc);
  500. msb = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  501. astContext.UnsignedIntTy, msb,
  502. constUint16, loc);
  503. word = spvBuilder.createBinaryOp(
  504. spv::Op::OpBitwiseOr, astContext.UnsignedIntTy, word, msb, loc);
  505. // We will store two 16-bit values.
  506. bitOffset = (bitOffset + 32) % 32;
  507. } else {
  508. // We will store one 16-bit value.
  509. bitOffset = (bitOffset + 16) % 32;
  510. }
  511. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  512. {constUint0, index}, loc);
  513. spvBuilder.createStore(ptr, word, loc);
  514. index = spvBuilder.createBinaryOp(
  515. spv::Op::OpIAdd, astContext.UnsignedIntTy, index, constUint1, loc);
  516. }
  517. } else if (storeWidth == 32u || storeWidth == 64u) {
  518. assert(bitOffset == 0);
  519. for (uint32_t i = 0; i < elemCount; ++i)
  520. processTemplatedStoreToBuffer(values[i], buffer, index, valueType, bitOffset);
  521. }
  522. }
  523. QualType RawBufferHandler::serializeToScalarsOrStruct(
  524. std::deque<SpirvInstruction *> *values, QualType valueType,
  525. SourceLocation loc) {
  526. uint32_t size = values->size();
  527. // Vector type
  528. {
  529. QualType elemType = {};
  530. uint32_t elemCount = 0;
  531. if (isVectorType(valueType, &elemType, &elemCount)) {
  532. for (uint32_t i = 0; i < size; ++i) {
  533. for (uint32_t j = 0; j < elemCount; ++j) {
  534. values->push_back(spvBuilder.createCompositeExtract(
  535. elemType, values->front(), {j}, loc));
  536. }
  537. values->pop_front();
  538. }
  539. return elemType;
  540. }
  541. }
  542. // Matrix type
  543. {
  544. QualType elemType = {};
  545. uint32_t numRows = 0, numCols = 0;
  546. if (isMxNMatrix(valueType, &elemType, &numRows, &numCols)) {
  547. for (uint32_t i = 0; i < size; ++i) {
  548. for (uint32_t j = 0; j < numRows; ++j) {
  549. for (uint32_t k = 0; k < numCols; ++k) {
  550. // TODO: This is currently doing a row_major matrix store. We must
  551. // investigate whether we also need to implement it for
  552. // column_major.
  553. values->push_back(spvBuilder.createCompositeExtract(
  554. elemType, values->front(), {j, k}, loc));
  555. }
  556. }
  557. values->pop_front();
  558. }
  559. return serializeToScalarsOrStruct(values, elemType, loc);
  560. }
  561. }
  562. // Array type
  563. {
  564. if (const auto *arrType = astContext.getAsConstantArrayType(valueType)) {
  565. const uint32_t arrElemCount =
  566. static_cast<uint32_t>(arrType->getSize().getZExtValue());
  567. const QualType arrElemType = arrType->getElementType();
  568. for (uint32_t i = 0; i < size; ++i) {
  569. for (uint32_t j = 0; j < arrElemCount; ++j) {
  570. values->push_back(spvBuilder.createCompositeExtract(
  571. arrElemType, values->front(), {j}, loc));
  572. }
  573. values->pop_front();
  574. }
  575. return serializeToScalarsOrStruct(values, arrElemType, loc);
  576. }
  577. }
  578. if (isScalarType(valueType))
  579. return valueType;
  580. if (const auto *structType = valueType->getAs<RecordType>())
  581. return valueType;
  582. llvm_unreachable("unhandled type when serializing an array");
  583. }
  584. void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
  585. SpirvInstruction *buffer,
  586. SpirvInstruction *&index,
  587. const QualType valueType,
  588. uint32_t &bitOffset) {
  589. assert(bitOffset == 0 || bitOffset == 16);
  590. const auto loc = buffer->getSourceLocation();
  591. auto *constUint0 =
  592. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  593. auto *constUint1 =
  594. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  595. // Scalar types
  596. if (isScalarType(valueType)) {
  597. auto storeWidth = getElementSpirvBitwidth(
  598. astContext, valueType, theEmitter.getSpirvOptions().enable16BitTypes);
  599. switch (bitOffset) {
  600. case 0: {
  601. switch (storeWidth) {
  602. case 16:
  603. store16BitsAtBitOffset0(value, buffer, index, valueType);
  604. return;
  605. case 32:
  606. store32BitsAtBitOffset0(value, buffer, index, valueType);
  607. return;
  608. case 64:
  609. store64BitsAtBitOffset0(value, buffer, index, valueType);
  610. return;
  611. default:
  612. theEmitter.emitError(
  613. "templated load of ByteAddressBuffer is only implemented for "
  614. "16, 32, and 64-bit types",
  615. loc);
  616. return;
  617. }
  618. }
  619. case 16: {
  620. // The only legal store at offset 16 is by a 16-bit value.
  621. assert(storeWidth == 16);
  622. store16BitsAtBitOffset16(value, buffer, index, valueType);
  623. return;
  624. }
  625. default:
  626. theEmitter.emitError(
  627. "templated load of ByteAddressBuffer is only implemented for "
  628. "16, 32, and 64-bit types",
  629. loc);
  630. return;
  631. }
  632. }
  633. // Vectors, Matrices, and Arrays can all be serialized and stored.
  634. if (isVectorType(valueType) || isMxNMatrix(valueType) ||
  635. isConstantArrayType(astContext, valueType)) {
  636. std::deque<SpirvInstruction *> elems;
  637. elems.push_back(value);
  638. auto serializedType = serializeToScalarsOrStruct(&elems, valueType, loc);
  639. if (isScalarType(serializedType)) {
  640. storeArrayOfScalars(elems, buffer, index, serializedType, bitOffset, loc);
  641. } else if (const auto *structType = serializedType->getAs<RecordType>()) {
  642. for (auto elem : elems)
  643. processTemplatedStoreToBuffer(elem, buffer, index, serializedType,
  644. bitOffset);
  645. }
  646. return;
  647. }
  648. // Struct types
  649. // The "natural" layout for structure types dictates that structs are
  650. // aligned like their field with the largest alignment.
  651. // As a result, there might exist some padding after some struct members.
  652. if (const auto *structType = valueType->getAs<RecordType>()) {
  653. const auto *decl = structType->getDecl();
  654. SpirvInstruction *originalIndex = index;
  655. const auto originalBitOffset = bitOffset;
  656. uint32_t fieldOffsetInBytes = 0;
  657. uint32_t structAlignment = 0, structSize = 0, stride = 0;
  658. std::tie(structAlignment, structSize) =
  659. AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
  660. .getAlignmentAndSize(valueType,
  661. theEmitter.getSpirvOptions().sBufferLayoutRule,
  662. llvm::None, &stride);
  663. uint32_t fieldIndex = 0;
  664. for (const auto *field : decl->fields()) {
  665. AlignmentSizeCalculator alignmentCalc(astContext,
  666. theEmitter.getSpirvOptions());
  667. uint32_t fieldSize = 0, fieldAlignment = 0;
  668. std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
  669. field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
  670. /*isRowMajor*/ llvm::None, &stride);
  671. fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
  672. const auto wordOffset =
  673. ((originalBitOffset / 8) + fieldOffsetInBytes) / 4;
  674. bitOffset = (((originalBitOffset / 8) + fieldOffsetInBytes) % 4) * 8;
  675. if (wordOffset != 0) {
  676. // Divide the fieldOffset by 4 to figure out how much to increment the
  677. // index into the buffer (increment occurs by 32-bit words since the
  678. // underlying type is an array of uints).
  679. // The remainder by four tells us the *byte offset* (then multiply by 8
  680. // to get bit offset).
  681. index = spvBuilder.createBinaryOp(
  682. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  683. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  684. llvm::APInt(32, wordOffset)),
  685. loc);
  686. }
  687. processTemplatedStoreToBuffer(
  688. spvBuilder.createCompositeExtract(field->getType(), value,
  689. {fieldIndex}, loc),
  690. buffer, index, field->getType(), bitOffset);
  691. fieldOffsetInBytes += fieldSize;
  692. ++fieldIndex;
  693. }
  694. // After we're done with storing the entire struct, we need to update the
  695. // index (in case we are stroring an array of structs).
  696. //
  697. // Example: struct alignment = 8. struct size = 34 bytes
  698. // (34 / 8) = 4 full words
  699. // (34 % 8) = 2 > 0, therefore need to move to the next aligned address
  700. // So the starting byte offset after loading the entire struct is:
  701. // 8 * (4 + 1) = 40
  702. assert(structAlignment != 0);
  703. uint32_t newByteOffset = roundToPow2(structSize, structAlignment);
  704. uint32_t newWordOffset = ((originalBitOffset / 8) + newByteOffset) / 4;
  705. bitOffset = 8 * (((originalBitOffset / 8) + newByteOffset) % 4);
  706. index = spvBuilder.createBinaryOp(
  707. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  708. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  709. llvm::APInt(32, newWordOffset)),
  710. loc);
  711. return;
  712. }
  713. llvm_unreachable("templated buffer store unimplemented for type");
  714. }
  715. } // namespace spirv
  716. } // namespace clang