SpirvContext.cpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. //===--- SpirvContext.cpp - SPIR-V SpirvContext implementation-------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include <algorithm>
  10. #include <tuple>
  11. #include "clang/SPIRV/SpirvContext.h"
  12. namespace clang {
  13. namespace spirv {
  14. SpirvContext::SpirvContext()
  15. : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
  16. uintTypes({}), floatTypes({}), samplerType(nullptr),
  17. curShaderModelKind(ShaderModelKind::Invalid), majorVersion(0),
  18. minorVersion(0) {
  19. voidType = new (this) VoidType;
  20. boolType = new (this) BoolType;
  21. samplerType = new (this) SamplerType;
  22. accelerationStructureTypeNV = new (this) AccelerationStructureTypeNV;
  23. rayQueryProvisionalTypeKHR = new (this) RayQueryProvisionalTypeKHR;
  24. }
  25. inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
  26. assert(bitwidth >= 16 && bitwidth <= 64 && llvm::isPowerOf2_32(bitwidth));
  27. return llvm::Log2_32(bitwidth);
  28. }
  29. const IntegerType *SpirvContext::getSIntType(uint32_t bitwidth) {
  30. auto &type = sintTypes[log2ForBitwidth(bitwidth)];
  31. if (type == nullptr) {
  32. type = new (this) IntegerType(bitwidth, true);
  33. }
  34. return type;
  35. }
  36. const IntegerType *SpirvContext::getUIntType(uint32_t bitwidth) {
  37. auto &type = uintTypes[log2ForBitwidth(bitwidth)];
  38. if (type == nullptr) {
  39. type = new (this) IntegerType(bitwidth, false);
  40. }
  41. return type;
  42. }
  43. const FloatType *SpirvContext::getFloatType(uint32_t bitwidth) {
  44. auto &type = floatTypes[log2ForBitwidth(bitwidth)];
  45. if (type == nullptr) {
  46. type = new (this) FloatType(bitwidth);
  47. }
  48. return type;
  49. }
  50. const VectorType *SpirvContext::getVectorType(const SpirvType *elemType,
  51. uint32_t count) {
  52. // We are certain this should be a scalar type. Otherwise, cast causes an
  53. // assertion failure.
  54. const ScalarType *scalarType = cast<ScalarType>(elemType);
  55. assert(count == 2 || count == 3 || count == 4);
  56. auto found = vecTypes.find(scalarType);
  57. if (found != vecTypes.end()) {
  58. auto &type = found->second[count];
  59. if (type != nullptr)
  60. return type;
  61. } else {
  62. // Make sure to initialize since std::array is "an aggregate type with the
  63. // same semantics as a struct holding a C-style array T[N]".
  64. vecTypes[scalarType] = {};
  65. }
  66. return vecTypes[scalarType][count] = new (this) VectorType(scalarType, count);
  67. }
  68. const SpirvType *SpirvContext::getMatrixType(const SpirvType *elemType,
  69. uint32_t count) {
  70. // We are certain this should be a vector type. Otherwise, cast causes an
  71. // assertion failure.
  72. const VectorType *vecType = cast<VectorType>(elemType);
  73. assert(count == 2 || count == 3 || count == 4);
  74. // In the case of non-floating-point matrices, we represent them as array of
  75. // vectors.
  76. if (!isa<FloatType>(vecType->getElementType())) {
  77. return getArrayType(elemType, count, llvm::None);
  78. }
  79. auto foundVec = matTypes.find(vecType);
  80. if (foundVec != matTypes.end()) {
  81. const auto &matVector = foundVec->second;
  82. // Create a temporary object for finding in the vector.
  83. MatrixType type(vecType, count);
  84. for (const auto *cachedType : matVector)
  85. if (type == *cachedType)
  86. return cachedType;
  87. }
  88. const auto *ptr = new (this) MatrixType(vecType, count);
  89. matTypes[vecType].push_back(ptr);
  90. return ptr;
  91. }
  92. const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
  93. spv::Dim dim,
  94. ImageType::WithDepth depth,
  95. bool arrayed, bool ms,
  96. ImageType::WithSampler sampled,
  97. spv::ImageFormat format) {
  98. // We are certain this should be a numerical type. Otherwise, cast causes an
  99. // assertion failure.
  100. const NumericalType *elemType = cast<NumericalType>(sampledType);
  101. // Create a temporary object for finding in the set.
  102. ImageType type(elemType, dim, depth, arrayed, ms, sampled, format);
  103. auto found = imageTypes.find(&type);
  104. if (found != imageTypes.end())
  105. return *found;
  106. auto inserted = imageTypes.insert(
  107. new (this) ImageType(elemType, dim, depth, arrayed, ms, sampled, format));
  108. return *(inserted.first);
  109. }
  110. const SampledImageType *
  111. SpirvContext::getSampledImageType(const ImageType *image) {
  112. auto found = sampledImageTypes.find(image);
  113. if (found != sampledImageTypes.end())
  114. return found->second;
  115. return sampledImageTypes[image] = new (this) SampledImageType(image);
  116. }
  117. const HybridSampledImageType *
  118. SpirvContext::getSampledImageType(QualType image) {
  119. return new (this) HybridSampledImageType(image);
  120. }
  121. const ArrayType *
  122. SpirvContext::getArrayType(const SpirvType *elemType, uint32_t elemCount,
  123. llvm::Optional<uint32_t> arrayStride) {
  124. ArrayType type(elemType, elemCount, arrayStride);
  125. auto found = arrayTypes.find(&type);
  126. if (found != arrayTypes.end())
  127. return *found;
  128. auto inserted =
  129. arrayTypes.insert(new (this) ArrayType(elemType, elemCount, arrayStride));
  130. // The return value is an (iterator, bool) pair. The boolean indicates whether
  131. // it was actually added as a new type.
  132. return *(inserted.first);
  133. }
  134. const RuntimeArrayType *
  135. SpirvContext::getRuntimeArrayType(const SpirvType *elemType,
  136. llvm::Optional<uint32_t> arrayStride) {
  137. RuntimeArrayType type(elemType, arrayStride);
  138. auto found = runtimeArrayTypes.find(&type);
  139. if (found != runtimeArrayTypes.end())
  140. return *found;
  141. auto inserted = runtimeArrayTypes.insert(
  142. new (this) RuntimeArrayType(elemType, arrayStride));
  143. return *(inserted.first);
  144. }
  145. const StructType *
  146. SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
  147. llvm::StringRef name, bool isReadOnly,
  148. StructInterfaceType interfaceType) {
  149. // We are creating a temporary struct type here for querying whether the
  150. // same type was already created. It is a little bit costly, but we can
  151. // avoid allocating directly from the bump pointer allocator, from which
  152. // then we are unable to reclaim until the allocator itself is destroyed.
  153. StructType type(fields, name, isReadOnly, interfaceType);
  154. auto found = std::find_if(
  155. structTypes.begin(), structTypes.end(),
  156. [&type](const StructType *cachedType) { return type == *cachedType; });
  157. if (found != structTypes.end())
  158. return *found;
  159. structTypes.push_back(
  160. new (this) StructType(fields, name, isReadOnly, interfaceType));
  161. return structTypes.back();
  162. }
  163. const HybridStructType *SpirvContext::getHybridStructType(
  164. llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
  165. bool isReadOnly, StructInterfaceType interfaceType) {
  166. return new (this) HybridStructType(fields, name, isReadOnly, interfaceType);
  167. }
  168. const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
  169. spv::StorageClass sc) {
  170. auto foundPointee = pointerTypes.find(pointee);
  171. if (foundPointee != pointerTypes.end()) {
  172. auto &pointeeMap = foundPointee->second;
  173. auto foundSC = pointeeMap.find(sc);
  174. if (foundSC != pointeeMap.end())
  175. return foundSC->second;
  176. }
  177. return pointerTypes[pointee][sc] = new (this) SpirvPointerType(pointee, sc);
  178. }
  179. const HybridPointerType *SpirvContext::getPointerType(QualType pointee,
  180. spv::StorageClass sc) {
  181. return new (this) HybridPointerType(pointee, sc);
  182. }
  183. FunctionType *
  184. SpirvContext::getFunctionType(const SpirvType *ret,
  185. llvm::ArrayRef<const SpirvType *> param) {
  186. // Create a temporary object for finding in the set.
  187. FunctionType type(ret, param);
  188. auto found = functionTypes.find(&type);
  189. if (found != functionTypes.end())
  190. return *found;
  191. auto inserted = functionTypes.insert(new (this) FunctionType(ret, param));
  192. return *inserted.first;
  193. }
  194. const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
  195. // Create a uint RuntimeArray.
  196. const auto *raType =
  197. getRuntimeArrayType(getUIntType(32), /* ArrayStride */ 4);
  198. // Create a struct containing the runtime array as its only member.
  199. return getStructType(
  200. {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0)},
  201. isWritable ? "type.RWByteAddressBuffer" : "type.ByteAddressBuffer",
  202. !isWritable, StructInterfaceType::StorageBuffer);
  203. }
  204. const StructType *SpirvContext::getACSBufferCounterType() {
  205. // Create int32.
  206. const auto *int32Type = getSIntType(32);
  207. // Create a struct containing the integer counter as its only member.
  208. const StructType *type =
  209. getStructType({StructType::FieldInfo(int32Type, "counter", /*offset*/ 0)},
  210. "type.ACSBuffer.counter",
  211. /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
  212. return type;
  213. }
  214. } // end namespace spirv
  215. } // end namespace clang