trim_capabilities_pass.cpp 28 KB


  1. // Copyright (c) 2023 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/trim_capabilities_pass.h"
  15. #include <algorithm>
  16. #include <array>
  17. #include <cassert>
  18. #include <functional>
  19. #include <optional>
  20. #include <queue>
  21. #include <stack>
  22. #include <unordered_map>
  23. #include <unordered_set>
  24. #include <vector>
  25. #include "source/enum_set.h"
  26. #include "source/enum_string_mapping.h"
  27. #include "source/ext_inst.h"
  28. #include "source/opt/ir_context.h"
  29. #include "source/opt/reflect.h"
  30. #include "source/spirv_target_env.h"
  31. #include "source/util/string_utils.h"
  32. namespace spvtools {
  33. namespace opt {
  34. namespace {
  35. constexpr uint32_t kOpTypeFloatSizeIndex = 0;
  36. constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
  37. constexpr uint32_t kTypeArrayTypeIndex = 0;
  38. constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
  39. constexpr uint32_t kTypePointerTypeIdInIndex = 1;
  40. constexpr uint32_t kOpTypeIntSizeIndex = 0;
  41. constexpr uint32_t kOpTypeImageDimIndex = 1;
  42. constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
  43. constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
  44. constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
  45. constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
  46. constexpr uint32_t kOpImageReadImageIndex = 0;
  47. constexpr uint32_t kOpImageWriteImageIndex = 0;
  48. constexpr uint32_t kOpImageSparseReadImageIndex = 0;
  49. constexpr uint32_t kOpExtInstSetInIndex = 0;
  50. constexpr uint32_t kOpExtInstInstructionInIndex = 1;
  51. constexpr uint32_t kOpExtInstImportNameInIndex = 0;
  52. // DFS visit of the type defined by `instruction`.
  53. // If `condition` is true, children of the current node are visited.
  54. // If `condition` is false, the children of the current node are ignored.
  55. template <class UnaryPredicate>
  56. static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
  57. std::stack<uint32_t> instructions_to_visit;
  58. instructions_to_visit.push(instruction->result_id());
  59. const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
  60. while (!instructions_to_visit.empty()) {
  61. const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
  62. instructions_to_visit.pop();
  63. if (!condition(item)) {
  64. continue;
  65. }
  66. if (item->opcode() == spv::Op::OpTypePointer) {
  67. instructions_to_visit.push(
  68. item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
  69. continue;
  70. }
  71. if (item->opcode() == spv::Op::OpTypeMatrix ||
  72. item->opcode() == spv::Op::OpTypeVector ||
  73. item->opcode() == spv::Op::OpTypeArray ||
  74. item->opcode() == spv::Op::OpTypeRuntimeArray) {
  75. instructions_to_visit.push(
  76. item->GetSingleWordInOperand(kTypeArrayTypeIndex));
  77. continue;
  78. }
  79. if (item->opcode() == spv::Op::OpTypeStruct) {
  80. item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
  81. instructions_to_visit.push(*op_id);
  82. });
  83. continue;
  84. }
  85. }
  86. }
  87. // Walks the type defined by `instruction` (OpType* only).
  88. // Returns `true` if any call to `predicate` with the type/subtype returns true.
  89. template <class UnaryPredicate>
  90. static bool AnyTypeOf(const Instruction* instruction,
  91. UnaryPredicate predicate) {
  92. assert(IsTypeInst(instruction->opcode()) &&
  93. "AnyTypeOf called with a non-type instruction.");
  94. bool found_one = false;
  95. DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
  96. if (found_one || predicate(node)) {
  97. found_one = true;
  98. return false;
  99. }
  100. return true;
  101. });
  102. return found_one;
  103. }
  104. static bool is16bitType(const Instruction* instruction) {
  105. if (instruction->opcode() != spv::Op::OpTypeInt &&
  106. instruction->opcode() != spv::Op::OpTypeFloat) {
  107. return false;
  108. }
  109. return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
  110. }
  111. static bool Has16BitCapability(const FeatureManager* feature_manager) {
  112. const CapabilitySet& capabilities = feature_manager->GetCapabilities();
  113. return capabilities.contains(spv::Capability::Float16) ||
  114. capabilities.contains(spv::Capability::Int16);
  115. }
  116. } // namespace
  117. // ============== Begin opcode handler implementations. =======================
  118. //
  119. // Adding support for a new capability should only require adding a new handler,
  120. // and updating the
  121. // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
  122. //
  123. // Handler names follow the following convention:
  124. // Handler_<Opcode>_<Capability>()
  125. static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
  126. const Instruction* instruction) {
  127. assert(instruction->opcode() == spv::Op::OpTypeFloat &&
  128. "This handler only support OpTypeFloat opcodes.");
  129. const uint32_t size =
  130. instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
  131. return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
  132. }
  133. static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
  134. const Instruction* instruction) {
  135. assert(instruction->opcode() == spv::Op::OpTypeFloat &&
  136. "This handler only support OpTypeFloat opcodes.");
  137. const uint32_t size =
  138. instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
  139. return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
  140. }
  141. static std::optional<spv::Capability>
  142. Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
  143. assert(instruction->opcode() == spv::Op::OpTypePointer &&
  144. "This handler only support OpTypePointer opcodes.");
  145. // This capability is only required if the variable has an Input/Output
  146. // storage class.
  147. spv::StorageClass storage_class = spv::StorageClass(
  148. instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  149. if (storage_class != spv::StorageClass::Input &&
  150. storage_class != spv::StorageClass::Output) {
  151. return std::nullopt;
  152. }
  153. if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
  154. return std::nullopt;
  155. }
  156. return AnyTypeOf(instruction, is16bitType)
  157. ? std::optional(spv::Capability::StorageInputOutput16)
  158. : std::nullopt;
  159. }
  160. static std::optional<spv::Capability>
  161. Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
  162. assert(instruction->opcode() == spv::Op::OpTypePointer &&
  163. "This handler only support OpTypePointer opcodes.");
  164. // This capability is only required if the variable has a PushConstant storage
  165. // class.
  166. spv::StorageClass storage_class = spv::StorageClass(
  167. instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  168. if (storage_class != spv::StorageClass::PushConstant) {
  169. return std::nullopt;
  170. }
  171. if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
  172. return std::nullopt;
  173. }
  174. return AnyTypeOf(instruction, is16bitType)
  175. ? std::optional(spv::Capability::StoragePushConstant16)
  176. : std::nullopt;
  177. }
  178. static std::optional<spv::Capability>
  179. Handler_OpTypePointer_StorageUniformBufferBlock16(
  180. const Instruction* instruction) {
  181. assert(instruction->opcode() == spv::Op::OpTypePointer &&
  182. "This handler only support OpTypePointer opcodes.");
  183. // This capability is only required if the variable has a Uniform storage
  184. // class.
  185. spv::StorageClass storage_class = spv::StorageClass(
  186. instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  187. if (storage_class != spv::StorageClass::Uniform) {
  188. return std::nullopt;
  189. }
  190. if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
  191. return std::nullopt;
  192. }
  193. const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
  194. const bool matchesCondition =
  195. AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
  196. if (!decoration_mgr->HasDecoration(item->result_id(),
  197. spv::Decoration::BufferBlock)) {
  198. return false;
  199. }
  200. return AnyTypeOf(item, is16bitType);
  201. });
  202. return matchesCondition
  203. ? std::optional(spv::Capability::StorageUniformBufferBlock16)
  204. : std::nullopt;
  205. }
  206. static std::optional<spv::Capability>
  207. Handler_OpTypePointer_StorageBuffer16BitAccess(const Instruction* instruction) {
  208. assert(instruction->opcode() == spv::Op::OpTypePointer &&
  209. "This handler only support OpTypePointer opcodes.");
  210. // Requires StorageBuffer, ShaderRecordBufferKHR or PhysicalStorageBuffer
  211. // storage classes.
  212. spv::StorageClass storage_class = spv::StorageClass(
  213. instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  214. if (storage_class != spv::StorageClass::StorageBuffer &&
  215. storage_class != spv::StorageClass::ShaderRecordBufferKHR &&
  216. storage_class != spv::StorageClass::PhysicalStorageBuffer) {
  217. return std::nullopt;
  218. }
  219. const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
  220. const bool matchesCondition =
  221. AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
  222. if (!decoration_mgr->HasDecoration(item->result_id(),
  223. spv::Decoration::Block)) {
  224. return false;
  225. }
  226. return AnyTypeOf(item, is16bitType);
  227. });
  228. return matchesCondition
  229. ? std::optional(spv::Capability::StorageBuffer16BitAccess)
  230. : std::nullopt;
  231. }
  232. static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
  233. const Instruction* instruction) {
  234. assert(instruction->opcode() == spv::Op::OpTypePointer &&
  235. "This handler only support OpTypePointer opcodes.");
  236. // This capability is only required if the variable has a Uniform storage
  237. // class.
  238. spv::StorageClass storage_class = spv::StorageClass(
  239. instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
  240. if (storage_class != spv::StorageClass::Uniform) {
  241. return std::nullopt;
  242. }
  243. const auto* feature_manager = instruction->context()->get_feature_mgr();
  244. if (!Has16BitCapability(feature_manager)) {
  245. return std::nullopt;
  246. }
  247. const bool hasBufferBlockCapability =
  248. feature_manager->GetCapabilities().contains(
  249. spv::Capability::StorageUniformBufferBlock16);
  250. const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
  251. bool found16bitType = false;
  252. DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
  253. &found16bitType](const Instruction* item) {
  254. if (found16bitType) {
  255. return false;
  256. }
  257. if (hasBufferBlockCapability &&
  258. decoration_mgr->HasDecoration(item->result_id(),
  259. spv::Decoration::BufferBlock)) {
  260. return false;
  261. }
  262. if (is16bitType(item)) {
  263. found16bitType = true;
  264. return false;
  265. }
  266. return true;
  267. });
  268. return found16bitType ? std::optional(spv::Capability::StorageUniform16)
  269. : std::nullopt;
  270. }
  271. static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
  272. const Instruction* instruction) {
  273. assert(instruction->opcode() == spv::Op::OpTypeInt &&
  274. "This handler only support OpTypeInt opcodes.");
  275. const uint32_t size =
  276. instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
  277. return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
  278. }
  279. static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
  280. const Instruction* instruction) {
  281. assert(instruction->opcode() == spv::Op::OpTypeInt &&
  282. "This handler only support OpTypeInt opcodes.");
  283. const uint32_t size =
  284. instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
  285. return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
  286. }
  287. static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
  288. const Instruction* instruction) {
  289. assert(instruction->opcode() == spv::Op::OpTypeImage &&
  290. "This handler only support OpTypeImage opcodes.");
  291. const uint32_t arrayed =
  292. instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
  293. const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
  294. const uint32_t sampled =
  295. instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
  296. return arrayed == 1 && sampled == 2 && ms == 1
  297. ? std::optional(spv::Capability::ImageMSArray)
  298. : std::nullopt;
  299. }
  300. static std::optional<spv::Capability>
  301. Handler_OpImageRead_StorageImageReadWithoutFormat(
  302. const Instruction* instruction) {
  303. assert(instruction->opcode() == spv::Op::OpImageRead &&
  304. "This handler only support OpImageRead opcodes.");
  305. const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
  306. const uint32_t image_index =
  307. instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
  308. const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
  309. const Instruction* type = def_use_mgr->GetDef(type_index);
  310. const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
  311. const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
  312. // If the Image Format is Unknown and Dim is SubpassData,
  313. // StorageImageReadWithoutFormat is required.
  314. const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
  315. const bool requires_capability_for_unknown =
  316. spv::Dim(dim) != spv::Dim::SubpassData;
  317. return is_unknown && requires_capability_for_unknown
  318. ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
  319. : std::nullopt;
  320. }
  321. static std::optional<spv::Capability>
  322. Handler_OpImageWrite_StorageImageWriteWithoutFormat(
  323. const Instruction* instruction) {
  324. assert(instruction->opcode() == spv::Op::OpImageWrite &&
  325. "This handler only support OpImageWrite opcodes.");
  326. const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
  327. const uint32_t image_index =
  328. instruction->GetSingleWordInOperand(kOpImageWriteImageIndex);
  329. const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
  330. // If the Image Format is Unknown, StorageImageWriteWithoutFormat is required.
  331. const Instruction* type = def_use_mgr->GetDef(type_index);
  332. const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
  333. const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
  334. return is_unknown
  335. ? std::optional(spv::Capability::StorageImageWriteWithoutFormat)
  336. : std::nullopt;
  337. }
  338. static std::optional<spv::Capability>
  339. Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
  340. const Instruction* instruction) {
  341. assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
  342. "This handler only support OpImageSparseRead opcodes.");
  343. const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
  344. const uint32_t image_index =
  345. instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
  346. const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
  347. const Instruction* type = def_use_mgr->GetDef(type_index);
  348. const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
  349. return spv::ImageFormat(format) == spv::ImageFormat::Unknown
  350. ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
  351. : std::nullopt;
  352. }
  353. // Opcode of interest to determine capabilities requirements.
  354. constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 14> kOpcodeHandlers{{
  355. // clang-format off
  356. {spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
  357. {spv::Op::OpImageWrite, Handler_OpImageWrite_StorageImageWriteWithoutFormat},
  358. {spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
  359. {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
  360. {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
  361. {spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
  362. {spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
  363. {spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
  364. {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
  365. {spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
  366. {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
  367. {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
  368. {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniformBufferBlock16},
  369. {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageBuffer16BitAccess},
  370. // clang-format on
  371. }};
  372. // ============== End opcode handler implementations. =======================
  373. namespace {
  374. ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
  375. const AssemblyGrammar& grammar) {
  376. ExtensionSet output;
  377. const spv_operand_desc_t* desc = nullptr;
  378. for (auto capability : capabilities) {
  379. if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
  380. static_cast<uint32_t>(capability),
  381. &desc)) {
  382. continue;
  383. }
  384. for (uint32_t i = 0; i < desc->numExtensions; ++i) {
  385. output.insert(desc->extensions[i]);
  386. }
  387. }
  388. return output;
  389. }
  390. bool hasOpcodeConflictingCapabilities(spv::Op opcode) {
  391. switch (opcode) {
  392. case spv::Op::OpBeginInvocationInterlockEXT:
  393. case spv::Op::OpEndInvocationInterlockEXT:
  394. case spv::Op::OpGroupNonUniformIAdd:
  395. case spv::Op::OpGroupNonUniformFAdd:
  396. case spv::Op::OpGroupNonUniformIMul:
  397. case spv::Op::OpGroupNonUniformFMul:
  398. case spv::Op::OpGroupNonUniformSMin:
  399. case spv::Op::OpGroupNonUniformUMin:
  400. case spv::Op::OpGroupNonUniformFMin:
  401. case spv::Op::OpGroupNonUniformSMax:
  402. case spv::Op::OpGroupNonUniformUMax:
  403. case spv::Op::OpGroupNonUniformFMax:
  404. case spv::Op::OpGroupNonUniformBitwiseAnd:
  405. case spv::Op::OpGroupNonUniformBitwiseOr:
  406. case spv::Op::OpGroupNonUniformBitwiseXor:
  407. case spv::Op::OpGroupNonUniformLogicalAnd:
  408. case spv::Op::OpGroupNonUniformLogicalOr:
  409. case spv::Op::OpGroupNonUniformLogicalXor:
  410. return true;
  411. default:
  412. return false;
  413. }
  414. }
  415. } // namespace
  416. TrimCapabilitiesPass::TrimCapabilitiesPass()
  417. : supportedCapabilities_(
  418. TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
  419. TrimCapabilitiesPass::kSupportedCapabilities.cend()),
  420. forbiddenCapabilities_(
  421. TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
  422. TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
  423. untouchableCapabilities_(
  424. TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
  425. TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
  426. opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
  427. void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
  428. spv::Op opcode, CapabilitySet* capabilities,
  429. ExtensionSet* extensions) const {
  430. if (hasOpcodeConflictingCapabilities(opcode)) {
  431. return;
  432. }
  433. const spv_opcode_desc_t* desc = {};
  434. auto result = context()->grammar().lookupOpcode(opcode, &desc);
  435. if (result != SPV_SUCCESS) {
  436. return;
  437. }
  438. addSupportedCapabilitiesToSet(desc, capabilities);
  439. addSupportedExtensionsToSet(desc, extensions);
  440. }
  441. void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
  442. const Operand& operand, CapabilitySet* capabilities,
  443. ExtensionSet* extensions) const {
  444. // No supported capability relies on a 2+-word operand.
  445. if (operand.words.size() != 1) {
  446. return;
  447. }
  448. // No supported capability relies on a literal string operand or an ID.
  449. if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
  450. operand.type == SPV_OPERAND_TYPE_ID ||
  451. operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
  452. return;
  453. }
  454. // If the Vulkan memory model is declared and any instruction uses Device
  455. // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
  456. // rule cannot be covered by the grammar, so must be checked explicitly.
  457. if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
  458. const Instruction* memory_model = context()->GetMemoryModel();
  459. if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
  460. uint32_t(spv::MemoryModel::Vulkan)) {
  461. capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
  462. }
  463. }
  464. // case 1: Operand is a single value, can directly lookup.
  465. if (!spvOperandIsConcreteMask(operand.type)) {
  466. const spv_operand_desc_t* desc = {};
  467. auto result = context()->grammar().lookupOperand(operand.type,
  468. operand.words[0], &desc);
  469. if (result != SPV_SUCCESS) {
  470. return;
  471. }
  472. addSupportedCapabilitiesToSet(desc, capabilities);
  473. addSupportedExtensionsToSet(desc, extensions);
  474. return;
  475. }
  476. // case 2: operand can be a bitmask, we need to decompose the lookup.
  477. for (uint32_t i = 0; i < 32; i++) {
  478. const uint32_t mask = (1 << i) & operand.words[0];
  479. if (!mask) {
  480. continue;
  481. }
  482. const spv_operand_desc_t* desc = {};
  483. auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
  484. if (result != SPV_SUCCESS) {
  485. continue;
  486. }
  487. addSupportedCapabilitiesToSet(desc, capabilities);
  488. addSupportedExtensionsToSet(desc, extensions);
  489. }
  490. }
  491. void TrimCapabilitiesPass::addInstructionRequirementsForExtInst(
  492. Instruction* instruction, CapabilitySet* capabilities) const {
  493. assert(instruction->opcode() == spv::Op::OpExtInst &&
  494. "addInstructionRequirementsForExtInst must be passed an OpExtInst "
  495. "instruction");
  496. const auto* def_use_mgr = context()->get_def_use_mgr();
  497. const Instruction* extInstImport = def_use_mgr->GetDef(
  498. instruction->GetSingleWordInOperand(kOpExtInstSetInIndex));
  499. uint32_t extInstruction =
  500. instruction->GetSingleWordInOperand(kOpExtInstInstructionInIndex);
  501. const Operand& extInstSet =
  502. extInstImport->GetInOperand(kOpExtInstImportNameInIndex);
  503. spv_ext_inst_type_t instructionSet =
  504. spvExtInstImportTypeGet(extInstSet.AsString().c_str());
  505. spv_ext_inst_desc desc = {};
  506. auto result =
  507. context()->grammar().lookupExtInst(instructionSet, extInstruction, &desc);
  508. if (result != SPV_SUCCESS) {
  509. return;
  510. }
  511. addSupportedCapabilitiesToSet(desc, capabilities);
  512. }
  513. void TrimCapabilitiesPass::addInstructionRequirements(
  514. Instruction* instruction, CapabilitySet* capabilities,
  515. ExtensionSet* extensions) const {
  516. // Ignoring OpCapability and OpExtension instructions.
  517. if (instruction->opcode() == spv::Op::OpCapability ||
  518. instruction->opcode() == spv::Op::OpExtension) {
  519. return;
  520. }
  521. if (instruction->opcode() == spv::Op::OpExtInst) {
  522. addInstructionRequirementsForExtInst(instruction, capabilities);
  523. } else {
  524. addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
  525. extensions);
  526. }
  527. // Second case: one of the opcode operand is gated by a capability.
  528. const uint32_t operandCount = instruction->NumOperands();
  529. for (uint32_t i = 0; i < operandCount; i++) {
  530. addInstructionRequirementsForOperand(instruction->GetOperand(i),
  531. capabilities, extensions);
  532. }
  533. // Last case: some complex logic needs to be run to determine capabilities.
  534. auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
  535. for (auto it = begin; it != end; it++) {
  536. const OpcodeHandler handler = it->second;
  537. auto result = handler(instruction);
  538. if (!result.has_value()) {
  539. continue;
  540. }
  541. capabilities->insert(*result);
  542. }
  543. }
  544. void TrimCapabilitiesPass::AddExtensionsForOperand(
  545. const spv_operand_type_t type, const uint32_t value,
  546. ExtensionSet* extensions) const {
  547. const spv_operand_desc_t* desc = nullptr;
  548. spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
  549. if (result != SPV_SUCCESS) {
  550. return;
  551. }
  552. addSupportedExtensionsToSet(desc, extensions);
  553. }
  554. std::pair<CapabilitySet, ExtensionSet>
  555. TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
  556. CapabilitySet required_capabilities;
  557. ExtensionSet required_extensions;
  558. get_module()->ForEachInst([&](Instruction* instruction) {
  559. addInstructionRequirements(instruction, &required_capabilities,
  560. &required_extensions);
  561. });
  562. for (auto capability : required_capabilities) {
  563. AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
  564. static_cast<uint32_t>(capability),
  565. &required_extensions);
  566. }
  567. #if !defined(NDEBUG)
  568. // Debug only. We check the outputted required capabilities against the
  569. // supported capabilities list. The supported capabilities list is useful for
  570. // API users to quickly determine if they can use the pass or not. But this
  571. // list has to remain up-to-date with the pass code. If we can detect a
  572. // capability as required, but it's not listed, it means the list is
  573. // out-of-sync. This method is not ideal, but should cover most cases.
  574. {
  575. for (auto capability : required_capabilities) {
  576. assert(supportedCapabilities_.contains(capability) &&
  577. "Module is using a capability that is not listed as supported.");
  578. }
  579. }
  580. #endif
  581. return std::make_pair(std::move(required_capabilities),
  582. std::move(required_extensions));
  583. }
  584. Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
  585. const CapabilitySet& required_capabilities) const {
  586. const FeatureManager* feature_manager = context()->get_feature_mgr();
  587. CapabilitySet capabilities_to_trim;
  588. for (auto capability : feature_manager->GetCapabilities()) {
  589. // Some capabilities cannot be safely removed. Leaving them untouched.
  590. if (untouchableCapabilities_.contains(capability)) {
  591. continue;
  592. }
  593. // If the capability is unsupported, don't trim it.
  594. if (!supportedCapabilities_.contains(capability)) {
  595. continue;
  596. }
  597. if (required_capabilities.contains(capability)) {
  598. continue;
  599. }
  600. capabilities_to_trim.insert(capability);
  601. }
  602. for (auto capability : capabilities_to_trim) {
  603. context()->RemoveCapability(capability);
  604. }
  605. return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
  606. : Pass::Status::SuccessWithChange;
  607. }
  608. Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
  609. const ExtensionSet& required_extensions) const {
  610. const auto supported_extensions =
  611. getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
  612. bool modified_module = false;
  613. for (auto extension : supported_extensions) {
  614. if (required_extensions.contains(extension)) {
  615. continue;
  616. }
  617. if (context()->RemoveExtension(extension)) {
  618. modified_module = true;
  619. }
  620. }
  621. return modified_module ? Pass::Status::SuccessWithChange
  622. : Pass::Status::SuccessWithoutChange;
  623. }
  624. bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
  625. // EnumSet.HasAnyOf returns `true` if the given set is empty.
  626. if (forbiddenCapabilities_.size() == 0) {
  627. return false;
  628. }
  629. const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
  630. return capabilities.HasAnyOf(forbiddenCapabilities_);
  631. }
  632. Pass::Status TrimCapabilitiesPass::Process() {
  633. if (HasForbiddenCapabilities()) {
  634. return Status::SuccessWithoutChange;
  635. }
  636. auto[required_capabilities, required_extensions] =
  637. DetermineRequiredCapabilitiesAndExtensions();
  638. Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
  639. Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
  640. return capStatus == Pass::Status::SuccessWithChange ||
  641. extStatus == Pass::Status::SuccessWithChange
  642. ? Pass::Status::SuccessWithChange
  643. : Pass::Status::SuccessWithoutChange;
  644. }
  645. } // namespace opt
  646. } // namespace spvtools