2
0

trim_capabilities_pass.h 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. // Copyright (c) 2023 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
  15. #define SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
  16. #include <algorithm>
  17. #include <array>
  18. #include <functional>
  19. #include <optional>
  20. #include <unordered_map>
  21. #include <unordered_set>
  22. #include "source/enum_set.h"
  23. #include "source/extensions.h"
  24. #include "source/opt/ir_context.h"
  25. #include "source/opt/module.h"
  26. #include "source/opt/pass.h"
  27. #include "source/spirv_target_env.h"
  28. namespace spvtools {
  29. namespace opt {
  30. // This is required for NDK build. The unordered_set/unordered_map
  31. // implementation don't work with class enums.
  32. struct ClassEnumHash {
  33. std::size_t operator()(spv::Capability value) const {
  34. using StoringType = typename std::underlying_type_t<spv::Capability>;
  35. return std::hash<StoringType>{}(static_cast<StoringType>(value));
  36. }
  37. std::size_t operator()(spv::Op value) const {
  38. using StoringType = typename std::underlying_type_t<spv::Op>;
  39. return std::hash<StoringType>{}(static_cast<StoringType>(value));
  40. }
  41. };
  42. // An opcode handler is a function which, given an instruction, returns either
  43. // the required capability, or nothing.
  44. // Each handler checks one case for a capability requirement.
  45. //
  46. // Example:
  47. // - `OpTypeImage` can have operand `A` operand which requires capability 1
  48. // - `OpTypeImage` can also have operand `B` which requires capability 2.
  49. // -> We have 2 handlers: `Handler_OpTypeImage_1` and
  50. // `Handler_OpTypeImage_2`.
  51. using OpcodeHandler =
  52. std::optional<spv::Capability> (*)(const Instruction* instruction);
  53. // This pass tried to remove superfluous capabilities declared in the module.
  54. // - If all the capabilities listed by an extension are removed, the extension
  55. // is also trimmed.
  56. // - If the module countains any capability listed in `kForbiddenCapabilities`,
  57. // the module is left untouched.
  58. // - No capabilities listed in `kUntouchableCapabilities` are trimmed, even when
  59. // not used.
  60. // - Only capabilitied listed in `kSupportedCapabilities` are supported.
  61. // - If the module contains unsupported capabilities, results might be
  62. // incorrect.
  63. class TrimCapabilitiesPass : public Pass {
  64. private:
  65. // All the capabilities supported by this optimization pass. If your module
  66. // contains unsupported instruction, the pass could yield bad results.
  67. static constexpr std::array kSupportedCapabilities{
  68. // clang-format off
  69. spv::Capability::ComputeDerivativeGroupLinearKHR,
  70. spv::Capability::ComputeDerivativeGroupQuadsKHR,
  71. spv::Capability::Float16,
  72. spv::Capability::Float64,
  73. spv::Capability::FragmentShaderPixelInterlockEXT,
  74. spv::Capability::FragmentShaderSampleInterlockEXT,
  75. spv::Capability::FragmentShaderShadingRateInterlockEXT,
  76. spv::Capability::GroupNonUniform,
  77. spv::Capability::GroupNonUniformArithmetic,
  78. spv::Capability::GroupNonUniformClustered,
  79. spv::Capability::GroupNonUniformPartitionedNV,
  80. spv::Capability::GroupNonUniformVote,
  81. spv::Capability::Groups,
  82. spv::Capability::ImageMSArray,
  83. spv::Capability::Int16,
  84. spv::Capability::Int64,
  85. spv::Capability::InterpolationFunction,
  86. spv::Capability::Linkage,
  87. spv::Capability::MinLod,
  88. spv::Capability::PhysicalStorageBufferAddresses,
  89. spv::Capability::RayQueryKHR,
  90. spv::Capability::RayTracingKHR,
  91. spv::Capability::RayTraversalPrimitiveCullingKHR,
  92. spv::Capability::Shader,
  93. spv::Capability::ShaderClockKHR,
  94. spv::Capability::StorageBuffer16BitAccess,
  95. spv::Capability::StorageImageReadWithoutFormat,
  96. spv::Capability::StorageImageWriteWithoutFormat,
  97. spv::Capability::StorageInputOutput16,
  98. spv::Capability::StoragePushConstant16,
  99. spv::Capability::StorageUniform16,
  100. spv::Capability::StorageUniformBufferBlock16,
  101. spv::Capability::VulkanMemoryModelDeviceScope,
  102. // clang-format on
  103. };
  104. // Those capabilities disable all transformation of the module.
  105. static constexpr std::array kForbiddenCapabilities{
  106. spv::Capability::Linkage,
  107. };
  108. // Those capabilities are never removed from a module because we cannot
  109. // guess from the SPIR-V only if they are required or not.
  110. static constexpr std::array kUntouchableCapabilities{
  111. spv::Capability::Shader,
  112. };
  113. public:
  114. TrimCapabilitiesPass();
  115. TrimCapabilitiesPass(const TrimCapabilitiesPass&) = delete;
  116. TrimCapabilitiesPass(TrimCapabilitiesPass&&) = delete;
  117. private:
  118. // Inserts every capability listed by `descriptor` this pass supports into
  119. // `output`. Expects a Descriptor like `spv_opcode_desc_t` or
  120. // `spv_operand_desc_t`.
  121. template <class Descriptor>
  122. inline void addSupportedCapabilitiesToSet(const Descriptor* const descriptor,
  123. CapabilitySet* output) const {
  124. const uint32_t capabilityCount = descriptor->numCapabilities;
  125. for (uint32_t i = 0; i < capabilityCount; ++i) {
  126. const auto capability = descriptor->capabilities[i];
  127. if (supportedCapabilities_.contains(capability)) {
  128. output->insert(capability);
  129. }
  130. }
  131. }
  132. // Inserts every extension listed by `descriptor` required by the module into
  133. // `output`. Expects a Descriptor like `spv_opcode_desc_t` or
  134. // `spv_operand_desc_t`.
  135. template <class Descriptor>
  136. inline void addSupportedExtensionsToSet(const Descriptor* const descriptor,
  137. ExtensionSet* output) const {
  138. if (descriptor->minVersion <=
  139. spvVersionForTargetEnv(context()->GetTargetEnv())) {
  140. return;
  141. }
  142. output->insert(descriptor->extensions,
  143. descriptor->extensions + descriptor->numExtensions);
  144. }
  145. void addInstructionRequirementsForOpcode(spv::Op opcode,
  146. CapabilitySet* capabilities,
  147. ExtensionSet* extensions) const;
  148. void addInstructionRequirementsForOperand(const Operand& operand,
  149. CapabilitySet* capabilities,
  150. ExtensionSet* extensions) const;
  151. void addInstructionRequirementsForExtInst(Instruction* instruction,
  152. CapabilitySet* capabilities) const;
  153. // Given an `instruction`, determines the capabilities it requires, and output
  154. // them in `capabilities`. The returned capabilities form a subset of
  155. // kSupportedCapabilities.
  156. void addInstructionRequirements(Instruction* instruction,
  157. CapabilitySet* capabilities,
  158. ExtensionSet* extensions) const;
  159. // Given an operand `type` and `value`, adds the extensions it would require
  160. // to `extensions`.
  161. void AddExtensionsForOperand(const spv_operand_type_t type,
  162. const uint32_t value,
  163. ExtensionSet* extensions) const;
  164. // Returns the list of required capabilities and extensions for the module.
  165. // The returned capabilities form a subset of kSupportedCapabilities.
  166. std::pair<CapabilitySet, ExtensionSet>
  167. DetermineRequiredCapabilitiesAndExtensions() const;
  168. // Trims capabilities not listed in `required_capabilities` if possible.
  169. // Returns whether or not the module was modified.
  170. Pass::Status TrimUnrequiredCapabilities(
  171. const CapabilitySet& required_capabilities) const;
  172. // Trims extensions not listed in `required_extensions` if supported by this
  173. // pass. An extensions is considered supported as soon as one capability this
  174. // pass support requires it.
  175. Pass::Status TrimUnrequiredExtensions(
  176. const ExtensionSet& required_extensions) const;
  177. // Returns if the analyzed module contains any forbidden capability.
  178. bool HasForbiddenCapabilities() const;
  179. public:
  180. const char* name() const override { return "trim-capabilities"; }
  181. Status Process() override;
  182. private:
  183. const CapabilitySet supportedCapabilities_;
  184. const CapabilitySet forbiddenCapabilities_;
  185. const CapabilitySet untouchableCapabilities_;
  186. const std::unordered_multimap<spv::Op, OpcodeHandler, ClassEnumHash>
  187. opcodeHandlers_;
  188. };
  189. } // namespace opt
  190. } // namespace spvtools
  191. #endif // SOURCE_OPT_TRIM_CAPABILITIES_H_