RawBufferMethods.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. //===---- RawBufferMethods.cpp ---- Raw Buffer Methods ----------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. #include "RawBufferMethods.h"
  9. #include "AlignmentSizeCalculator.h"
  10. #include "clang/AST/ASTContext.h"
  11. #include "clang/AST/CharUnits.h"
  12. #include "clang/AST/RecordLayout.h"
  13. #include "clang/SPIRV/AstTypeProbe.h"
  14. #include "clang/SPIRV/SpirvBuilder.h"
  15. #include "clang/SPIRV/SpirvInstruction.h"
  16. namespace {
  17. /// Rounds the given value up to the given power of 2.
  18. inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) {
  19. assert(pow2 != 0);
  20. return (val + pow2 - 1) & ~(pow2 - 1);
  21. }
  22. } // anonymous namespace
  23. namespace clang {
  24. namespace spirv {
  25. SpirvInstruction *
  26. RawBufferHandler::bitCastToNumericalOrBool(SpirvInstruction *instr,
  27. QualType fromType, QualType toType,
  28. SourceLocation loc) {
  29. if (isSameType(astContext, fromType, toType))
  30. return instr;
  31. if (toType->isBooleanType() || fromType->isBooleanType())
  32. return theEmitter.castToType(instr, fromType, toType, loc);
  33. // Perform a bitcast
  34. return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toType, instr, loc);
  35. }
  36. SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset0(
  37. SpirvInstruction *buffer, SpirvInstruction *&index,
  38. QualType target16BitType, uint32_t &bitOffset) {
  39. assert(bitOffset == 0);
  40. const auto loc = buffer->getSourceLocation();
  41. SpirvInstruction *result = nullptr;
  42. auto *constUint0 =
  43. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  44. // The underlying element type of the ByteAddressBuffer is uint. So we
  45. // need to load 32-bits at the very least.
  46. auto *loadPtr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  47. {constUint0, index}, loc);
  48. result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, loc);
  49. // Only need to mask the lowest 16 bits of the loaded 32-bit uint.
  50. // OpUConvert can perform truncation in this case.
  51. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  52. astContext.UnsignedShortTy, result, loc);
  53. result = bitCastToNumericalOrBool(result, astContext.UnsignedShortTy,
  54. target16BitType, loc);
  55. result->setRValue();
  56. // Now that a 16-bit load at bit-offset 0 has been performed, the next load
  57. // should be done at *the same base index* at bit-offset 16.
  58. bitOffset = (bitOffset + 16) % 32;
  59. return result;
  60. }
  61. SpirvInstruction *RawBufferHandler::load32BitsAtBitOffset0(
  62. SpirvInstruction *buffer, SpirvInstruction *&index,
  63. QualType target32BitType, uint32_t &bitOffset) {
  64. assert(bitOffset == 0);
  65. const auto loc = buffer->getSourceLocation();
  66. SpirvInstruction *result = nullptr;
  67. // Only need to perform one 32-bit uint load.
  68. auto *constUint0 =
  69. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  70. auto *constUint1 =
  71. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  72. auto *loadPtr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  73. {constUint0, index}, loc);
  74. result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, loc);
  75. result = bitCastToNumericalOrBool(result, astContext.UnsignedIntTy,
  76. target32BitType, loc);
  77. result->setRValue();
  78. // Now that a 32-bit load at bit-offset 0 has been performed, the next load
  79. // should be done at *the next base index* at bit-offset 0.
  80. bitOffset = (bitOffset + 32) % 32;
  81. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  82. index, constUint1, loc);
  83. return result;
  84. }
  85. SpirvInstruction *RawBufferHandler::load64BitsAtBitOffset0(
  86. SpirvInstruction *buffer, SpirvInstruction *&index,
  87. QualType target64BitType, uint32_t &bitOffset) {
  88. assert(bitOffset == 0);
  89. const auto loc = buffer->getSourceLocation();
  90. SpirvInstruction *result = nullptr;
  91. SpirvInstruction *ptr = nullptr;
  92. auto *constUint0 =
  93. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  94. auto *constUint1 =
  95. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  96. auto *constUint32 =
  97. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
  98. // Need to perform two 32-bit uint loads and construct a 64-bit value.
  99. // Load the first 32-bit uint (word0).
  100. ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  101. {constUint0, index}, loc);
  102. SpirvInstruction *word0 =
  103. spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc);
  104. // Increment the base index
  105. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  106. index, constUint1, loc);
  107. // Load the second 32-bit uint (word1).
  108. ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  109. {constUint0, index}, loc);
  110. SpirvInstruction *word1 =
  111. spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc);
  112. // Convert both word0 and word1 to 64-bit uints.
  113. word0 = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  114. astContext.UnsignedLongLongTy, word0, loc);
  115. word1 = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  116. astContext.UnsignedLongLongTy, word1, loc);
  117. // Shift word1 to the left by 32 bits.
  118. word1 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  119. astContext.UnsignedLongLongTy, word1,
  120. constUint32, loc);
  121. // BitwiseOr word0 and word1.
  122. result = spvBuilder.createBinaryOp(
  123. spv::Op::OpBitwiseOr, astContext.UnsignedLongLongTy, word0, word1, loc);
  124. result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
  125. target64BitType, loc);
  126. result->setRValue();
  127. // Now that a 64-bit load at bit-offset 0 has been performed, the next load
  128. // should be done at *the base index + 2* at bit-offset 0. The index has
  129. // already been incremented once. Need to increment it once more.
  130. bitOffset = (bitOffset + 64) % 32;
  131. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  132. index, constUint1, loc);
  133. return result;
  134. }
  135. SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset16(
  136. SpirvInstruction *buffer, SpirvInstruction *&index,
  137. QualType target16BitType, uint32_t &bitOffset) {
  138. assert(bitOffset == 16);
  139. const auto loc = buffer->getSourceLocation();
  140. SpirvInstruction *result = nullptr;
  141. auto *constUint0 =
  142. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  143. auto *constUint1 =
  144. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  145. auto *constUint16 =
  146. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  147. // The underlying element type of the ByteAddressBuffer is uint. So we
  148. // need to load 32-bits at the very least.
  149. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  150. {constUint0, index}, loc);
  151. result = spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc);
  152. result = spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
  153. astContext.UnsignedIntTy, result,
  154. constUint16, loc);
  155. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  156. astContext.UnsignedShortTy, result, loc);
  157. result = bitCastToNumericalOrBool(result, astContext.UnsignedShortTy,
  158. target16BitType, loc);
  159. result->setRValue();
  160. // Now that a 16-bit load at bit-offset 16 has been performed, the next load
  161. // should be done at *the next base index* at bit-offset 0.
  162. bitOffset = (bitOffset + 16) % 32;
  163. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  164. index, constUint1, loc);
  165. return result;
  166. }
  167. SpirvInstruction *RawBufferHandler::processTemplatedLoadFromBuffer(
  168. SpirvInstruction *buffer, SpirvInstruction *&index,
  169. const QualType targetType, uint32_t &bitOffset) {
  170. const auto loc = buffer->getSourceLocation();
  171. SpirvInstruction *result = nullptr;
  172. // TODO: If 8-bit types are to be supported in the future, we should also
  173. // add code to support bitOffset 8 and 24.
  174. assert(bitOffset == 0 || bitOffset == 16);
  175. // Scalar types
  176. if (isScalarType(targetType)) {
  177. auto loadWidth = getElementSpirvBitwidth(
  178. astContext, targetType, theEmitter.getSpirvOptions().enable16BitTypes);
  179. switch (bitOffset) {
  180. case 0: {
  181. switch (loadWidth) {
  182. case 16:
  183. return load16BitsAtBitOffset0(buffer, index, targetType, bitOffset);
  184. break;
  185. case 32:
  186. return load32BitsAtBitOffset0(buffer, index, targetType, bitOffset);
  187. break;
  188. case 64:
  189. return load64BitsAtBitOffset0(buffer, index, targetType, bitOffset);
  190. break;
  191. default:
  192. theEmitter.emitError(
  193. "templated load of ByteAddressBuffer is only implemented for "
  194. "16, 32, and 64-bit types",
  195. loc);
  196. return nullptr;
  197. }
  198. break;
  199. }
  200. case 16: {
  201. switch (loadWidth) {
  202. case 16:
  203. return load16BitsAtBitOffset16(buffer, index, targetType, bitOffset);
  204. break;
  205. case 32:
  206. case 64:
  207. theEmitter.emitError(
  208. "templated buffer load should not result in loading "
  209. "32-bit or 64-bit values at bit offset 16",
  210. loc);
  211. return nullptr;
  212. default:
  213. theEmitter.emitError(
  214. "templated load of ByteAddressBuffer is only implemented for "
  215. "16, 32, and 64-bit types",
  216. loc);
  217. return nullptr;
  218. }
  219. break;
  220. }
  221. default:
  222. theEmitter.emitError(
  223. "templated load of ByteAddressBuffer is only implemented for "
  224. "16, 32, and 64-bit types",
  225. loc);
  226. return nullptr;
  227. }
  228. }
  229. // Vector types
  230. {
  231. QualType elemType = {};
  232. uint32_t elemCount = 0;
  233. if (isVectorType(targetType, &elemType, &elemCount)) {
  234. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  235. for (uint32_t i = 0; i < elemCount; ++i) {
  236. loadedElems.push_back(
  237. processTemplatedLoadFromBuffer(buffer, index, elemType, bitOffset));
  238. }
  239. result =
  240. spvBuilder.createCompositeConstruct(targetType, loadedElems, loc);
  241. result->setRValue();
  242. return result;
  243. }
  244. }
  245. // Array types
  246. {
  247. QualType elemType = {};
  248. uint32_t elemCount = 0;
  249. if (const auto *arrType = astContext.getAsConstantArrayType(targetType)) {
  250. elemCount = static_cast<uint32_t>(arrType->getSize().getZExtValue());
  251. elemType = arrType->getElementType();
  252. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  253. for (uint32_t i = 0; i < elemCount; ++i) {
  254. loadedElems.push_back(
  255. processTemplatedLoadFromBuffer(buffer, index, elemType, bitOffset));
  256. }
  257. result =
  258. spvBuilder.createCompositeConstruct(targetType, loadedElems, loc);
  259. result->setRValue();
  260. return result;
  261. }
  262. }
  263. // Matrix types
  264. {
  265. QualType elemType = {};
  266. uint32_t numRows = 0, numCols = 0;
  267. if (isMxNMatrix(targetType, &elemType, &numRows, &numCols)) {
  268. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  269. llvm::SmallVector<SpirvInstruction *, 4> loadedRows;
  270. for (uint32_t i = 0; i < numRows; ++i) {
  271. for (uint32_t j = 0; j < numCols; ++j) {
  272. // TODO: This is currently doing a row_major matrix load. We must
  273. // investigate whether we also need to implement it for column_major.
  274. loadedElems.push_back(processTemplatedLoadFromBuffer(
  275. buffer, index, elemType, bitOffset));
  276. }
  277. const auto rowType = astContext.getExtVectorType(elemType, numCols);
  278. loadedRows.push_back(
  279. spvBuilder.createCompositeConstruct(rowType, loadedElems, loc));
  280. loadedElems.clear();
  281. }
  282. result = spvBuilder.createCompositeConstruct(targetType, loadedRows, loc);
  283. result->setRValue();
  284. return result;
  285. }
  286. }
  287. // Struct types
  288. // The "natural" layout for structure types dictates that structs are
  289. // aligned like their field with the largest alignment.
  290. // As a result, there might exist some padding after some struct members.
  291. if (const auto *structType = targetType->getAs<RecordType>()) {
  292. const auto *decl = structType->getDecl();
  293. SpirvInstruction *originalIndex = index;
  294. uint32_t originalBitOffset = bitOffset;
  295. llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
  296. uint32_t fieldOffsetInBytes = 0;
  297. uint32_t structAlignment = 0, structSize = 0, stride = 0;
  298. std::tie(structAlignment, structSize) =
  299. AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
  300. .getAlignmentAndSize(targetType,
  301. theEmitter.getSpirvOptions().sBufferLayoutRule,
  302. llvm::None, &stride);
  303. for (const auto *field : decl->fields()) {
  304. AlignmentSizeCalculator alignmentCalc(astContext,
  305. theEmitter.getSpirvOptions());
  306. uint32_t fieldSize = 0, fieldAlignment = 0;
  307. std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
  308. field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
  309. /*isRowMajor*/ llvm::None, &stride);
  310. fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
  311. const auto wordOffset =
  312. ((originalBitOffset / 8) + fieldOffsetInBytes) / 4;
  313. bitOffset = (((originalBitOffset / 8) + fieldOffsetInBytes) % 4) * 8;
  314. if (wordOffset != 0) {
  315. // Divide the fieldOffset by 4 to figure out how much to increment the
  316. // index into the buffer (increment occurs by 32-bit words since the
  317. // underlying type is an array of uints).
  318. // The remainder by four tells us the *byte offset* (then multiply by 8
  319. // to get bit offset).
  320. index = spvBuilder.createBinaryOp(
  321. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  322. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  323. llvm::APInt(32, wordOffset)),
  324. loc);
  325. }
  326. loadedElems.push_back(processTemplatedLoadFromBuffer(
  327. buffer, index, field->getType(), bitOffset));
  328. fieldOffsetInBytes += fieldSize;
  329. }
  330. // After we're done with loading the entire struct, we need to update the
  331. // index and bitOffset (in case we are loading an array of structs).
  332. //
  333. // Example: struct alignment = 8. struct size = 34 bytes
  334. // (34 / 8) = 4 full words
  335. // (34 % 8) = 2 > 0, therefore need to move to the next aligned address
  336. // So the starting byte offset after loading the entire struct is:
  337. // 8 * (4 + 1) = 40
  338. assert(structAlignment != 0);
  339. uint32_t newByteOffset = roundToPow2(structSize, structAlignment);
  340. uint32_t newWordOffset = ((originalBitOffset / 8) + newByteOffset) / 4;
  341. bitOffset = 8 * (((originalBitOffset / 8) + newByteOffset) % 4);
  342. index = spvBuilder.createBinaryOp(
  343. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  344. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  345. llvm::APInt(32, newWordOffset)),
  346. loc);
  347. result = spvBuilder.createCompositeConstruct(targetType, loadedElems, loc);
  348. result->setRValue();
  349. return result;
  350. }
  351. llvm_unreachable("templated buffer load unimplemented for type");
  352. }
  353. void RawBufferHandler::store16BitsAtBitOffset0(SpirvInstruction *value,
  354. SpirvInstruction *buffer,
  355. SpirvInstruction *&index,
  356. const QualType valueType) {
  357. const auto loc = buffer->getSourceLocation();
  358. SpirvInstruction *result = nullptr;
  359. auto *constUint0 =
  360. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  361. // The underlying element type of the ByteAddressBuffer is uint. So we
  362. // need to store a 32-bit value.
  363. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  364. {constUint0, index}, loc);
  365. result = bitCastToNumericalOrBool(value, valueType,
  366. astContext.UnsignedShortTy, loc);
  367. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  368. astContext.UnsignedIntTy, result, loc);
  369. spvBuilder.createStore(ptr, result, loc);
  370. }
  371. void RawBufferHandler::store16BitsAtBitOffset16(SpirvInstruction *value,
  372. SpirvInstruction *buffer,
  373. SpirvInstruction *&index,
  374. const QualType valueType) {
  375. const auto loc = buffer->getSourceLocation();
  376. SpirvInstruction *result = nullptr;
  377. auto *constUint0 =
  378. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  379. auto *constUint1 =
  380. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  381. auto *constUint16 =
  382. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  383. // The underlying element type of the ByteAddressBuffer is uint. So we
  384. // need to store a 32-bit value.
  385. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  386. {constUint0, index}, loc);
  387. result = bitCastToNumericalOrBool(value, valueType,
  388. astContext.UnsignedShortTy, loc);
  389. result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  390. astContext.UnsignedIntTy, result, loc);
  391. result = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  392. astContext.UnsignedIntTy, result,
  393. constUint16, loc);
  394. result = spvBuilder.createBinaryOp(
  395. spv::Op::OpBitwiseOr, astContext.UnsignedIntTy,
  396. spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc), result, loc);
  397. spvBuilder.createStore(ptr, result, loc);
  398. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  399. index, constUint1, loc);
  400. }
  401. void RawBufferHandler::store32BitsAtBitOffset0(SpirvInstruction *value,
  402. SpirvInstruction *buffer,
  403. SpirvInstruction *&index,
  404. const QualType valueType) {
  405. const auto loc = buffer->getSourceLocation();
  406. auto *constUint0 =
  407. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  408. auto *constUint1 =
  409. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  410. // The underlying element type of the ByteAddressBuffer is uint. So we
  411. // need to store a 32-bit value.
  412. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  413. {constUint0, index}, loc);
  414. value =
  415. bitCastToNumericalOrBool(value, valueType, astContext.UnsignedIntTy, loc);
  416. spvBuilder.createStore(ptr, value, loc);
  417. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  418. index, constUint1, loc);
  419. }
  420. void RawBufferHandler::store64BitsAtBitOffset0(SpirvInstruction *value,
  421. SpirvInstruction *buffer,
  422. SpirvInstruction *&index,
  423. const QualType valueType) {
  424. const auto loc = buffer->getSourceLocation();
  425. auto *constUint0 =
  426. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  427. auto *constUint1 =
  428. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  429. auto *constUint32 =
  430. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
  431. // The underlying element type of the ByteAddressBuffer is uint. So we
  432. // need to store two 32-bit values.
  433. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  434. {constUint0, index}, loc);
  435. // First convert the 64-bit value to uint64_t. Then extract two 32-bit words
  436. // from it.
  437. value = bitCastToNumericalOrBool(value, valueType,
  438. astContext.UnsignedLongLongTy, loc);
  439. // Use OpUConvert to perform truncation (produces the least significant bits).
  440. SpirvInstruction *lsb = spvBuilder.createUnaryOp(
  441. spv::Op::OpUConvert, astContext.UnsignedIntTy, value, loc);
  442. // Shift uint64_t to the right by 32 bits and truncate to get the most
  443. // significant bits.
  444. SpirvInstruction *msb = spvBuilder.createUnaryOp(
  445. spv::Op::OpUConvert, astContext.UnsignedIntTy,
  446. spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
  447. astContext.UnsignedLongLongTy, value,
  448. constUint32, loc),
  449. loc);
  450. spvBuilder.createStore(ptr, lsb, loc);
  451. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  452. index, constUint1, loc);
  453. ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  454. {constUint0, index}, loc);
  455. spvBuilder.createStore(ptr, msb, loc);
  456. index = spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  457. index, constUint1, loc);
  458. }
  459. void RawBufferHandler::storeArrayOfScalars(
  460. std::deque<SpirvInstruction *> values, SpirvInstruction *buffer,
  461. SpirvInstruction *&index, const QualType valueType, uint32_t &bitOffset,
  462. SourceLocation loc) {
  463. auto *constUint0 =
  464. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  465. auto *constUint1 =
  466. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  467. auto *constUint16 =
  468. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  469. const auto storeWidth = getElementSpirvBitwidth(
  470. astContext, valueType, theEmitter.getSpirvOptions().enable16BitTypes);
  471. const uint32_t elemCount = values.size();
  472. if (storeWidth == 16u) {
  473. uint32_t elemIndex = 0;
  474. if (bitOffset == 16) {
  475. // First store the first element at offset 16 of the last memory index.
  476. store16BitsAtBitOffset16(values[0], buffer, index, valueType);
  477. bitOffset = 0;
  478. ++elemIndex;
  479. }
  480. // Do a custom store based on the number of elements.
  481. for (; elemIndex < elemCount; elemIndex = elemIndex + 2) {
  482. // The underlying element type of the ByteAddressBuffer is uint. So we
  483. // need to store a 32-bit value by combining two 16-bit values.
  484. SpirvInstruction *word = nullptr;
  485. word = bitCastToNumericalOrBool(values[elemIndex], valueType,
  486. astContext.UnsignedShortTy, loc);
  487. // Zero-extend to 32 bits.
  488. word = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  489. astContext.UnsignedIntTy, word, loc);
  490. if (elemIndex + 1 < elemCount) {
  491. SpirvInstruction *msb = nullptr;
  492. msb = bitCastToNumericalOrBool(values[elemIndex + 1], valueType,
  493. astContext.UnsignedShortTy, loc);
  494. msb = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
  495. astContext.UnsignedIntTy, msb, loc);
  496. msb = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  497. astContext.UnsignedIntTy, msb,
  498. constUint16, loc);
  499. word = spvBuilder.createBinaryOp(
  500. spv::Op::OpBitwiseOr, astContext.UnsignedIntTy, word, msb, loc);
  501. // We will store two 16-bit values.
  502. bitOffset = (bitOffset + 32) % 32;
  503. } else {
  504. // We will store one 16-bit value.
  505. bitOffset = (bitOffset + 16) % 32;
  506. }
  507. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
  508. {constUint0, index}, loc);
  509. spvBuilder.createStore(ptr, word, loc);
  510. index = spvBuilder.createBinaryOp(
  511. spv::Op::OpIAdd, astContext.UnsignedIntTy, index, constUint1, loc);
  512. }
  513. } else if (storeWidth == 32u || storeWidth == 64u) {
  514. assert(bitOffset == 0);
  515. for (uint32_t i = 0; i < elemCount; ++i)
  516. processTemplatedStoreToBuffer(values[i], buffer, index, valueType, bitOffset);
  517. }
  518. }
  519. QualType RawBufferHandler::serializeToScalarsOrStruct(
  520. std::deque<SpirvInstruction *> *values, QualType valueType,
  521. SourceLocation loc) {
  522. uint32_t size = values->size();
  523. // Vector type
  524. {
  525. QualType elemType = {};
  526. uint32_t elemCount = 0;
  527. if (isVectorType(valueType, &elemType, &elemCount)) {
  528. for (uint32_t i = 0; i < size; ++i) {
  529. for (uint32_t j = 0; j < elemCount; ++j) {
  530. values->push_back(spvBuilder.createCompositeExtract(
  531. elemType, values->front(), {j}, loc));
  532. }
  533. values->pop_front();
  534. }
  535. return elemType;
  536. }
  537. }
  538. // Matrix type
  539. {
  540. QualType elemType = {};
  541. uint32_t numRows = 0, numCols = 0;
  542. if (isMxNMatrix(valueType, &elemType, &numRows, &numCols)) {
  543. for (uint32_t i = 0; i < size; ++i) {
  544. for (uint32_t j = 0; j < numRows; ++j) {
  545. for (uint32_t k = 0; k < numCols; ++k) {
  546. // TODO: This is currently doing a row_major matrix store. We must
  547. // investigate whether we also need to implement it for
  548. // column_major.
  549. values->push_back(spvBuilder.createCompositeExtract(
  550. elemType, values->front(), {j, k}, loc));
  551. }
  552. }
  553. values->pop_front();
  554. }
  555. return serializeToScalarsOrStruct(values, elemType, loc);
  556. }
  557. }
  558. // Array type
  559. {
  560. if (const auto *arrType = astContext.getAsConstantArrayType(valueType)) {
  561. const uint32_t arrElemCount =
  562. static_cast<uint32_t>(arrType->getSize().getZExtValue());
  563. const QualType arrElemType = arrType->getElementType();
  564. for (uint32_t i = 0; i < size; ++i) {
  565. for (uint32_t j = 0; j < arrElemCount; ++j) {
  566. values->push_back(spvBuilder.createCompositeExtract(
  567. arrElemType, values->front(), {j}, loc));
  568. }
  569. values->pop_front();
  570. }
  571. return serializeToScalarsOrStruct(values, arrElemType, loc);
  572. }
  573. }
  574. if (isScalarType(valueType))
  575. return valueType;
  576. if (const auto *structType = valueType->getAs<RecordType>())
  577. return valueType;
  578. llvm_unreachable("unhandled type when serializing an array");
  579. }
  580. void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
  581. SpirvInstruction *buffer,
  582. SpirvInstruction *&index,
  583. const QualType valueType,
  584. uint32_t &bitOffset) {
  585. assert(bitOffset == 0 || bitOffset == 16);
  586. const auto loc = buffer->getSourceLocation();
  587. // Scalar types
  588. if (isScalarType(valueType)) {
  589. auto storeWidth = getElementSpirvBitwidth(
  590. astContext, valueType, theEmitter.getSpirvOptions().enable16BitTypes);
  591. switch (bitOffset) {
  592. case 0: {
  593. switch (storeWidth) {
  594. case 16:
  595. store16BitsAtBitOffset0(value, buffer, index, valueType);
  596. return;
  597. case 32:
  598. store32BitsAtBitOffset0(value, buffer, index, valueType);
  599. return;
  600. case 64:
  601. store64BitsAtBitOffset0(value, buffer, index, valueType);
  602. return;
  603. default:
  604. theEmitter.emitError(
  605. "templated load of ByteAddressBuffer is only implemented for "
  606. "16, 32, and 64-bit types",
  607. loc);
  608. return;
  609. }
  610. }
  611. case 16: {
  612. // The only legal store at offset 16 is by a 16-bit value.
  613. assert(storeWidth == 16);
  614. store16BitsAtBitOffset16(value, buffer, index, valueType);
  615. return;
  616. }
  617. default:
  618. theEmitter.emitError(
  619. "templated load of ByteAddressBuffer is only implemented for "
  620. "16, 32, and 64-bit types",
  621. loc);
  622. return;
  623. }
  624. }
  625. // Vectors, Matrices, and Arrays can all be serialized and stored.
  626. if (isVectorType(valueType) || isMxNMatrix(valueType) ||
  627. isConstantArrayType(astContext, valueType)) {
  628. std::deque<SpirvInstruction *> elems;
  629. elems.push_back(value);
  630. auto serializedType = serializeToScalarsOrStruct(&elems, valueType, loc);
  631. if (isScalarType(serializedType)) {
  632. storeArrayOfScalars(elems, buffer, index, serializedType, bitOffset, loc);
  633. } else if (const auto *structType = serializedType->getAs<RecordType>()) {
  634. for (auto elem : elems)
  635. processTemplatedStoreToBuffer(elem, buffer, index, serializedType,
  636. bitOffset);
  637. }
  638. return;
  639. }
  640. // Struct types
  641. // The "natural" layout for structure types dictates that structs are
  642. // aligned like their field with the largest alignment.
  643. // As a result, there might exist some padding after some struct members.
  644. if (const auto *structType = valueType->getAs<RecordType>()) {
  645. const auto *decl = structType->getDecl();
  646. SpirvInstruction *originalIndex = index;
  647. const auto originalBitOffset = bitOffset;
  648. uint32_t fieldOffsetInBytes = 0;
  649. uint32_t structAlignment = 0, structSize = 0, stride = 0;
  650. std::tie(structAlignment, structSize) =
  651. AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
  652. .getAlignmentAndSize(valueType,
  653. theEmitter.getSpirvOptions().sBufferLayoutRule,
  654. llvm::None, &stride);
  655. uint32_t fieldIndex = 0;
  656. for (const auto *field : decl->fields()) {
  657. AlignmentSizeCalculator alignmentCalc(astContext,
  658. theEmitter.getSpirvOptions());
  659. uint32_t fieldSize = 0, fieldAlignment = 0;
  660. std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
  661. field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
  662. /*isRowMajor*/ llvm::None, &stride);
  663. fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
  664. const auto wordOffset =
  665. ((originalBitOffset / 8) + fieldOffsetInBytes) / 4;
  666. bitOffset = (((originalBitOffset / 8) + fieldOffsetInBytes) % 4) * 8;
  667. if (wordOffset != 0) {
  668. // Divide the fieldOffset by 4 to figure out how much to increment the
  669. // index into the buffer (increment occurs by 32-bit words since the
  670. // underlying type is an array of uints).
  671. // The remainder by four tells us the *byte offset* (then multiply by 8
  672. // to get bit offset).
  673. index = spvBuilder.createBinaryOp(
  674. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  675. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  676. llvm::APInt(32, wordOffset)),
  677. loc);
  678. }
  679. processTemplatedStoreToBuffer(
  680. spvBuilder.createCompositeExtract(field->getType(), value,
  681. {fieldIndex}, loc),
  682. buffer, index, field->getType(), bitOffset);
  683. fieldOffsetInBytes += fieldSize;
  684. ++fieldIndex;
  685. }
  686. // After we're done with storing the entire struct, we need to update the
  687. // index (in case we are stroring an array of structs).
  688. //
  689. // Example: struct alignment = 8. struct size = 34 bytes
  690. // (34 / 8) = 4 full words
  691. // (34 % 8) = 2 > 0, therefore need to move to the next aligned address
  692. // So the starting byte offset after loading the entire struct is:
  693. // 8 * (4 + 1) = 40
  694. assert(structAlignment != 0);
  695. uint32_t newByteOffset = roundToPow2(structSize, structAlignment);
  696. uint32_t newWordOffset = ((originalBitOffset / 8) + newByteOffset) / 4;
  697. bitOffset = 8 * (((originalBitOffset / 8) + newByteOffset) % 4);
  698. index = spvBuilder.createBinaryOp(
  699. spv::Op::OpIAdd, astContext.UnsignedIntTy, originalIndex,
  700. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  701. llvm::APInt(32, newWordOffset)),
  702. loc);
  703. return;
  704. }
  705. llvm_unreachable("templated buffer store unimplemented for type");
  706. }
  707. } // namespace spirv
  708. } // namespace clang