CapabilityVisitor.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. //===--- CapabilityVisitor.cpp - Capability Visitor --------------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "CapabilityVisitor.h"
  10. #include "clang/SPIRV/SpirvBuilder.h"
  11. namespace clang {
  12. namespace spirv {
  13. void CapabilityVisitor::addExtension(Extension ext, llvm::StringRef target,
  14. SourceLocation loc) {
  15. featureManager.requestExtension(ext, target, loc);
  16. // Do not emit OpExtension if the given extension is natively supported in
  17. // the target environment.
  18. if (featureManager.isExtensionRequiredForTargetEnv(ext))
  19. spvBuilder.requireExtension(featureManager.getExtensionName(ext), loc);
  20. }
  21. void CapabilityVisitor::addCapability(spv::Capability cap, SourceLocation loc) {
  22. if (cap != spv::Capability::Max) {
  23. spvBuilder.requireCapability(cap, loc);
  24. }
  25. }
  26. void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
  27. SourceLocation loc,
  28. spv::StorageClass sc) {
  29. // Defent against instructions that do not have a return type.
  30. if (!type)
  31. return;
  32. // Integer-related capabilities
  33. if (const auto *intType = dyn_cast<IntegerType>(type)) {
  34. switch (intType->getBitwidth()) {
  35. case 8: {
  36. addCapability(spv::Capability::Int8);
  37. break;
  38. }
  39. case 16: {
  40. // Usage of a 16-bit integer type.
  41. addCapability(spv::Capability::Int16);
  42. // Usage of a 16-bit integer type as stage I/O.
  43. if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
  44. addExtension(Extension::KHR_16bit_storage, "16-bit stage IO variables",
  45. loc);
  46. addCapability(spv::Capability::StorageInputOutput16);
  47. }
  48. break;
  49. }
  50. case 64: {
  51. addCapability(spv::Capability::Int64);
  52. break;
  53. }
  54. default:
  55. break;
  56. }
  57. }
  58. // Float-related capabilities
  59. else if (const auto *floatType = dyn_cast<FloatType>(type)) {
  60. switch (floatType->getBitwidth()) {
  61. case 16: {
  62. // Usage of a 16-bit float type.
  63. addCapability(spv::Capability::Float16);
  64. // Usage of a 16-bit float type as stage I/O.
  65. if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
  66. addExtension(Extension::KHR_16bit_storage, "16-bit stage IO variables",
  67. loc);
  68. addCapability(spv::Capability::StorageInputOutput16);
  69. }
  70. break;
  71. }
  72. case 64: {
  73. addCapability(spv::Capability::Float64);
  74. break;
  75. }
  76. default:
  77. break;
  78. }
  79. }
  80. // Vectors
  81. else if (const auto *vecType = dyn_cast<VectorType>(type)) {
  82. addCapabilityForType(vecType->getElementType(), loc, sc);
  83. }
  84. // Matrices
  85. else if (const auto *matType = dyn_cast<MatrixType>(type)) {
  86. addCapabilityForType(matType->getElementType(), loc, sc);
  87. }
  88. // Arrays
  89. else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
  90. addCapabilityForType(arrType->getElementType(), loc, sc);
  91. }
  92. // Runtime array of resources requires additional capability.
  93. else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
  94. if (SpirvType::isResourceType(raType->getElementType())) {
  95. // the elements inside the runtime array are resources
  96. addExtension(Extension::EXT_descriptor_indexing,
  97. "runtime array of resources", loc);
  98. addCapability(spv::Capability::RuntimeDescriptorArrayEXT);
  99. }
  100. addCapabilityForType(raType->getElementType(), loc, sc);
  101. }
  102. // Image types
  103. else if (const auto *imageType = dyn_cast<ImageType>(type)) {
  104. switch (imageType->getDimension()) {
  105. case spv::Dim::Buffer: {
  106. addCapability(spv::Capability::SampledBuffer);
  107. if (imageType->withSampler() == ImageType::WithSampler::No) {
  108. addCapability(spv::Capability::ImageBuffer);
  109. }
  110. break;
  111. }
  112. case spv::Dim::Dim1D: {
  113. if (imageType->withSampler() == ImageType::WithSampler::No) {
  114. addCapability(spv::Capability::Image1D);
  115. } else {
  116. addCapability(spv::Capability::Sampled1D);
  117. }
  118. break;
  119. }
  120. case spv::Dim::SubpassData: {
  121. addCapability(spv::Capability::InputAttachment);
  122. break;
  123. }
  124. default:
  125. break;
  126. }
  127. switch (imageType->getImageFormat()) {
  128. case spv::ImageFormat::Rg32f:
  129. case spv::ImageFormat::Rg16f:
  130. case spv::ImageFormat::R11fG11fB10f:
  131. case spv::ImageFormat::R16f:
  132. case spv::ImageFormat::Rgba16:
  133. case spv::ImageFormat::Rgb10A2:
  134. case spv::ImageFormat::Rg16:
  135. case spv::ImageFormat::Rg8:
  136. case spv::ImageFormat::R16:
  137. case spv::ImageFormat::R8:
  138. case spv::ImageFormat::Rgba16Snorm:
  139. case spv::ImageFormat::Rg16Snorm:
  140. case spv::ImageFormat::Rg8Snorm:
  141. case spv::ImageFormat::R16Snorm:
  142. case spv::ImageFormat::R8Snorm:
  143. case spv::ImageFormat::Rg32i:
  144. case spv::ImageFormat::Rg16i:
  145. case spv::ImageFormat::Rg8i:
  146. case spv::ImageFormat::R16i:
  147. case spv::ImageFormat::R8i:
  148. case spv::ImageFormat::Rgb10a2ui:
  149. case spv::ImageFormat::Rg32ui:
  150. case spv::ImageFormat::Rg16ui:
  151. case spv::ImageFormat::Rg8ui:
  152. case spv::ImageFormat::R16ui:
  153. case spv::ImageFormat::R8ui:
  154. addCapability(spv::Capability::StorageImageExtendedFormats);
  155. break;
  156. default:
  157. // Only image formats requiring extended formats are relevant. The rest
  158. // just pass through.
  159. break;
  160. }
  161. if (imageType->isArrayedImage() && imageType->isMSImage())
  162. addCapability(spv::Capability::ImageMSArray);
  163. addCapabilityForType(imageType->getSampledType(), loc, sc);
  164. }
  165. // Sampled image type
  166. else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
  167. addCapabilityForType(sampledImageType->getImageType(), loc, sc);
  168. }
  169. // Pointer type
  170. else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
  171. addCapabilityForType(ptrType->getPointeeType(), loc, sc);
  172. }
  173. // Struct type
  174. else if (const auto *structType = dyn_cast<StructType>(type)) {
  175. if (SpirvType::isOrContainsType<NumericalType, 16>(structType)) {
  176. addExtension(Extension::KHR_16bit_storage, "16-bit types in resource",
  177. loc);
  178. if (sc == spv::StorageClass::PushConstant) {
  179. addCapability(spv::Capability::StoragePushConstant16);
  180. } else if (structType->getInterfaceType() ==
  181. StructInterfaceType::UniformBuffer) {
  182. addCapability(spv::Capability::StorageUniform16);
  183. } else if (structType->getInterfaceType() ==
  184. StructInterfaceType::StorageBuffer) {
  185. addCapability(spv::Capability::StorageUniformBufferBlock16);
  186. }
  187. }
  188. for (auto field : structType->getFields())
  189. addCapabilityForType(field.type, loc, sc);
  190. }
  191. // AccelerationStructureTypeNV type
  192. else if (isa<AccelerationStructureTypeNV>(type)) {
  193. if (featureManager.isExtensionEnabled(Extension::NV_ray_tracing)) {
  194. addCapability(spv::Capability::RayTracingNV);
  195. addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing", {});
  196. } else {
  197. // KHR_ray_tracing extension requires Vulkan 1.1 with VK_KHR_spirv_1_4
  198. // extention or Vulkan 1.2.
  199. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1_SPIRV_1_4,
  200. "Raytracing", {});
  201. addCapability(spv::Capability::RayTracingKHR);
  202. addExtension(Extension::KHR_ray_tracing, "SPV_KHR_ray_tracing", {});
  203. }
  204. }
  205. // RayQueryTypeKHR type
  206. else if (isa<RayQueryTypeKHR>(type)) {
  207. addCapability(spv::Capability::RayQueryKHR);
  208. addExtension(Extension::KHR_ray_query, "SPV_KHR_ray_query", {});
  209. }
  210. }
  211. bool CapabilityVisitor::visit(SpirvDecoration *decor) {
  212. const auto loc = decor->getSourceLocation();
  213. switch (decor->getDecoration()) {
  214. case spv::Decoration::Sample: {
  215. addCapability(spv::Capability::SampleRateShading, loc);
  216. break;
  217. }
  218. case spv::Decoration::NonUniformEXT: {
  219. addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT", loc);
  220. addCapability(spv::Capability::ShaderNonUniformEXT);
  221. break;
  222. }
  223. case spv::Decoration::HlslSemanticGOOGLE:
  224. case spv::Decoration::HlslCounterBufferGOOGLE: {
  225. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  226. loc);
  227. break;
  228. }
  229. // Capabilities needed for built-ins
  230. case spv::Decoration::BuiltIn: {
  231. assert(decor->getParams().size() == 1);
  232. const auto builtin = static_cast<spv::BuiltIn>(decor->getParams()[0]);
  233. switch (builtin) {
  234. case spv::BuiltIn::SampleId:
  235. case spv::BuiltIn::SamplePosition: {
  236. addCapability(spv::Capability::SampleRateShading, loc);
  237. break;
  238. }
  239. case spv::BuiltIn::SubgroupSize:
  240. case spv::BuiltIn::NumSubgroups:
  241. case spv::BuiltIn::SubgroupId:
  242. case spv::BuiltIn::SubgroupLocalInvocationId: {
  243. addCapability(spv::Capability::GroupNonUniform, loc);
  244. break;
  245. }
  246. case spv::BuiltIn::BaseVertex: {
  247. addExtension(Extension::KHR_shader_draw_parameters, "BaseVertex Builtin",
  248. loc);
  249. addCapability(spv::Capability::DrawParameters);
  250. break;
  251. }
  252. case spv::BuiltIn::BaseInstance: {
  253. addExtension(Extension::KHR_shader_draw_parameters,
  254. "BaseInstance Builtin", loc);
  255. addCapability(spv::Capability::DrawParameters);
  256. break;
  257. }
  258. case spv::BuiltIn::DrawIndex: {
  259. addExtension(Extension::KHR_shader_draw_parameters, "DrawIndex Builtin",
  260. loc);
  261. addCapability(spv::Capability::DrawParameters);
  262. break;
  263. }
  264. case spv::BuiltIn::DeviceIndex: {
  265. addExtension(Extension::KHR_device_group, "DeviceIndex Builtin", loc);
  266. addCapability(spv::Capability::DeviceGroup);
  267. break;
  268. }
  269. case spv::BuiltIn::FragStencilRefEXT: {
  270. addExtension(Extension::EXT_shader_stencil_export, "SV_StencilRef", loc);
  271. addCapability(spv::Capability::StencilExportEXT);
  272. break;
  273. }
  274. case spv::BuiltIn::ViewIndex: {
  275. addExtension(Extension::KHR_multiview, "SV_ViewID", loc);
  276. addCapability(spv::Capability::MultiView);
  277. break;
  278. }
  279. case spv::BuiltIn::FullyCoveredEXT: {
  280. addExtension(Extension::EXT_fragment_fully_covered, "SV_InnerCoverage",
  281. loc);
  282. addCapability(spv::Capability::FragmentFullyCoveredEXT);
  283. break;
  284. }
  285. case spv::BuiltIn::PrimitiveId: {
  286. // PrimitiveID can be used as PSIn or MSPOut.
  287. if (shaderModel == spv::ExecutionModel::Fragment ||
  288. shaderModel == spv::ExecutionModel::MeshNV)
  289. addCapability(spv::Capability::Geometry);
  290. break;
  291. }
  292. case spv::BuiltIn::Layer: {
  293. if (shaderModel == spv::ExecutionModel::Vertex ||
  294. shaderModel == spv::ExecutionModel::TessellationControl ||
  295. shaderModel == spv::ExecutionModel::TessellationEvaluation) {
  296. addExtension(Extension::EXT_shader_viewport_index_layer,
  297. "SV_RenderTargetArrayIndex", loc);
  298. addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
  299. } else if (shaderModel == spv::ExecutionModel::Fragment ||
  300. shaderModel == spv::ExecutionModel::MeshNV) {
  301. // SV_RenderTargetArrayIndex can be used as PSIn or MSPOut.
  302. addCapability(spv::Capability::Geometry);
  303. }
  304. break;
  305. }
  306. case spv::BuiltIn::ViewportIndex: {
  307. if (shaderModel == spv::ExecutionModel::Vertex ||
  308. shaderModel == spv::ExecutionModel::TessellationControl ||
  309. shaderModel == spv::ExecutionModel::TessellationEvaluation) {
  310. addExtension(Extension::EXT_shader_viewport_index_layer,
  311. "SV_ViewPortArrayIndex", loc);
  312. addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
  313. } else if (shaderModel == spv::ExecutionModel::Fragment ||
  314. shaderModel == spv::ExecutionModel::Geometry ||
  315. shaderModel == spv::ExecutionModel::MeshNV) {
  316. // SV_ViewportArrayIndex can be used as PSIn or GSOut or MSPOut.
  317. addCapability(spv::Capability::MultiViewport);
  318. }
  319. break;
  320. }
  321. case spv::BuiltIn::ClipDistance: {
  322. addCapability(spv::Capability::ClipDistance);
  323. break;
  324. }
  325. case spv::BuiltIn::CullDistance: {
  326. addCapability(spv::Capability::CullDistance);
  327. break;
  328. }
  329. case spv::BuiltIn::BaryCoordNoPerspAMD:
  330. case spv::BuiltIn::BaryCoordNoPerspCentroidAMD:
  331. case spv::BuiltIn::BaryCoordNoPerspSampleAMD:
  332. case spv::BuiltIn::BaryCoordSmoothAMD:
  333. case spv::BuiltIn::BaryCoordSmoothCentroidAMD:
  334. case spv::BuiltIn::BaryCoordSmoothSampleAMD:
  335. case spv::BuiltIn::BaryCoordPullModelAMD: {
  336. addExtension(Extension::AMD_shader_explicit_vertex_parameter,
  337. "SV_Barycentrics", loc);
  338. break;
  339. }
  340. case spv::BuiltIn::ShadingRateKHR:
  341. case spv::BuiltIn::PrimitiveShadingRateKHR: {
  342. addExtension(Extension::KHR_fragment_shading_rate, "SV_ShadingRate", loc);
  343. addCapability(spv::Capability::FragmentShadingRateKHR);
  344. break;
  345. }
  346. default:
  347. break;
  348. }
  349. break;
  350. }
  351. default:
  352. break;
  353. }
  354. return true;
  355. }
  356. spv::Capability
  357. CapabilityVisitor::getNonUniformCapability(const SpirvType *type) {
  358. if (!type)
  359. return spv::Capability::Max;
  360. if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
  361. return getNonUniformCapability(arrayType->getElementType());
  362. }
  363. if (SpirvType::isTexture(type) || SpirvType::isSampler(type)) {
  364. return spv::Capability::SampledImageArrayNonUniformIndexingEXT;
  365. }
  366. if (SpirvType::isRWTexture(type)) {
  367. return spv::Capability::StorageImageArrayNonUniformIndexingEXT;
  368. }
  369. if (SpirvType::isBuffer(type)) {
  370. return spv::Capability::UniformTexelBufferArrayNonUniformIndexingEXT;
  371. }
  372. if (SpirvType::isRWBuffer(type)) {
  373. return spv::Capability::StorageTexelBufferArrayNonUniformIndexingEXT;
  374. }
  375. if (SpirvType::isSubpassInput(type) || SpirvType::isSubpassInputMS(type)) {
  376. return spv::Capability::InputAttachmentArrayNonUniformIndexingEXT;
  377. }
  378. return spv::Capability::Max;
  379. }
  380. bool CapabilityVisitor::visit(SpirvImageQuery *instr) {
  381. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  382. instr->getStorageClass());
  383. addCapability(spv::Capability::ImageQuery);
  384. return true;
  385. }
  386. bool CapabilityVisitor::visit(SpirvImageSparseTexelsResident *instr) {
  387. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  388. instr->getStorageClass());
  389. addCapability(spv::Capability::ImageGatherExtended);
  390. addCapability(spv::Capability::SparseResidency);
  391. return true;
  392. }
  393. bool CapabilityVisitor::visit(SpirvImageOp *instr) {
  394. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  395. instr->getStorageClass());
  396. if (instr->hasOffset() || instr->hasConstOffsets())
  397. addCapability(spv::Capability::ImageGatherExtended);
  398. if (instr->hasMinLod())
  399. addCapability(spv::Capability::MinLod);
  400. if (instr->isSparse())
  401. addCapability(spv::Capability::SparseResidency);
  402. return true;
  403. }
  404. bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
  405. const SpirvType *resultType = instr->getResultType();
  406. const auto opcode = instr->getopcode();
  407. const auto loc = instr->getSourceLocation();
  408. // Add result-type-specific capabilities
  409. addCapabilityForType(resultType, loc, instr->getStorageClass());
  410. // Add NonUniform capabilities if necessary
  411. if (instr->isNonUniform()) {
  412. addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT", loc);
  413. addCapability(spv::Capability::ShaderNonUniformEXT);
  414. addCapability(getNonUniformCapability(resultType));
  415. }
  416. // Add opcode-specific capabilities
  417. switch (opcode) {
  418. case spv::Op::OpDPdxCoarse:
  419. case spv::Op::OpDPdyCoarse:
  420. case spv::Op::OpFwidthCoarse:
  421. case spv::Op::OpDPdxFine:
  422. case spv::Op::OpDPdyFine:
  423. case spv::Op::OpFwidthFine:
  424. addCapability(spv::Capability::DerivativeControl);
  425. break;
  426. case spv::Op::OpGroupNonUniformElect:
  427. addCapability(spv::Capability::GroupNonUniform);
  428. break;
  429. case spv::Op::OpGroupNonUniformAny:
  430. case spv::Op::OpGroupNonUniformAll:
  431. case spv::Op::OpGroupNonUniformAllEqual:
  432. addCapability(spv::Capability::GroupNonUniformVote);
  433. break;
  434. case spv::Op::OpGroupNonUniformBallot:
  435. case spv::Op::OpGroupNonUniformInverseBallot:
  436. case spv::Op::OpGroupNonUniformBallotBitExtract:
  437. case spv::Op::OpGroupNonUniformBallotBitCount:
  438. case spv::Op::OpGroupNonUniformBallotFindLSB:
  439. case spv::Op::OpGroupNonUniformBallotFindMSB:
  440. case spv::Op::OpGroupNonUniformBroadcast:
  441. case spv::Op::OpGroupNonUniformBroadcastFirst:
  442. addCapability(spv::Capability::GroupNonUniformBallot);
  443. break;
  444. case spv::Op::OpGroupNonUniformShuffle:
  445. case spv::Op::OpGroupNonUniformShuffleXor:
  446. addCapability(spv::Capability::GroupNonUniformShuffle);
  447. break;
  448. case spv::Op::OpGroupNonUniformIAdd:
  449. case spv::Op::OpGroupNonUniformFAdd:
  450. case spv::Op::OpGroupNonUniformIMul:
  451. case spv::Op::OpGroupNonUniformFMul:
  452. case spv::Op::OpGroupNonUniformSMax:
  453. case spv::Op::OpGroupNonUniformUMax:
  454. case spv::Op::OpGroupNonUniformFMax:
  455. case spv::Op::OpGroupNonUniformSMin:
  456. case spv::Op::OpGroupNonUniformUMin:
  457. case spv::Op::OpGroupNonUniformFMin:
  458. case spv::Op::OpGroupNonUniformBitwiseAnd:
  459. case spv::Op::OpGroupNonUniformBitwiseOr:
  460. case spv::Op::OpGroupNonUniformBitwiseXor:
  461. case spv::Op::OpGroupNonUniformLogicalAnd:
  462. case spv::Op::OpGroupNonUniformLogicalOr:
  463. case spv::Op::OpGroupNonUniformLogicalXor:
  464. addCapability(spv::Capability::GroupNonUniformArithmetic);
  465. break;
  466. case spv::Op::OpGroupNonUniformQuadBroadcast:
  467. case spv::Op::OpGroupNonUniformQuadSwap:
  468. addCapability(spv::Capability::GroupNonUniformQuad);
  469. break;
  470. case spv::Op::OpVariable: {
  471. if (spvOptions.enableReflect &&
  472. !cast<SpirvVariable>(instr)->getHlslUserType().empty()) {
  473. addExtension(Extension::GOOGLE_user_type, "HLSL User Type", loc);
  474. addExtension(Extension::GOOGLE_hlsl_functionality1, "HLSL User Type",
  475. loc);
  476. }
  477. break;
  478. }
  479. case spv::Op::OpRayQueryInitializeKHR: {
  480. auto rayQueryInst = dyn_cast<SpirvRayQueryOpKHR>(instr);
  481. if (rayQueryInst->hasCullFlags()) {
  482. addCapability(
  483. spv::Capability::RayTraversalPrimitiveCullingKHR);
  484. }
  485. break;
  486. }
  487. default:
  488. break;
  489. }
  490. return true;
  491. }
  492. bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
  493. shaderModel = entryPoint->getExecModel();
  494. switch (shaderModel) {
  495. case spv::ExecutionModel::Fragment:
  496. case spv::ExecutionModel::Vertex:
  497. case spv::ExecutionModel::GLCompute:
  498. addCapability(spv::Capability::Shader);
  499. break;
  500. case spv::ExecutionModel::Geometry:
  501. addCapability(spv::Capability::Geometry);
  502. break;
  503. case spv::ExecutionModel::TessellationControl:
  504. case spv::ExecutionModel::TessellationEvaluation:
  505. addCapability(spv::Capability::Tessellation);
  506. break;
  507. case spv::ExecutionModel::RayGenerationNV:
  508. case spv::ExecutionModel::IntersectionNV:
  509. case spv::ExecutionModel::ClosestHitNV:
  510. case spv::ExecutionModel::AnyHitNV:
  511. case spv::ExecutionModel::MissNV:
  512. case spv::ExecutionModel::CallableNV:
  513. if (featureManager.isExtensionEnabled(Extension::NV_ray_tracing)) {
  514. addCapability(spv::Capability::RayTracingNV);
  515. addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing", {});
  516. } else {
  517. // KHR_ray_tracing extension requires Vulkan 1.1 with VK_KHR_spirv_1_4
  518. // extention or Vulkan 1.2.
  519. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1_SPIRV_1_4,
  520. "Raytracing", {});
  521. addCapability(spv::Capability::RayTracingKHR);
  522. addExtension(Extension::KHR_ray_tracing, "SPV_KHR_ray_tracing", {});
  523. }
  524. break;
  525. case spv::ExecutionModel::MeshNV:
  526. case spv::ExecutionModel::TaskNV:
  527. addCapability(spv::Capability::MeshShadingNV);
  528. addExtension(Extension::NV_mesh_shader, "SPV_NV_mesh_shader", {});
  529. break;
  530. default:
  531. llvm_unreachable("found unknown shader model");
  532. break;
  533. }
  534. return true;
  535. }
  536. bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
  537. if (execMode->getExecutionMode() == spv::ExecutionMode::PostDepthCoverage) {
  538. addCapability(spv::Capability::SampleMaskPostDepthCoverage,
  539. execMode->getEntryPoint()->getSourceLocation());
  540. addExtension(Extension::KHR_post_depth_coverage,
  541. "[[vk::post_depth_coverage]]", execMode->getSourceLocation());
  542. }
  543. return true;
  544. }
  545. bool CapabilityVisitor::visit(SpirvExtInstImport *instr) {
  546. if (instr->getExtendedInstSetName() == "NonSemantic.DebugPrintf")
  547. addExtension(Extension::KHR_non_semantic_info, "DebugPrintf",
  548. /*SourceLocation*/ {});
  549. return true;
  550. }
  551. bool CapabilityVisitor::visit(SpirvExtInst *instr) {
  552. // OpExtInst using the GLSL extended instruction allows only 32-bit types by
  553. // default for interpolation instructions. The AMD_gpu_shader_half_float
  554. // extension adds support for 16-bit floating-point component types for these
  555. // instructions:
  556. // InterpolateAtCentroid, InterpolateAtSample, InterpolateAtOffset
  557. if (SpirvType::isOrContainsType<FloatType, 16>(instr->getResultType()))
  558. switch (instr->getInstruction()) {
  559. case GLSLstd450::GLSLstd450InterpolateAtCentroid:
  560. case GLSLstd450::GLSLstd450InterpolateAtSample:
  561. case GLSLstd450::GLSLstd450InterpolateAtOffset:
  562. addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float",
  563. instr->getSourceLocation());
  564. default:
  565. break;
  566. }
  567. return visitInstruction(instr);
  568. }
  569. bool CapabilityVisitor::visit(SpirvAtomic *instr) {
  570. if (instr->hasValue() && SpirvType::isOrContainsType<IntegerType, 64>(
  571. instr->getValue()->getResultType())) {
  572. addCapability(spv::Capability::Int64Atomics, instr->getSourceLocation());
  573. }
  574. return true;
  575. }
  576. bool CapabilityVisitor::visit(SpirvDemoteToHelperInvocationEXT *inst) {
  577. addCapability(spv::Capability::DemoteToHelperInvocationEXT,
  578. inst->getSourceLocation());
  579. addExtension(Extension::EXT_demote_to_helper_invocation, "discard",
  580. inst->getSourceLocation());
  581. return true;
  582. }
  583. bool CapabilityVisitor::visit(SpirvReadClock *inst) {
  584. auto loc = inst->getSourceLocation();
  585. addCapabilityForType(inst->getResultType(), loc, inst->getStorageClass());
  586. addCapability(spv::Capability::ShaderClockKHR, loc);
  587. addExtension(Extension::KHR_shader_clock, "ReadClock", loc);
  588. return true;
  589. }
  590. bool CapabilityVisitor::visit(SpirvModule *, Visitor::Phase phase) {
  591. // If there are no entry-points in the module (hence shaderModel is not set),
  592. // add the Linkage capability. This allows library shader models to use
  593. // 'export' attribute on functions, and generate an "incomplete/partial"
  594. // SPIR-V binary.
  595. // ExecutionModel::Max means that no entrypoints exist, therefore we should
  596. // add the Linkage Capability.
  597. if (phase == Visitor::Phase::Done &&
  598. shaderModel == spv::ExecutionModel::Max) {
  599. addCapability(spv::Capability::Shader);
  600. addCapability(spv::Capability::Linkage);
  601. }
  602. return true;
  603. }
  604. } // end namespace spirv
  605. } // end namespace clang