CapabilityVisitor.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. //===--- CapabilityVisitor.cpp - Capability Visitor --------------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "CapabilityVisitor.h"
  10. #include "clang/SPIRV/SpirvBuilder.h"
  11. namespace clang {
  12. namespace spirv {
  13. void CapabilityVisitor::addExtension(Extension ext, llvm::StringRef target,
  14. SourceLocation loc) {
  15. featureManager.requestExtension(ext, target, loc);
  16. // Do not emit OpExtension if the given extension is natively supported in
  17. // the target environment.
  18. if (featureManager.isExtensionRequiredForTargetEnv(ext))
  19. spvBuilder.requireExtension(featureManager.getExtensionName(ext), loc);
  20. }
  21. void CapabilityVisitor::addCapability(spv::Capability cap, SourceLocation loc) {
  22. if (cap != spv::Capability::Max) {
  23. spvBuilder.requireCapability(cap, loc);
  24. }
  25. }
  26. void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
  27. SourceLocation loc,
  28. spv::StorageClass sc) {
  29. // Defent against instructions that do not have a return type.
  30. if (!type)
  31. return;
  32. // Integer-related capabilities
  33. if (const auto *intType = dyn_cast<IntegerType>(type)) {
  34. switch (intType->getBitwidth()) {
  35. case 16: {
  36. // Usage of a 16-bit integer type.
  37. addCapability(spv::Capability::Int16);
  38. // Usage of a 16-bit integer type as stage I/O.
  39. if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
  40. addExtension(Extension::KHR_16bit_storage, "16-bit stage IO variables",
  41. loc);
  42. addCapability(spv::Capability::StorageInputOutput16);
  43. }
  44. break;
  45. }
  46. case 64: {
  47. addCapability(spv::Capability::Int64);
  48. break;
  49. }
  50. default:
  51. break;
  52. }
  53. }
  54. // Float-related capabilities
  55. else if (const auto *floatType = dyn_cast<FloatType>(type)) {
  56. switch (floatType->getBitwidth()) {
  57. case 16: {
  58. // Usage of a 16-bit float type.
  59. addCapability(spv::Capability::Float16);
  60. // Usage of a 16-bit float type as stage I/O.
  61. if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
  62. addExtension(Extension::KHR_16bit_storage, "16-bit stage IO variables",
  63. loc);
  64. addCapability(spv::Capability::StorageInputOutput16);
  65. }
  66. break;
  67. }
  68. case 64: {
  69. addCapability(spv::Capability::Float64);
  70. break;
  71. }
  72. default:
  73. break;
  74. }
  75. }
  76. // Vectors
  77. else if (const auto *vecType = dyn_cast<VectorType>(type)) {
  78. addCapabilityForType(vecType->getElementType(), loc, sc);
  79. }
  80. // Matrices
  81. else if (const auto *matType = dyn_cast<MatrixType>(type)) {
  82. addCapabilityForType(matType->getElementType(), loc, sc);
  83. }
  84. // Arrays
  85. else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
  86. addCapabilityForType(arrType->getElementType(), loc, sc);
  87. }
  88. // Runtime array of resources requires additional capability.
  89. else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
  90. if (SpirvType::isResourceType(raType->getElementType())) {
  91. // the elements inside the runtime array are resources
  92. addExtension(Extension::EXT_descriptor_indexing,
  93. "runtime array of resources", loc);
  94. addCapability(spv::Capability::RuntimeDescriptorArrayEXT);
  95. }
  96. addCapabilityForType(raType->getElementType(), loc, sc);
  97. }
  98. // Image types
  99. else if (const auto *imageType = dyn_cast<ImageType>(type)) {
  100. switch (imageType->getDimension()) {
  101. case spv::Dim::Buffer: {
  102. addCapability(spv::Capability::SampledBuffer);
  103. if (imageType->withSampler() == ImageType::WithSampler::No) {
  104. addCapability(spv::Capability::ImageBuffer);
  105. }
  106. break;
  107. }
  108. case spv::Dim::Dim1D: {
  109. if (imageType->withSampler() == ImageType::WithSampler::No) {
  110. addCapability(spv::Capability::Image1D);
  111. } else {
  112. addCapability(spv::Capability::Sampled1D);
  113. }
  114. break;
  115. }
  116. case spv::Dim::SubpassData: {
  117. addCapability(spv::Capability::InputAttachment);
  118. break;
  119. }
  120. default:
  121. break;
  122. }
  123. switch (imageType->getImageFormat()) {
  124. case spv::ImageFormat::Rg32f:
  125. case spv::ImageFormat::Rg16f:
  126. case spv::ImageFormat::R11fG11fB10f:
  127. case spv::ImageFormat::R16f:
  128. case spv::ImageFormat::Rgba16:
  129. case spv::ImageFormat::Rgb10A2:
  130. case spv::ImageFormat::Rg16:
  131. case spv::ImageFormat::Rg8:
  132. case spv::ImageFormat::R16:
  133. case spv::ImageFormat::R8:
  134. case spv::ImageFormat::Rgba16Snorm:
  135. case spv::ImageFormat::Rg16Snorm:
  136. case spv::ImageFormat::Rg8Snorm:
  137. case spv::ImageFormat::R16Snorm:
  138. case spv::ImageFormat::R8Snorm:
  139. case spv::ImageFormat::Rg32i:
  140. case spv::ImageFormat::Rg16i:
  141. case spv::ImageFormat::Rg8i:
  142. case spv::ImageFormat::R16i:
  143. case spv::ImageFormat::R8i:
  144. case spv::ImageFormat::Rgb10a2ui:
  145. case spv::ImageFormat::Rg32ui:
  146. case spv::ImageFormat::Rg16ui:
  147. case spv::ImageFormat::Rg8ui:
  148. case spv::ImageFormat::R16ui:
  149. case spv::ImageFormat::R8ui:
  150. addCapability(spv::Capability::StorageImageExtendedFormats);
  151. break;
  152. default:
  153. // Only image formats requiring extended formats are relevant. The rest
  154. // just pass through.
  155. break;
  156. }
  157. if (imageType->isArrayedImage() && imageType->isMSImage())
  158. addCapability(spv::Capability::ImageMSArray);
  159. addCapabilityForType(imageType->getSampledType(), loc, sc);
  160. }
  161. // Sampled image type
  162. else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
  163. addCapabilityForType(sampledImageType->getImageType(), loc, sc);
  164. }
  165. // Pointer type
  166. else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
  167. addCapabilityForType(ptrType->getPointeeType(), loc, sc);
  168. }
  169. // Struct type
  170. else if (const auto *structType = dyn_cast<StructType>(type)) {
  171. if (SpirvType::isOrContainsType<NumericalType, 16>(structType)) {
  172. addExtension(Extension::KHR_16bit_storage, "16-bit types in resource",
  173. loc);
  174. if (sc == spv::StorageClass::PushConstant) {
  175. addCapability(spv::Capability::StoragePushConstant16);
  176. } else if (structType->getInterfaceType() ==
  177. StructInterfaceType::UniformBuffer) {
  178. addCapability(spv::Capability::StorageUniform16);
  179. } else if (structType->getInterfaceType() ==
  180. StructInterfaceType::StorageBuffer) {
  181. addCapability(spv::Capability::StorageUniformBufferBlock16);
  182. }
  183. }
  184. for (auto field : structType->getFields())
  185. addCapabilityForType(field.type, loc, sc);
  186. }
  187. //
  188. else if (const auto *rayQueryType =
  189. dyn_cast<RayQueryProvisionalTypeKHR>(type)) {
  190. addCapability(spv::Capability::RayQueryProvisionalKHR);
  191. addExtension(Extension::KHR_ray_query, "SPV_KHR_ray_query", {});
  192. }
  193. }
  194. bool CapabilityVisitor::visit(SpirvDecoration *decor) {
  195. const auto loc = decor->getSourceLocation();
  196. switch (decor->getDecoration()) {
  197. case spv::Decoration::Sample: {
  198. addCapability(spv::Capability::SampleRateShading, loc);
  199. break;
  200. }
  201. case spv::Decoration::NonUniformEXT: {
  202. addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT", loc);
  203. addCapability(spv::Capability::ShaderNonUniformEXT);
  204. break;
  205. }
  206. case spv::Decoration::HlslSemanticGOOGLE:
  207. case spv::Decoration::HlslCounterBufferGOOGLE: {
  208. addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
  209. loc);
  210. break;
  211. }
  212. // Capabilities needed for built-ins
  213. case spv::Decoration::BuiltIn: {
  214. assert(decor->getParams().size() == 1);
  215. const auto builtin = static_cast<spv::BuiltIn>(decor->getParams()[0]);
  216. switch (builtin) {
  217. case spv::BuiltIn::SampleId:
  218. case spv::BuiltIn::SamplePosition: {
  219. addCapability(spv::Capability::SampleRateShading, loc);
  220. break;
  221. }
  222. case spv::BuiltIn::SubgroupSize:
  223. case spv::BuiltIn::NumSubgroups:
  224. case spv::BuiltIn::SubgroupId:
  225. case spv::BuiltIn::SubgroupLocalInvocationId: {
  226. addCapability(spv::Capability::GroupNonUniform, loc);
  227. break;
  228. }
  229. case spv::BuiltIn::BaseVertex: {
  230. addExtension(Extension::KHR_shader_draw_parameters, "BaseVertex Builtin",
  231. loc);
  232. addCapability(spv::Capability::DrawParameters);
  233. break;
  234. }
  235. case spv::BuiltIn::BaseInstance: {
  236. addExtension(Extension::KHR_shader_draw_parameters,
  237. "BaseInstance Builtin", loc);
  238. addCapability(spv::Capability::DrawParameters);
  239. break;
  240. }
  241. case spv::BuiltIn::DrawIndex: {
  242. addExtension(Extension::KHR_shader_draw_parameters, "DrawIndex Builtin",
  243. loc);
  244. addCapability(spv::Capability::DrawParameters);
  245. break;
  246. }
  247. case spv::BuiltIn::DeviceIndex: {
  248. addExtension(Extension::KHR_device_group, "DeviceIndex Builtin", loc);
  249. addCapability(spv::Capability::DeviceGroup);
  250. break;
  251. }
  252. case spv::BuiltIn::FragStencilRefEXT: {
  253. addExtension(Extension::EXT_shader_stencil_export, "SV_StencilRef", loc);
  254. addCapability(spv::Capability::StencilExportEXT);
  255. break;
  256. }
  257. case spv::BuiltIn::ViewIndex: {
  258. addExtension(Extension::KHR_multiview, "SV_ViewID", loc);
  259. addCapability(spv::Capability::MultiView);
  260. break;
  261. }
  262. case spv::BuiltIn::FullyCoveredEXT: {
  263. addExtension(Extension::EXT_fragment_fully_covered, "SV_InnerCoverage",
  264. loc);
  265. addCapability(spv::Capability::FragmentFullyCoveredEXT);
  266. break;
  267. }
  268. case spv::BuiltIn::PrimitiveId: {
  269. // PrimitiveID can be used as PSIn or MSPOut.
  270. if (shaderModel == spv::ExecutionModel::Fragment ||
  271. shaderModel == spv::ExecutionModel::MeshNV)
  272. addCapability(spv::Capability::Geometry);
  273. break;
  274. }
  275. case spv::BuiltIn::Layer: {
  276. if (shaderModel == spv::ExecutionModel::Vertex ||
  277. shaderModel == spv::ExecutionModel::TessellationControl ||
  278. shaderModel == spv::ExecutionModel::TessellationEvaluation) {
  279. addExtension(Extension::EXT_shader_viewport_index_layer,
  280. "SV_RenderTargetArrayIndex", loc);
  281. addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
  282. } else if (shaderModel == spv::ExecutionModel::Fragment ||
  283. shaderModel == spv::ExecutionModel::MeshNV) {
  284. // SV_RenderTargetArrayIndex can be used as PSIn or MSPOut.
  285. addCapability(spv::Capability::Geometry);
  286. }
  287. break;
  288. }
  289. case spv::BuiltIn::ViewportIndex: {
  290. if (shaderModel == spv::ExecutionModel::Vertex ||
  291. shaderModel == spv::ExecutionModel::TessellationControl ||
  292. shaderModel == spv::ExecutionModel::TessellationEvaluation) {
  293. addExtension(Extension::EXT_shader_viewport_index_layer,
  294. "SV_ViewPortArrayIndex", loc);
  295. addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
  296. } else if (shaderModel == spv::ExecutionModel::Fragment ||
  297. shaderModel == spv::ExecutionModel::Geometry ||
  298. shaderModel == spv::ExecutionModel::MeshNV) {
  299. // SV_ViewportArrayIndex can be used as PSIn or GSOut or MSPOut.
  300. addCapability(spv::Capability::MultiViewport);
  301. }
  302. break;
  303. }
  304. case spv::BuiltIn::ClipDistance: {
  305. addCapability(spv::Capability::ClipDistance);
  306. break;
  307. }
  308. case spv::BuiltIn::CullDistance: {
  309. addCapability(spv::Capability::CullDistance);
  310. break;
  311. }
  312. case spv::BuiltIn::BaryCoordNoPerspAMD:
  313. case spv::BuiltIn::BaryCoordNoPerspCentroidAMD:
  314. case spv::BuiltIn::BaryCoordNoPerspSampleAMD:
  315. case spv::BuiltIn::BaryCoordSmoothAMD:
  316. case spv::BuiltIn::BaryCoordSmoothCentroidAMD:
  317. case spv::BuiltIn::BaryCoordSmoothSampleAMD:
  318. case spv::BuiltIn::BaryCoordPullModelAMD: {
  319. addExtension(Extension::AMD_shader_explicit_vertex_parameter,
  320. "SV_Barycentrics", loc);
  321. break;
  322. }
  323. case spv::BuiltIn::FragSizeEXT: {
  324. addExtension(Extension::EXT_fragment_invocation_density, "SV_ShadingRate",
  325. loc);
  326. addCapability(spv::Capability::FragmentDensityEXT);
  327. break;
  328. }
  329. default:
  330. break;
  331. }
  332. break;
  333. }
  334. default:
  335. break;
  336. }
  337. return true;
  338. }
  339. spv::Capability
  340. CapabilityVisitor::getNonUniformCapability(const SpirvType *type) {
  341. if (!type)
  342. return spv::Capability::Max;
  343. if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
  344. return getNonUniformCapability(arrayType->getElementType());
  345. }
  346. if (SpirvType::isTexture(type) || SpirvType::isSampler(type)) {
  347. return spv::Capability::SampledImageArrayNonUniformIndexingEXT;
  348. }
  349. if (SpirvType::isRWTexture(type)) {
  350. return spv::Capability::StorageImageArrayNonUniformIndexingEXT;
  351. }
  352. if (SpirvType::isBuffer(type)) {
  353. return spv::Capability::UniformTexelBufferArrayNonUniformIndexingEXT;
  354. }
  355. if (SpirvType::isRWBuffer(type)) {
  356. return spv::Capability::StorageTexelBufferArrayNonUniformIndexingEXT;
  357. }
  358. if (SpirvType::isSubpassInput(type) || SpirvType::isSubpassInputMS(type)) {
  359. return spv::Capability::InputAttachmentArrayNonUniformIndexingEXT;
  360. }
  361. return spv::Capability::Max;
  362. }
  363. bool CapabilityVisitor::visit(SpirvImageQuery *instr) {
  364. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  365. instr->getStorageClass());
  366. addCapability(spv::Capability::ImageQuery);
  367. return true;
  368. }
  369. bool CapabilityVisitor::visit(SpirvImageSparseTexelsResident *instr) {
  370. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  371. instr->getStorageClass());
  372. addCapability(spv::Capability::ImageGatherExtended);
  373. addCapability(spv::Capability::SparseResidency);
  374. return true;
  375. }
  376. bool CapabilityVisitor::visit(SpirvImageOp *instr) {
  377. addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
  378. instr->getStorageClass());
  379. if (instr->hasOffset() || instr->hasConstOffsets())
  380. addCapability(spv::Capability::ImageGatherExtended);
  381. if (instr->hasMinLod())
  382. addCapability(spv::Capability::MinLod);
  383. if (instr->isSparse())
  384. addCapability(spv::Capability::SparseResidency);
  385. return true;
  386. }
  387. bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
  388. const SpirvType *resultType = instr->getResultType();
  389. const auto opcode = instr->getopcode();
  390. const auto loc = instr->getSourceLocation();
  391. // Add result-type-specific capabilities
  392. addCapabilityForType(resultType, loc, instr->getStorageClass());
  393. // Add NonUniform capabilities if necessary
  394. if (instr->isNonUniform()) {
  395. addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT", loc);
  396. addCapability(spv::Capability::ShaderNonUniformEXT);
  397. addCapability(getNonUniformCapability(resultType));
  398. }
  399. // Add opcode-specific capabilities
  400. switch (opcode) {
  401. case spv::Op::OpDPdxCoarse:
  402. case spv::Op::OpDPdyCoarse:
  403. case spv::Op::OpFwidthCoarse:
  404. case spv::Op::OpDPdxFine:
  405. case spv::Op::OpDPdyFine:
  406. case spv::Op::OpFwidthFine:
  407. addCapability(spv::Capability::DerivativeControl);
  408. break;
  409. case spv::Op::OpGroupNonUniformElect:
  410. addCapability(spv::Capability::GroupNonUniform);
  411. break;
  412. case spv::Op::OpGroupNonUniformAny:
  413. case spv::Op::OpGroupNonUniformAll:
  414. case spv::Op::OpGroupNonUniformAllEqual:
  415. addCapability(spv::Capability::GroupNonUniformVote);
  416. break;
  417. case spv::Op::OpGroupNonUniformBallot:
  418. case spv::Op::OpGroupNonUniformInverseBallot:
  419. case spv::Op::OpGroupNonUniformBallotBitExtract:
  420. case spv::Op::OpGroupNonUniformBallotBitCount:
  421. case spv::Op::OpGroupNonUniformBallotFindLSB:
  422. case spv::Op::OpGroupNonUniformBallotFindMSB:
  423. case spv::Op::OpGroupNonUniformBroadcast:
  424. case spv::Op::OpGroupNonUniformBroadcastFirst:
  425. addCapability(spv::Capability::GroupNonUniformBallot);
  426. break;
  427. case spv::Op::OpGroupNonUniformIAdd:
  428. case spv::Op::OpGroupNonUniformFAdd:
  429. case spv::Op::OpGroupNonUniformIMul:
  430. case spv::Op::OpGroupNonUniformFMul:
  431. case spv::Op::OpGroupNonUniformSMax:
  432. case spv::Op::OpGroupNonUniformUMax:
  433. case spv::Op::OpGroupNonUniformFMax:
  434. case spv::Op::OpGroupNonUniformSMin:
  435. case spv::Op::OpGroupNonUniformUMin:
  436. case spv::Op::OpGroupNonUniformFMin:
  437. case spv::Op::OpGroupNonUniformBitwiseAnd:
  438. case spv::Op::OpGroupNonUniformBitwiseOr:
  439. case spv::Op::OpGroupNonUniformBitwiseXor:
  440. case spv::Op::OpGroupNonUniformLogicalAnd:
  441. case spv::Op::OpGroupNonUniformLogicalOr:
  442. case spv::Op::OpGroupNonUniformLogicalXor:
  443. addCapability(spv::Capability::GroupNonUniformArithmetic);
  444. break;
  445. case spv::Op::OpGroupNonUniformQuadBroadcast:
  446. case spv::Op::OpGroupNonUniformQuadSwap:
  447. addCapability(spv::Capability::GroupNonUniformQuad);
  448. break;
  449. case spv::Op::OpVariable: {
  450. if (spvOptions.enableReflect &&
  451. !cast<SpirvVariable>(instr)->getHlslUserType().empty()) {
  452. addExtension(Extension::GOOGLE_user_type, "HLSL User Type", loc);
  453. addExtension(Extension::GOOGLE_hlsl_functionality1, "HLSL User Type",
  454. loc);
  455. }
  456. break;
  457. }
  458. case spv::Op::OpRayQueryInitializeKHR: {
  459. auto rayQueryInst = dyn_cast<SpirvRayQueryOpKHR>(instr);
  460. if (rayQueryInst->hasCullFlags()) {
  461. addCapability(
  462. spv::Capability::RayTraversalPrimitiveCullingProvisionalKHR);
  463. }
  464. break;
  465. }
  466. default:
  467. break;
  468. }
  469. return true;
  470. }
  471. bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
  472. shaderModel = entryPoint->getExecModel();
  473. switch (shaderModel) {
  474. case spv::ExecutionModel::Fragment:
  475. case spv::ExecutionModel::Vertex:
  476. case spv::ExecutionModel::GLCompute:
  477. addCapability(spv::Capability::Shader);
  478. break;
  479. case spv::ExecutionModel::Geometry:
  480. addCapability(spv::Capability::Geometry);
  481. break;
  482. case spv::ExecutionModel::TessellationControl:
  483. case spv::ExecutionModel::TessellationEvaluation:
  484. addCapability(spv::Capability::Tessellation);
  485. break;
  486. case spv::ExecutionModel::RayGenerationNV:
  487. case spv::ExecutionModel::IntersectionNV:
  488. case spv::ExecutionModel::ClosestHitNV:
  489. case spv::ExecutionModel::AnyHitNV:
  490. case spv::ExecutionModel::MissNV:
  491. case spv::ExecutionModel::CallableNV:
  492. if (featureManager.isExtensionEnabled("SPV_NV_ray_tracing")) {
  493. addCapability(spv::Capability::RayTracingNV);
  494. addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing", {});
  495. } else {
  496. // KHR_ray_tracing extension requires SPIR-V 1.4/Vulkan 1.2
  497. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_2, "Raytracing", {});
  498. addCapability(spv::Capability::RayTracingProvisionalKHR);
  499. addExtension(Extension::KHR_ray_tracing, "SPV_KHR_ray_tracing", {});
  500. }
  501. break;
  502. case spv::ExecutionModel::MeshNV:
  503. case spv::ExecutionModel::TaskNV:
  504. addCapability(spv::Capability::MeshShadingNV);
  505. addExtension(Extension::NV_mesh_shader, "SPV_NV_mesh_shader", {});
  506. break;
  507. default:
  508. llvm_unreachable("found unknown shader model");
  509. break;
  510. }
  511. return true;
  512. }
  513. bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
  514. if (execMode->getExecutionMode() == spv::ExecutionMode::PostDepthCoverage) {
  515. addCapability(spv::Capability::SampleMaskPostDepthCoverage,
  516. execMode->getEntryPoint()->getSourceLocation());
  517. addExtension(Extension::KHR_post_depth_coverage,
  518. "[[vk::post_depth_coverage]]", execMode->getSourceLocation());
  519. }
  520. return true;
  521. }
  522. bool CapabilityVisitor::visit(SpirvExtInstImport *instr) {
  523. if (instr->getExtendedInstSetName() == "NonSemantic.DebugPrintf")
  524. addExtension(Extension::KHR_non_semantic_info, "DebugPrintf",
  525. /*SourceLocation*/ {});
  526. return true;
  527. }
  528. bool CapabilityVisitor::visit(SpirvExtInst *instr) {
  529. // OpExtInst using the GLSL extended instruction allows only 32-bit types by
  530. // default for interpolation instructions. The AMD_gpu_shader_half_float
  531. // extension adds support for 16-bit floating-point component types for these
  532. // instructions:
  533. // InterpolateAtCentroid, InterpolateAtSample, InterpolateAtOffset
  534. if (SpirvType::isOrContainsType<FloatType, 16>(instr->getResultType()))
  535. switch (instr->getInstruction()) {
  536. case GLSLstd450::GLSLstd450InterpolateAtCentroid:
  537. case GLSLstd450::GLSLstd450InterpolateAtSample:
  538. case GLSLstd450::GLSLstd450InterpolateAtOffset:
  539. addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float",
  540. instr->getSourceLocation());
  541. default:
  542. break;
  543. }
  544. return visitInstruction(instr);
  545. }
  546. bool CapabilityVisitor::visit(SpirvDemoteToHelperInvocationEXT *inst) {
  547. addCapability(spv::Capability::DemoteToHelperInvocationEXT,
  548. inst->getSourceLocation());
  549. addExtension(Extension::EXT_demote_to_helper_invocation, "discard",
  550. inst->getSourceLocation());
  551. return true;
  552. }
  553. } // end namespace spirv
  554. } // end namespace clang