CapabilityVisitor.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. //===--- CapabilityVisitor.cpp - Capability Visitor --------------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "CapabilityVisitor.h"
  10. #include "clang/SPIRV/SpirvBuilder.h"
  11. namespace clang {
  12. namespace spirv {
  13. void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
  14. SourceLocation loc,
  15. spv::StorageClass sc) {
  16. // Defent against instructions that do not have a return type.
  17. if (!type)
  18. return;
  19. // Integer-related capabilities
  20. if (const auto *intType = dyn_cast<IntegerType>(type)) {
  21. switch (intType->getBitwidth()) {
  22. case 16: {
  23. // Usage of a 16-bit integer type.
  24. spvBuilder.requireCapability(spv::Capability::Int16);
  25. // Usage of a 16-bit integer type as stage I/O.
  26. if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
  27. spvBuilder.addExtension(Extension::KHR_16bit_storage,
  28. "16-bit stage IO variables", loc);
  29. spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
  30. }
  31. break;
  32. }
  33. case 64: {
  34. spvBuilder.requireCapability(spv::Capability::Int64);
  35. break;
  36. }
  37. default:
  38. break;
  39. }
  40. }
  41. // Float-related capabilities
  42. else if (const auto *floatType = dyn_cast<FloatType>(type)) {
  43. switch (floatType->getBitwidth()) {
  44. case 16: {
  45. // Usage of a 16-bit float type.
  46. // It looks like the validator does not approve of Float16
  47. // capability even though we do use the necessary extension.
  48. // TODO: Re-enable adding Float16 capability below.
  49. // spvBuilder.requireCapability(spv::Capability::Float16);
  50. spvBuilder.addExtension(Extension::AMD_gpu_shader_half_float,
  51. "16-bit float", loc);
  52. // Usage of a 16-bit float type as stage I/O.
  53. if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
  54. spvBuilder.addExtension(Extension::KHR_16bit_storage,
  55. "16-bit stage IO variables", loc);
  56. spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
  57. }
  58. break;
  59. }
  60. case 64: {
  61. spvBuilder.requireCapability(spv::Capability::Float64);
  62. break;
  63. }
  64. default:
  65. break;
  66. }
  67. }
  68. // Vectors
  69. else if (const auto *vecType = dyn_cast<VectorType>(type)) {
  70. addCapabilityForType(vecType->getElementType(), loc, sc);
  71. }
  72. // Matrices
  73. else if (const auto *matType = dyn_cast<MatrixType>(type)) {
  74. addCapabilityForType(matType->getElementType(), loc, sc);
  75. }
  76. // Arrays
  77. else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
  78. addCapabilityForType(arrType->getElementType(), loc, sc);
  79. }
  80. // Runtime array of resources requires additional capability.
  81. else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
  82. if (SpirvType::isResourceType(raType->getElementType())) {
  83. // the elements inside the runtime array are resources
  84. spvBuilder.addExtension(Extension::EXT_descriptor_indexing,
  85. "runtime array of resources", loc);
  86. spvBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
  87. }
  88. addCapabilityForType(raType->getElementType(), loc, sc);
  89. }
  90. // Image types
  91. else if (const auto *imageType = dyn_cast<ImageType>(type)) {
  92. switch (imageType->getDimension()) {
  93. case spv::Dim::Buffer: {
  94. spvBuilder.requireCapability(spv::Capability::SampledBuffer);
  95. if (imageType->withSampler() == ImageType::WithSampler::No) {
  96. spvBuilder.requireCapability(spv::Capability::ImageBuffer);
  97. }
  98. break;
  99. }
  100. case spv::Dim::Dim1D: {
  101. if (imageType->withSampler() == ImageType::WithSampler::No) {
  102. spvBuilder.requireCapability(spv::Capability::Image1D);
  103. } else {
  104. spvBuilder.requireCapability(spv::Capability::Sampled1D);
  105. }
  106. break;
  107. }
  108. case spv::Dim::SubpassData: {
  109. spvBuilder.requireCapability(spv::Capability::InputAttachment);
  110. break;
  111. }
  112. default:
  113. break;
  114. }
  115. switch (imageType->getImageFormat()) {
  116. case spv::ImageFormat::Rg32f:
  117. case spv::ImageFormat::Rg16f:
  118. case spv::ImageFormat::R11fG11fB10f:
  119. case spv::ImageFormat::R16f:
  120. case spv::ImageFormat::Rgba16:
  121. case spv::ImageFormat::Rgb10A2:
  122. case spv::ImageFormat::Rg16:
  123. case spv::ImageFormat::Rg8:
  124. case spv::ImageFormat::R16:
  125. case spv::ImageFormat::R8:
  126. case spv::ImageFormat::Rgba16Snorm:
  127. case spv::ImageFormat::Rg16Snorm:
  128. case spv::ImageFormat::Rg8Snorm:
  129. case spv::ImageFormat::R16Snorm:
  130. case spv::ImageFormat::R8Snorm:
  131. case spv::ImageFormat::Rg32i:
  132. case spv::ImageFormat::Rg16i:
  133. case spv::ImageFormat::Rg8i:
  134. case spv::ImageFormat::R16i:
  135. case spv::ImageFormat::R8i:
  136. case spv::ImageFormat::Rgb10a2ui:
  137. case spv::ImageFormat::Rg32ui:
  138. case spv::ImageFormat::Rg16ui:
  139. case spv::ImageFormat::Rg8ui:
  140. case spv::ImageFormat::R16ui:
  141. case spv::ImageFormat::R8ui:
  142. spvBuilder.requireCapability(
  143. spv::Capability::StorageImageExtendedFormats);
  144. break;
  145. default:
  146. // Only image formats requiring extended formats are relevant. The rest
  147. // just pass through.
  148. break;
  149. }
  150. if (imageType->isArrayedImage() && imageType->isMSImage())
  151. spvBuilder.requireCapability(spv::Capability::ImageMSArray);
  152. addCapabilityForType(imageType->getSampledType(), loc, sc);
  153. }
  154. // Sampled image type
  155. else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
  156. addCapabilityForType(sampledImageType->getImageType(), loc, sc);
  157. }
  158. // Pointer type
  159. else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
  160. addCapabilityForType(ptrType->getPointeeType(), loc, sc);
  161. }
  162. // Struct type
  163. else if (const auto *structType = dyn_cast<StructType>(type)) {
  164. if (SpirvType::isOrContains16BitType(structType)) {
  165. spvBuilder.addExtension(Extension::KHR_16bit_storage,
  166. "16-bit types in resource", loc);
  167. if (sc == spv::StorageClass::PushConstant) {
  168. spvBuilder.requireCapability(spv::Capability::StoragePushConstant16);
  169. } else if (structType->getInterfaceType() ==
  170. StructInterfaceType::UniformBuffer) {
  171. spvBuilder.requireCapability(spv::Capability::StorageUniform16);
  172. } else if (structType->getInterfaceType() ==
  173. StructInterfaceType::StorageBuffer) {
  174. spvBuilder.requireCapability(
  175. spv::Capability::StorageUniformBufferBlock16);
  176. }
  177. }
  178. for (auto field : structType->getFields())
  179. addCapabilityForType(field.type, loc, sc);
  180. }
  181. }
  182. bool CapabilityVisitor::visit(SpirvDecoration *decor) {
  183. const auto loc = decor->getSourceLocation();
  184. switch (decor->getDecoration()) {
  185. case spv::Decoration::Sample: {
  186. spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
  187. break;
  188. }
  189. case spv::Decoration::NonUniformEXT: {
  190. spvBuilder.addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT",
  191. loc);
  192. spvBuilder.requireCapability(spv::Capability::ShaderNonUniformEXT);
  193. break;
  194. }
  195. // Capabilities needed for built-ins
  196. case spv::Decoration::BuiltIn: {
  197. assert(decor->getParams().size() == 1);
  198. const auto builtin = static_cast<spv::BuiltIn>(decor->getParams()[0]);
  199. switch (builtin) {
  200. case spv::BuiltIn::SampleId:
  201. case spv::BuiltIn::SamplePosition: {
  202. spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
  203. break;
  204. }
  205. case spv::BuiltIn::SubgroupSize:
  206. case spv::BuiltIn::NumSubgroups:
  207. case spv::BuiltIn::SubgroupId:
  208. case spv::BuiltIn::SubgroupLocalInvocationId: {
  209. spvBuilder.requireCapability(spv::Capability::GroupNonUniform, loc);
  210. break;
  211. }
  212. case spv::BuiltIn::BaseVertex: {
  213. spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
  214. "BaseVertex Builtin", loc);
  215. spvBuilder.requireCapability(spv::Capability::DrawParameters);
  216. break;
  217. }
  218. case spv::BuiltIn::BaseInstance: {
  219. spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
  220. "BaseInstance Builtin", loc);
  221. spvBuilder.requireCapability(spv::Capability::DrawParameters);
  222. break;
  223. }
  224. case spv::BuiltIn::DrawIndex: {
  225. spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
  226. "DrawIndex Builtin", loc);
  227. spvBuilder.requireCapability(spv::Capability::DrawParameters);
  228. break;
  229. }
  230. case spv::BuiltIn::DeviceIndex: {
  231. spvBuilder.addExtension(Extension::KHR_device_group,
  232. "DeviceIndex Builtin", loc);
  233. spvBuilder.requireCapability(spv::Capability::DeviceGroup);
  234. break;
  235. }
  236. case spv::BuiltIn::FragStencilRefEXT: {
  237. spvBuilder.addExtension(Extension::EXT_shader_stencil_export,
  238. "SV_StencilRef", loc);
  239. spvBuilder.requireCapability(spv::Capability::StencilExportEXT);
  240. break;
  241. }
  242. case spv::BuiltIn::ViewIndex: {
  243. spvBuilder.addExtension(Extension::KHR_multiview, "SV_ViewID", loc);
  244. spvBuilder.requireCapability(spv::Capability::MultiView);
  245. break;
  246. }
  247. case spv::BuiltIn::FullyCoveredEXT: {
  248. spvBuilder.addExtension(Extension::EXT_fragment_fully_covered,
  249. "SV_InnerCoverage", loc);
  250. spvBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
  251. break;
  252. }
  253. case spv::BuiltIn::PrimitiveId: {
  254. // PrimitiveID can be used as PSIn
  255. if (shaderModel == spv::ExecutionModel::Fragment)
  256. spvBuilder.requireCapability(spv::Capability::Geometry);
  257. break;
  258. }
  259. case spv::BuiltIn::Layer: {
  260. if (shaderModel == spv::ExecutionModel::Vertex ||
  261. shaderModel == spv::ExecutionModel::TessellationControl ||
  262. shaderModel == spv::ExecutionModel::TessellationEvaluation) {
  263. spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
  264. "SV_RenderTargetArrayIndex", loc);
  265. spvBuilder.requireCapability(
  266. spv::Capability::ShaderViewportIndexLayerEXT);
  267. } else if (shaderModel == spv::ExecutionModel::Fragment) {
  268. // SV_RenderTargetArrayIndex can be used as PSIn.
  269. spvBuilder.requireCapability(spv::Capability::Geometry);
  270. }
  271. break;
  272. }
  273. case spv::BuiltIn::ViewportIndex: {
  274. if (shaderModel == spv::ExecutionModel::Vertex ||
  275. shaderModel == spv::ExecutionModel::TessellationControl ||
  276. shaderModel == spv::ExecutionModel::TessellationEvaluation) {
  277. spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
  278. "SV_ViewPortArrayIndex", loc);
  279. spvBuilder.requireCapability(
  280. spv::Capability::ShaderViewportIndexLayerEXT);
  281. } else if (shaderModel == spv::ExecutionModel::Fragment ||
  282. shaderModel == spv::ExecutionModel::Geometry) {
  283. // SV_ViewportArrayIndex can be used as PSIn.
  284. spvBuilder.requireCapability(spv::Capability::MultiViewport);
  285. }
  286. break;
  287. }
  288. case spv::BuiltIn::ClipDistance: {
  289. spvBuilder.requireCapability(spv::Capability::ClipDistance);
  290. break;
  291. }
  292. case spv::BuiltIn::CullDistance: {
  293. spvBuilder.requireCapability(spv::Capability::CullDistance);
  294. break;
  295. }
  296. default:
  297. break;
  298. }
  299. break;
  300. }
  301. default:
  302. break;
  303. }
  304. return true;
  305. }
  306. spv::Capability
  307. CapabilityVisitor::getNonUniformCapability(const SpirvType *type) {
  308. if (!type)
  309. return spv::Capability::Max;
  310. if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
  311. return getNonUniformCapability(arrayType->getElementType());
  312. }
  313. if (SpirvType::isTexture(type) || SpirvType::isSampler(type)) {
  314. return spv::Capability::SampledImageArrayNonUniformIndexingEXT;
  315. }
  316. if (SpirvType::isRWTexture(type)) {
  317. return spv::Capability::StorageImageArrayNonUniformIndexingEXT;
  318. }
  319. if (SpirvType::isBuffer(type)) {
  320. return spv::Capability::UniformTexelBufferArrayNonUniformIndexingEXT;
  321. }
  322. if (SpirvType::isRWBuffer(type)) {
  323. return spv::Capability::StorageTexelBufferArrayNonUniformIndexingEXT;
  324. }
  325. if (SpirvType::isSubpassInput(type) || SpirvType::isSubpassInputMS(type)) {
  326. return spv::Capability::InputAttachmentArrayNonUniformIndexingEXT;
  327. }
  328. return spv::Capability::Max;
  329. }
  330. bool CapabilityVisitor::visit(SpirvImageQuery *instr) {
  331. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  332. instr->getStorageClass());
  333. spvBuilder.requireCapability(spv::Capability::ImageQuery);
  334. return true;
  335. }
  336. bool CapabilityVisitor::visit(SpirvImageSparseTexelsResident *instr) {
  337. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  338. instr->getStorageClass());
  339. spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
  340. return true;
  341. }
  342. bool CapabilityVisitor::visit(SpirvImageOp *instr) {
  343. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  344. instr->getStorageClass());
  345. if (instr->hasOffset() || instr->hasConstOffsets())
  346. spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
  347. if (instr->hasMinLod())
  348. spvBuilder.requireCapability(spv::Capability::MinLod);
  349. if (instr->isSparse())
  350. spvBuilder.requireCapability(spv::Capability::SparseResidency);
  351. return true;
  352. }
  353. bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
  354. const SpirvType *resultType = instr->getResultType();
  355. const auto opcode = instr->getopcode();
  356. // Add result-type-specific capabilities
  357. addCapabilityForType(resultType, instr->getSourceLocation(),
  358. instr->getStorageClass());
  359. // Add NonUniform capabilities if necessary
  360. if (instr->isNonUniform()) {
  361. spvBuilder.requireCapability(getNonUniformCapability(resultType));
  362. }
  363. // Add opcode-specific capabilities
  364. switch (opcode) {
  365. case spv::Op::OpDPdxCoarse:
  366. case spv::Op::OpDPdyCoarse:
  367. case spv::Op::OpFwidthCoarse:
  368. case spv::Op::OpDPdxFine:
  369. case spv::Op::OpDPdyFine:
  370. case spv::Op::OpFwidthFine:
  371. spvBuilder.requireCapability(spv::Capability::DerivativeControl);
  372. break;
  373. case spv::Op::OpGroupNonUniformElect:
  374. spvBuilder.requireCapability(spv::Capability::GroupNonUniform);
  375. break;
  376. case spv::Op::OpGroupNonUniformAny:
  377. case spv::Op::OpGroupNonUniformAll:
  378. case spv::Op::OpGroupNonUniformAllEqual:
  379. spvBuilder.requireCapability(spv::Capability::GroupNonUniformVote);
  380. break;
  381. case spv::Op::OpGroupNonUniformBallot:
  382. case spv::Op::OpGroupNonUniformInverseBallot:
  383. case spv::Op::OpGroupNonUniformBallotBitExtract:
  384. case spv::Op::OpGroupNonUniformBallotBitCount:
  385. case spv::Op::OpGroupNonUniformBallotFindLSB:
  386. case spv::Op::OpGroupNonUniformBallotFindMSB:
  387. case spv::Op::OpGroupNonUniformBroadcast:
  388. case spv::Op::OpGroupNonUniformBroadcastFirst:
  389. spvBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
  390. break;
  391. case spv::Op::OpGroupNonUniformIAdd:
  392. case spv::Op::OpGroupNonUniformFAdd:
  393. case spv::Op::OpGroupNonUniformIMul:
  394. case spv::Op::OpGroupNonUniformFMul:
  395. case spv::Op::OpGroupNonUniformSMax:
  396. case spv::Op::OpGroupNonUniformUMax:
  397. case spv::Op::OpGroupNonUniformFMax:
  398. case spv::Op::OpGroupNonUniformSMin:
  399. case spv::Op::OpGroupNonUniformUMin:
  400. case spv::Op::OpGroupNonUniformFMin:
  401. case spv::Op::OpGroupNonUniformBitwiseAnd:
  402. case spv::Op::OpGroupNonUniformBitwiseOr:
  403. case spv::Op::OpGroupNonUniformBitwiseXor:
  404. case spv::Op::OpGroupNonUniformLogicalAnd:
  405. case spv::Op::OpGroupNonUniformLogicalOr:
  406. case spv::Op::OpGroupNonUniformLogicalXor:
  407. spvBuilder.requireCapability(spv::Capability::GroupNonUniformArithmetic);
  408. break;
  409. case spv::Op::OpGroupNonUniformQuadBroadcast:
  410. case spv::Op::OpGroupNonUniformQuadSwap:
  411. spvBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
  412. break;
  413. default:
  414. break;
  415. }
  416. return true;
  417. }
  418. bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
  419. shaderModel = entryPoint->getExecModel();
  420. switch (shaderModel) {
  421. case spv::ExecutionModel::Fragment:
  422. case spv::ExecutionModel::Vertex:
  423. case spv::ExecutionModel::GLCompute:
  424. spvBuilder.requireCapability(spv::Capability::Shader);
  425. break;
  426. case spv::ExecutionModel::Geometry:
  427. spvBuilder.requireCapability(spv::Capability::Geometry);
  428. break;
  429. case spv::ExecutionModel::TessellationControl:
  430. case spv::ExecutionModel::TessellationEvaluation:
  431. spvBuilder.requireCapability(spv::Capability::Tessellation);
  432. break;
  433. default:
  434. llvm_unreachable("found unknown shader model");
  435. break;
  436. }
  437. return true;
  438. }
  439. bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
  440. if (execMode->getExecutionMode() == spv::ExecutionMode::PostDepthCoverage) {
  441. spvBuilder.requireCapability(
  442. spv::Capability::SampleMaskPostDepthCoverage,
  443. execMode->getEntryPoint()->getSourceLocation());
  444. }
  445. return true;
  446. }
  447. } // end namespace spirv
  448. } // end namespace clang