SpirvContext.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. //===--- SpirvContext.cpp - SPIR-V SpirvContext implementation-------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include <algorithm>
  10. #include <tuple>
  11. #include "clang/SPIRV/SpirvContext.h"
  12. #include "clang/SPIRV/SpirvModule.h"
  13. namespace clang {
  14. namespace spirv {
  15. SpirvContext::SpirvContext()
  16. : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
  17. uintTypes({}), floatTypes({}), samplerType(nullptr),
  18. curShaderModelKind(ShaderModelKind::Invalid), majorVersion(0),
  19. minorVersion(0), currentLexicalScope(nullptr) {
  20. voidType = new (this) VoidType;
  21. boolType = new (this) BoolType;
  22. samplerType = new (this) SamplerType;
  23. accelerationStructureTypeNV = new (this) AccelerationStructureTypeNV;
  24. rayQueryTypeKHR = new (this) RayQueryTypeKHR;
  25. }
  26. SpirvContext::~SpirvContext() {
  27. voidType->~VoidType();
  28. boolType->~BoolType();
  29. samplerType->~SamplerType();
  30. accelerationStructureTypeNV->~AccelerationStructureTypeNV();
  31. rayQueryTypeKHR->~RayQueryTypeKHR();
  32. for (auto *sintType : sintTypes)
  33. if (sintType) // sintTypes may contain nullptr
  34. sintType->~IntegerType();
  35. for (auto *uintType : uintTypes)
  36. if (uintType) // uintTypes may contain nullptr
  37. uintType->~IntegerType();
  38. for (auto *floatType : floatTypes)
  39. if (floatType) // floatTypes may contain nullptr
  40. floatType->~FloatType();
  41. for (auto &pair : vecTypes)
  42. for (auto *vecType : pair.second)
  43. if (vecType) // vecTypes may contain nullptr
  44. vecType->~VectorType();
  45. for (auto &pair : matTypes)
  46. for (auto *matType : pair.second)
  47. matType->~MatrixType();
  48. for (auto *arrType : arrayTypes)
  49. arrType->~ArrayType();
  50. for (auto *raType : runtimeArrayTypes)
  51. raType->~RuntimeArrayType();
  52. for (auto *fnType : functionTypes)
  53. fnType->~FunctionType();
  54. for (auto *structType : structTypes)
  55. structType->~StructType();
  56. for (auto *hybridStructType : hybridStructTypes)
  57. hybridStructType->~HybridStructType();
  58. for (auto pair : sampledImageTypes)
  59. pair.second->~SampledImageType();
  60. for (auto *hybridSampledImageType : hybridSampledImageTypes)
  61. hybridSampledImageType->~HybridSampledImageType();
  62. for (auto *imgType : imageTypes)
  63. imgType->~ImageType();
  64. for (auto &pair : pointerTypes)
  65. for (auto &scPtrTypePair : pair.second)
  66. scPtrTypePair.second->~SpirvPointerType();
  67. for (auto *hybridPtrType : hybridPointerTypes)
  68. hybridPtrType->~HybridPointerType();
  69. for (auto &typePair : debugTypes)
  70. typePair.second->releaseMemory();
  71. for (auto &typePair : typeTemplates)
  72. typePair.second->releaseMemory();
  73. for (auto &typePair : typeTemplateParams)
  74. typePair.second->releaseMemory();
  75. }
  76. inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
  77. assert(bitwidth >= 8 && bitwidth <= 64 && llvm::isPowerOf2_32(bitwidth));
  78. return llvm::Log2_32(bitwidth);
  79. }
  80. const IntegerType *SpirvContext::getSIntType(uint32_t bitwidth) {
  81. auto &type = sintTypes[log2ForBitwidth(bitwidth)];
  82. if (type == nullptr) {
  83. type = new (this) IntegerType(bitwidth, true);
  84. }
  85. return type;
  86. }
  87. const IntegerType *SpirvContext::getUIntType(uint32_t bitwidth) {
  88. auto &type = uintTypes[log2ForBitwidth(bitwidth)];
  89. if (type == nullptr) {
  90. type = new (this) IntegerType(bitwidth, false);
  91. }
  92. return type;
  93. }
  94. const FloatType *SpirvContext::getFloatType(uint32_t bitwidth) {
  95. auto &type = floatTypes[log2ForBitwidth(bitwidth)];
  96. if (type == nullptr) {
  97. type = new (this) FloatType(bitwidth);
  98. }
  99. return type;
  100. }
  101. const VectorType *SpirvContext::getVectorType(const SpirvType *elemType,
  102. uint32_t count) {
  103. // We are certain this should be a scalar type. Otherwise, cast causes an
  104. // assertion failure.
  105. const ScalarType *scalarType = cast<ScalarType>(elemType);
  106. assert(count == 2 || count == 3 || count == 4);
  107. auto found = vecTypes.find(scalarType);
  108. if (found != vecTypes.end()) {
  109. auto &type = found->second[count];
  110. if (type != nullptr)
  111. return type;
  112. } else {
  113. // Make sure to initialize since std::array is "an aggregate type with the
  114. // same semantics as a struct holding a C-style array T[N]".
  115. vecTypes[scalarType] = {};
  116. }
  117. return vecTypes[scalarType][count] = new (this) VectorType(scalarType, count);
  118. }
  119. const SpirvType *SpirvContext::getMatrixType(const SpirvType *elemType,
  120. uint32_t count) {
  121. // We are certain this should be a vector type. Otherwise, cast causes an
  122. // assertion failure.
  123. const VectorType *vecType = cast<VectorType>(elemType);
  124. assert(count == 2 || count == 3 || count == 4);
  125. // In the case of non-floating-point matrices, we represent them as array of
  126. // vectors.
  127. if (!isa<FloatType>(vecType->getElementType())) {
  128. return getArrayType(elemType, count, llvm::None);
  129. }
  130. auto foundVec = matTypes.find(vecType);
  131. if (foundVec != matTypes.end()) {
  132. const auto &matVector = foundVec->second;
  133. // Create a temporary object for finding in the vector.
  134. MatrixType type(vecType, count);
  135. for (const auto *cachedType : matVector)
  136. if (type == *cachedType)
  137. return cachedType;
  138. }
  139. const auto *ptr = new (this) MatrixType(vecType, count);
  140. matTypes[vecType].push_back(ptr);
  141. return ptr;
  142. }
  143. const ImageType *
  144. SpirvContext::getImageType(const ImageType *imageTypeWithUnknownFormat,
  145. spv::ImageFormat format) {
  146. return getImageType(imageTypeWithUnknownFormat->getSampledType(),
  147. imageTypeWithUnknownFormat->getDimension(),
  148. imageTypeWithUnknownFormat->getDepth(),
  149. imageTypeWithUnknownFormat->isArrayedImage(),
  150. imageTypeWithUnknownFormat->isMSImage(),
  151. imageTypeWithUnknownFormat->withSampler(), format);
  152. }
  153. const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
  154. spv::Dim dim,
  155. ImageType::WithDepth depth,
  156. bool arrayed, bool ms,
  157. ImageType::WithSampler sampled,
  158. spv::ImageFormat format) {
  159. // We are certain this should be a numerical type. Otherwise, cast causes an
  160. // assertion failure.
  161. const NumericalType *elemType = cast<NumericalType>(sampledType);
  162. // Create a temporary object for finding in the set.
  163. ImageType type(elemType, dim, depth, arrayed, ms, sampled, format);
  164. auto found = imageTypes.find(&type);
  165. if (found != imageTypes.end())
  166. return *found;
  167. auto inserted = imageTypes.insert(
  168. new (this) ImageType(elemType, dim, depth, arrayed, ms, sampled, format));
  169. return *(inserted.first);
  170. }
  171. const SampledImageType *
  172. SpirvContext::getSampledImageType(const ImageType *image) {
  173. auto found = sampledImageTypes.find(image);
  174. if (found != sampledImageTypes.end())
  175. return found->second;
  176. return sampledImageTypes[image] = new (this) SampledImageType(image);
  177. }
  178. const HybridSampledImageType *
  179. SpirvContext::getSampledImageType(QualType image) {
  180. const HybridSampledImageType *result =
  181. new (this) HybridSampledImageType(image);
  182. hybridSampledImageTypes.push_back(result);
  183. return result;
  184. }
  185. const ArrayType *
  186. SpirvContext::getArrayType(const SpirvType *elemType, uint32_t elemCount,
  187. llvm::Optional<uint32_t> arrayStride) {
  188. ArrayType type(elemType, elemCount, arrayStride);
  189. auto found = arrayTypes.find(&type);
  190. if (found != arrayTypes.end())
  191. return *found;
  192. auto inserted =
  193. arrayTypes.insert(new (this) ArrayType(elemType, elemCount, arrayStride));
  194. // The return value is an (iterator, bool) pair. The boolean indicates whether
  195. // it was actually added as a new type.
  196. return *(inserted.first);
  197. }
  198. const RuntimeArrayType *
  199. SpirvContext::getRuntimeArrayType(const SpirvType *elemType,
  200. llvm::Optional<uint32_t> arrayStride) {
  201. RuntimeArrayType type(elemType, arrayStride);
  202. auto found = runtimeArrayTypes.find(&type);
  203. if (found != runtimeArrayTypes.end())
  204. return *found;
  205. auto inserted = runtimeArrayTypes.insert(
  206. new (this) RuntimeArrayType(elemType, arrayStride));
  207. return *(inserted.first);
  208. }
  209. const StructType *
  210. SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
  211. llvm::StringRef name, bool isReadOnly,
  212. StructInterfaceType interfaceType) {
  213. // We are creating a temporary struct type here for querying whether the
  214. // same type was already created. It is a little bit costly, but we can
  215. // avoid allocating directly from the bump pointer allocator, from which
  216. // then we are unable to reclaim until the allocator itself is destroyed.
  217. StructType type(fields, name, isReadOnly, interfaceType);
  218. auto found = std::find_if(
  219. structTypes.begin(), structTypes.end(),
  220. [&type](const StructType *cachedType) { return type == *cachedType; });
  221. if (found != structTypes.end())
  222. return *found;
  223. structTypes.push_back(
  224. new (this) StructType(fields, name, isReadOnly, interfaceType));
  225. return structTypes.back();
  226. }
  227. const HybridStructType *SpirvContext::getHybridStructType(
  228. llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
  229. bool isReadOnly, StructInterfaceType interfaceType) {
  230. const HybridStructType *result =
  231. new (this) HybridStructType(fields, name, isReadOnly, interfaceType);
  232. hybridStructTypes.push_back(result);
  233. return result;
  234. }
  235. const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
  236. spv::StorageClass sc) {
  237. auto foundPointee = pointerTypes.find(pointee);
  238. if (foundPointee != pointerTypes.end()) {
  239. auto &pointeeMap = foundPointee->second;
  240. auto foundSC = pointeeMap.find(sc);
  241. if (foundSC != pointeeMap.end())
  242. return foundSC->second;
  243. }
  244. return pointerTypes[pointee][sc] = new (this) SpirvPointerType(pointee, sc);
  245. }
  246. const HybridPointerType *SpirvContext::getPointerType(QualType pointee,
  247. spv::StorageClass sc) {
  248. const HybridPointerType *result = new (this) HybridPointerType(pointee, sc);
  249. hybridPointerTypes.push_back(result);
  250. return result;
  251. }
  252. FunctionType *
  253. SpirvContext::getFunctionType(const SpirvType *ret,
  254. llvm::ArrayRef<const SpirvType *> param) {
  255. // Create a temporary object for finding in the set.
  256. FunctionType type(ret, param);
  257. auto found = functionTypes.find(&type);
  258. if (found != functionTypes.end())
  259. return *found;
  260. auto inserted = functionTypes.insert(new (this) FunctionType(ret, param));
  261. return *inserted.first;
  262. }
  263. const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
  264. // Create a uint RuntimeArray.
  265. const auto *raType =
  266. getRuntimeArrayType(getUIntType(32), /* ArrayStride */ 4);
  267. // Create a struct containing the runtime array as its only member.
  268. return getStructType(
  269. {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0)},
  270. isWritable ? "type.RWByteAddressBuffer" : "type.ByteAddressBuffer",
  271. !isWritable, StructInterfaceType::StorageBuffer);
  272. }
  273. const StructType *SpirvContext::getACSBufferCounterType() {
  274. // Create int32.
  275. const auto *int32Type = getSIntType(32);
  276. // Create a struct containing the integer counter as its only member.
  277. const StructType *type =
  278. getStructType({StructType::FieldInfo(int32Type, "counter", /*offset*/ 0)},
  279. "type.ACSBuffer.counter",
  280. /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
  281. return type;
  282. }
  283. SpirvDebugType *SpirvContext::getDebugTypeBasic(const SpirvType *spirvType,
  284. llvm::StringRef name,
  285. SpirvConstant *size,
  286. uint32_t encoding) {
  287. // Reuse existing debug type if possible.
  288. if (debugTypes.find(spirvType) != debugTypes.end())
  289. return debugTypes[spirvType];
  290. auto *debugType = new (this) SpirvDebugTypeBasic(name, size, encoding);
  291. debugTypes[spirvType] = debugType;
  292. return debugType;
  293. }
  294. SpirvDebugType *
  295. SpirvContext::getDebugTypeMember(llvm::StringRef name, SpirvDebugType *type,
  296. SpirvDebugSource *source, uint32_t line,
  297. uint32_t column, SpirvDebugInstruction *parent,
  298. uint32_t flags, uint32_t offsetInBits,
  299. uint32_t sizeInBits, const APValue *value) {
  300. // NOTE: Do not search it in debugTypes because it would have the same
  301. // spirvType but has different parent i.e., type composite.
  302. SpirvDebugTypeMember *debugType =
  303. new (this) SpirvDebugTypeMember(name, type, source, line, column, parent,
  304. flags, offsetInBits, sizeInBits, value);
  305. return debugType;
  306. }
  307. SpirvDebugTypeComposite *SpirvContext::getDebugTypeComposite(
  308. const SpirvType *spirvType, llvm::StringRef name, SpirvDebugSource *source,
  309. uint32_t line, uint32_t column, SpirvDebugInstruction *parent,
  310. llvm::StringRef linkageName, uint32_t flags, uint32_t tag) {
  311. // Reuse existing debug type if possible.
  312. auto it = debugTypes.find(spirvType);
  313. if (it != debugTypes.end()) {
  314. assert(it->second != nullptr && isa<SpirvDebugTypeComposite>(it->second));
  315. return dyn_cast<SpirvDebugTypeComposite>(it->second);
  316. }
  317. auto *debugType = new (this) SpirvDebugTypeComposite(
  318. name, source, line, column, parent, linkageName, flags, tag);
  319. debugType->setDebugSpirvType(spirvType);
  320. debugTypes[spirvType] = debugType;
  321. return debugType;
  322. }
  323. SpirvDebugType *SpirvContext::getDebugType(const SpirvType *spirvType) {
  324. auto it = debugTypes.find(spirvType);
  325. if (it != debugTypes.end())
  326. return it->second;
  327. return nullptr;
  328. }
  329. SpirvDebugType *
  330. SpirvContext::getDebugTypeArray(const SpirvType *spirvType,
  331. SpirvDebugInstruction *elemType,
  332. llvm::ArrayRef<uint32_t> elemCount) {
  333. // Reuse existing debug type if possible.
  334. if (debugTypes.find(spirvType) != debugTypes.end())
  335. return debugTypes[spirvType];
  336. auto *eTy = dyn_cast<SpirvDebugType>(elemType);
  337. assert(eTy && "Element type must be a SpirvDebugType.");
  338. auto *debugType = new (this) SpirvDebugTypeArray(eTy, elemCount);
  339. debugTypes[spirvType] = debugType;
  340. return debugType;
  341. }
  342. SpirvDebugType *
  343. SpirvContext::getDebugTypeVector(const SpirvType *spirvType,
  344. SpirvDebugInstruction *elemType,
  345. uint32_t elemCount) {
  346. // Reuse existing debug type if possible.
  347. if (debugTypes.find(spirvType) != debugTypes.end())
  348. return debugTypes[spirvType];
  349. auto *eTy = dyn_cast<SpirvDebugType>(elemType);
  350. assert(eTy && "Element type must be a SpirvDebugType.");
  351. auto *debugType = new (this) SpirvDebugTypeVector(eTy, elemCount);
  352. debugTypes[spirvType] = debugType;
  353. return debugType;
  354. }
  355. SpirvDebugType *
  356. SpirvContext::getDebugTypeFunction(const SpirvType *spirvType, uint32_t flags,
  357. SpirvDebugType *ret,
  358. llvm::ArrayRef<SpirvDebugType *> params) {
  359. // Reuse existing debug type if possible.
  360. if (debugTypes.find(spirvType) != debugTypes.end())
  361. return debugTypes[spirvType];
  362. auto *debugType = new (this) SpirvDebugTypeFunction(flags, ret, params);
  363. debugTypes[spirvType] = debugType;
  364. return debugType;
  365. }
  366. SpirvDebugTypeTemplate *SpirvContext::createDebugTypeTemplate(
  367. const ClassTemplateSpecializationDecl *templateType,
  368. SpirvDebugInstruction *target,
  369. const llvm::SmallVector<SpirvDebugTypeTemplateParameter *, 2> &params) {
  370. auto *tempTy = getDebugTypeTemplate(templateType);
  371. if (tempTy != nullptr)
  372. return tempTy;
  373. tempTy = new (this) SpirvDebugTypeTemplate(target, params);
  374. typeTemplates[templateType] = tempTy;
  375. return tempTy;
  376. }
  377. SpirvDebugTypeTemplate *SpirvContext::getDebugTypeTemplate(
  378. const ClassTemplateSpecializationDecl *templateType) {
  379. auto it = typeTemplates.find(templateType);
  380. if (it != typeTemplates.end())
  381. return it->second;
  382. return nullptr;
  383. }
  384. SpirvDebugTypeTemplateParameter *SpirvContext::createDebugTypeTemplateParameter(
  385. const TemplateArgument *templateArg, llvm::StringRef name,
  386. SpirvDebugType *type, SpirvInstruction *value, SpirvDebugSource *source,
  387. uint32_t line, uint32_t column) {
  388. auto *param = getDebugTypeTemplateParameter(templateArg);
  389. if (param != nullptr)
  390. return param;
  391. param = new (this)
  392. SpirvDebugTypeTemplateParameter(name, type, value, source, line, column);
  393. typeTemplateParams[templateArg] = param;
  394. return param;
  395. }
  396. SpirvDebugTypeTemplateParameter *SpirvContext::getDebugTypeTemplateParameter(
  397. const TemplateArgument *templateArg) {
  398. auto it = typeTemplateParams.find(templateArg);
  399. if (it != typeTemplateParams.end())
  400. return it->second;
  401. return nullptr;
  402. }
  403. void SpirvContext::pushDebugLexicalScope(RichDebugInfo *info,
  404. SpirvDebugInstruction *scope) {
  405. assert((isa<SpirvDebugLexicalBlock>(scope) ||
  406. isa<SpirvDebugFunction>(scope) ||
  407. isa<SpirvDebugCompilationUnit>(scope) ||
  408. isa<SpirvDebugTypeComposite>(scope)) &&
  409. "Given scope is not a lexical scope");
  410. currentLexicalScope = scope;
  411. info->scopeStack.push_back(scope);
  412. }
  413. void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
  414. for (const auto &typePair : debugTypes) {
  415. module->addDebugInfo(typePair.second);
  416. if (auto *composite = dyn_cast<SpirvDebugTypeComposite>(typePair.second)) {
  417. for (auto *member : composite->getMembers()) {
  418. module->addDebugInfo(member);
  419. }
  420. }
  421. }
  422. for (const auto &typePair : typeTemplates) {
  423. module->addDebugInfo(typePair.second);
  424. }
  425. for (const auto &typePair : typeTemplateParams) {
  426. module->addDebugInfo(typePair.second);
  427. }
  428. debugTypes.clear();
  429. typeTemplates.clear();
  430. typeTemplateParams.clear();
  431. }
  432. } // end namespace spirv
  433. } // end namespace clang