folding_rules.cpp 113 KB


  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <climits>
  16. #include <limits>
  17. #include <memory>
  18. #include <utility>
  19. #include "ir_builder.h"
  20. #include "source/latest_version_glsl_std_450_header.h"
  21. #include "source/opt/ir_context.h"
  22. namespace spvtools {
  23. namespace opt {
  24. namespace {
  25. constexpr uint32_t kExtractCompositeIdInIdx = 0;
  26. constexpr uint32_t kInsertObjectIdInIdx = 0;
  27. constexpr uint32_t kInsertCompositeIdInIdx = 1;
  28. constexpr uint32_t kExtInstSetIdInIdx = 0;
  29. constexpr uint32_t kExtInstInstructionInIdx = 1;
  30. constexpr uint32_t kFMixXIdInIdx = 2;
  31. constexpr uint32_t kFMixYIdInIdx = 3;
  32. constexpr uint32_t kFMixAIdInIdx = 4;
  33. constexpr uint32_t kStoreObjectInIdx = 1;
  34. // Some image instructions may contain an "image operands" argument.
  35. // Returns the operand index for the "image operands".
  36. // Returns -1 if the instruction does not have image operands.
  37. int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
  38. const auto opcode = inst->opcode();
  39. switch (opcode) {
  40. case spv::Op::OpImageSampleImplicitLod:
  41. case spv::Op::OpImageSampleExplicitLod:
  42. case spv::Op::OpImageSampleProjImplicitLod:
  43. case spv::Op::OpImageSampleProjExplicitLod:
  44. case spv::Op::OpImageFetch:
  45. case spv::Op::OpImageRead:
  46. case spv::Op::OpImageSparseSampleImplicitLod:
  47. case spv::Op::OpImageSparseSampleExplicitLod:
  48. case spv::Op::OpImageSparseSampleProjImplicitLod:
  49. case spv::Op::OpImageSparseSampleProjExplicitLod:
  50. case spv::Op::OpImageSparseFetch:
  51. case spv::Op::OpImageSparseRead:
  52. return inst->NumOperands() > 4 ? 2 : -1;
  53. case spv::Op::OpImageSampleDrefImplicitLod:
  54. case spv::Op::OpImageSampleDrefExplicitLod:
  55. case spv::Op::OpImageSampleProjDrefImplicitLod:
  56. case spv::Op::OpImageSampleProjDrefExplicitLod:
  57. case spv::Op::OpImageGather:
  58. case spv::Op::OpImageDrefGather:
  59. case spv::Op::OpImageSparseSampleDrefImplicitLod:
  60. case spv::Op::OpImageSparseSampleDrefExplicitLod:
  61. case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
  62. case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
  63. case spv::Op::OpImageSparseGather:
  64. case spv::Op::OpImageSparseDrefGather:
  65. return inst->NumOperands() > 5 ? 3 : -1;
  66. case spv::Op::OpImageWrite:
  67. return inst->NumOperands() > 3 ? 3 : -1;
  68. default:
  69. return -1;
  70. }
  71. }
  72. // Returns the element width of |type|.
  73. uint32_t ElementWidth(const analysis::Type* type) {
  74. if (const analysis::Vector* vec_type = type->AsVector()) {
  75. return ElementWidth(vec_type->element_type());
  76. } else if (const analysis::Float* float_type = type->AsFloat()) {
  77. return float_type->width();
  78. } else {
  79. assert(type->AsInteger());
  80. return type->AsInteger()->width();
  81. }
  82. }
  83. // Returns true if |type| is Float or a vector of Float.
  84. bool HasFloatingPoint(const analysis::Type* type) {
  85. if (type->AsFloat()) {
  86. return true;
  87. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  88. return vec_type->element_type()->AsFloat() != nullptr;
  89. }
  90. return false;
  91. }
  92. // Returns false if |val| is NaN, infinite or subnormal.
  93. template <typename T>
  94. bool IsValidResult(T val) {
  95. int classified = std::fpclassify(val);
  96. switch (classified) {
  97. case FP_NAN:
  98. case FP_INFINITE:
  99. case FP_SUBNORMAL:
  100. return false;
  101. default:
  102. return true;
  103. }
  104. }
  105. const analysis::Constant* ConstInput(
  106. const std::vector<const analysis::Constant*>& constants) {
  107. return constants[0] ? constants[0] : constants[1];
  108. }
  109. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  110. Instruction* inst) {
  111. uint32_t in_op = c ? 1u : 0u;
  112. return context->get_def_use_mgr()->GetDef(
  113. inst->GetSingleWordInOperand(in_op));
  114. }
  115. std::vector<uint32_t> ExtractInts(uint64_t val) {
  116. std::vector<uint32_t> words;
  117. words.push_back(static_cast<uint32_t>(val));
  118. words.push_back(static_cast<uint32_t>(val >> 32));
  119. return words;
  120. }
  121. std::vector<uint32_t> GetWordsFromScalarIntConstant(
  122. const analysis::IntConstant* c) {
  123. assert(c != nullptr);
  124. uint32_t width = c->type()->AsInteger()->width();
  125. assert(width == 8 || width == 16 || width == 32 || width == 64);
  126. if (width == 64) {
  127. uint64_t uval = static_cast<uint64_t>(c->GetU64());
  128. return ExtractInts(uval);
  129. }
  130. // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
  131. // smaller than 32-bits are automatically zero or sign extended to 32-bits.
  132. return {c->GetU32BitValue()};
  133. }
  134. std::vector<uint32_t> GetWordsFromScalarFloatConstant(
  135. const analysis::FloatConstant* c) {
  136. assert(c != nullptr);
  137. uint32_t width = c->type()->AsFloat()->width();
  138. assert(width == 16 || width == 32 || width == 64);
  139. if (width == 64) {
  140. utils::FloatProxy<double> result(c->GetDouble());
  141. return result.GetWords();
  142. }
  143. // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
  144. // smaller than 32-bits are automatically zero extended to 32-bits.
  145. return {c->GetU32BitValue()};
  146. }
  147. std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
  148. analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
  149. if (const auto* float_constant = c->AsFloatConstant()) {
  150. return GetWordsFromScalarFloatConstant(float_constant);
  151. } else if (const auto* int_constant = c->AsIntConstant()) {
  152. return GetWordsFromScalarIntConstant(int_constant);
  153. } else if (const auto* vec_constant = c->AsVectorConstant()) {
  154. std::vector<uint32_t> words;
  155. for (const auto* comp : vec_constant->GetComponents()) {
  156. auto comp_in_words =
  157. GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
  158. words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
  159. }
  160. return words;
  161. }
  162. return {};
  163. }
  164. const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
  165. analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
  166. const analysis::Type* type) {
  167. if (type->AsInteger() || type->AsFloat())
  168. return const_mgr->GetConstant(type, words);
  169. if (const auto* vec_type = type->AsVector())
  170. return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
  171. return nullptr;
  172. }
  173. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  174. // constant.
  175. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  176. const analysis::Constant* c) {
  177. assert(c);
  178. assert(c->type()->AsFloat());
  179. uint32_t width = c->type()->AsFloat()->width();
  180. assert(width == 32 || width == 64);
  181. std::vector<uint32_t> words;
  182. if (width == 64) {
  183. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  184. words = result.GetWords();
  185. } else {
  186. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  187. words = result.GetWords();
  188. }
  189. const analysis::Constant* negated_const =
  190. const_mgr->GetConstant(c->type(), std::move(words));
  191. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  192. }
  193. // Negates the integer constant |c|. Returns the id of the defining instruction.
  194. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  195. const analysis::Constant* c) {
  196. assert(c);
  197. assert(c->type()->AsInteger());
  198. uint32_t width = c->type()->AsInteger()->width();
  199. assert(width == 32 || width == 64);
  200. std::vector<uint32_t> words;
  201. if (width == 64) {
  202. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  203. words = ExtractInts(uval);
  204. } else {
  205. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  206. }
  207. const analysis::Constant* negated_const =
  208. const_mgr->GetConstant(c->type(), std::move(words));
  209. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  210. }
  211. // Negates the vector constant |c|. Returns the id of the defining instruction.
  212. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  213. const analysis::Constant* c) {
  214. assert(const_mgr && c);
  215. assert(c->type()->AsVector());
  216. if (c->AsNullConstant()) {
  217. // 0.0 vs -0.0 shouldn't matter.
  218. return const_mgr->GetDefiningInstruction(c)->result_id();
  219. } else {
  220. const analysis::Type* component_type =
  221. c->AsVectorConstant()->component_type();
  222. std::vector<uint32_t> words;
  223. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  224. if (component_type->AsFloat()) {
  225. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  226. } else {
  227. assert(component_type->AsInteger());
  228. words.push_back(NegateIntegerConstant(const_mgr, comp));
  229. }
  230. }
  231. const analysis::Constant* negated_const =
  232. const_mgr->GetConstant(c->type(), std::move(words));
  233. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  234. }
  235. }
  236. // Negates |c|. Returns the id of the defining instruction.
  237. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  238. const analysis::Constant* c) {
  239. if (c->type()->AsVector()) {
  240. return NegateVectorConstant(const_mgr, c);
  241. } else if (c->type()->AsFloat()) {
  242. return NegateFloatingPointConstant(const_mgr, c);
  243. } else {
  244. assert(c->type()->AsInteger());
  245. return NegateIntegerConstant(const_mgr, c);
  246. }
  247. }
  248. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  249. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  250. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  251. const analysis::Constant* c) {
  252. assert(const_mgr && c);
  253. assert(c->type()->AsFloat());
  254. uint32_t width = c->type()->AsFloat()->width();
  255. assert(width == 32 || width == 64);
  256. std::vector<uint32_t> words;
  257. if (c->IsZero()) {
  258. return 0;
  259. }
  260. if (width == 64) {
  261. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  262. if (!IsValidResult(result.getAsFloat())) return 0;
  263. words = result.GetWords();
  264. } else {
  265. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  266. if (!IsValidResult(result.getAsFloat())) return 0;
  267. words = result.GetWords();
  268. }
  269. const analysis::Constant* negated_const =
  270. const_mgr->GetConstant(c->type(), std::move(words));
  271. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  272. }
  273. // Replaces fdiv where second operand is constant with fmul.
  274. FoldingRule ReciprocalFDiv() {
  275. return [](IRContext* context, Instruction* inst,
  276. const std::vector<const analysis::Constant*>& constants) {
  277. assert(inst->opcode() == spv::Op::OpFDiv);
  278. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  279. const analysis::Type* type =
  280. context->get_type_mgr()->GetType(inst->type_id());
  281. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  282. uint32_t width = ElementWidth(type);
  283. if (width != 32 && width != 64) return false;
  284. if (constants[1] != nullptr) {
  285. uint32_t id = 0;
  286. if (const analysis::VectorConstant* vector_const =
  287. constants[1]->AsVectorConstant()) {
  288. std::vector<uint32_t> neg_ids;
  289. for (auto& comp : vector_const->GetComponents()) {
  290. id = Reciprocal(const_mgr, comp);
  291. if (id == 0) return false;
  292. neg_ids.push_back(id);
  293. }
  294. const analysis::Constant* negated_const =
  295. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  296. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  297. } else if (constants[1]->AsFloatConstant()) {
  298. id = Reciprocal(const_mgr, constants[1]);
  299. if (id == 0) return false;
  300. } else {
  301. // Don't fold a null constant.
  302. return false;
  303. }
  304. inst->SetOpcode(spv::Op::OpFMul);
  305. inst->SetInOperands(
  306. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  307. {SPV_OPERAND_TYPE_ID, {id}}});
  308. return true;
  309. }
  310. return false;
  311. };
  312. }
  313. // Elides consecutive negate instructions.
  314. FoldingRule MergeNegateArithmetic() {
  315. return [](IRContext* context, Instruction* inst,
  316. const std::vector<const analysis::Constant*>& constants) {
  317. assert(inst->opcode() == spv::Op::OpFNegate ||
  318. inst->opcode() == spv::Op::OpSNegate);
  319. (void)constants;
  320. const analysis::Type* type =
  321. context->get_type_mgr()->GetType(inst->type_id());
  322. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  323. return false;
  324. Instruction* op_inst =
  325. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  326. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  327. return false;
  328. if (op_inst->opcode() == inst->opcode()) {
  329. // Elide negates.
  330. inst->SetOpcode(spv::Op::OpCopyObject);
  331. inst->SetInOperands(
  332. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  333. return true;
  334. }
  335. return false;
  336. };
  337. }
  338. // Merges negate into a mul or div operation if that operation contains a
  339. // constant operand.
  340. // Cases:
  341. // -(x * 2) = x * -2
  342. // -(2 * x) = x * -2
  343. // -(x / 2) = x / -2
  344. // -(2 / x) = -2 / x
  345. FoldingRule MergeNegateMulDivArithmetic() {
  346. return [](IRContext* context, Instruction* inst,
  347. const std::vector<const analysis::Constant*>& constants) {
  348. assert(inst->opcode() == spv::Op::OpFNegate ||
  349. inst->opcode() == spv::Op::OpSNegate);
  350. (void)constants;
  351. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  352. const analysis::Type* type =
  353. context->get_type_mgr()->GetType(inst->type_id());
  354. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  355. return false;
  356. Instruction* op_inst =
  357. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  358. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  359. return false;
  360. uint32_t width = ElementWidth(type);
  361. if (width != 32 && width != 64) return false;
  362. spv::Op opcode = op_inst->opcode();
  363. if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
  364. opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
  365. opcode == spv::Op::OpUDiv) {
  366. std::vector<const analysis::Constant*> op_constants =
  367. const_mgr->GetOperandConstants(op_inst);
  368. // Merge negate into mul or div if one operand is constant.
  369. if (op_constants[0] || op_constants[1]) {
  370. bool zero_is_variable = op_constants[0] == nullptr;
  371. const analysis::Constant* c = ConstInput(op_constants);
  372. uint32_t neg_id = NegateConstant(const_mgr, c);
  373. uint32_t non_const_id = zero_is_variable
  374. ? op_inst->GetSingleWordInOperand(0u)
  375. : op_inst->GetSingleWordInOperand(1u);
  376. // Change this instruction to a mul/div.
  377. inst->SetOpcode(op_inst->opcode());
  378. if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
  379. opcode == spv::Op::OpSDiv) {
  380. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  381. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  382. inst->SetInOperands(
  383. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  384. } else {
  385. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  386. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  387. }
  388. return true;
  389. }
  390. }
  391. return false;
  392. };
  393. }
  394. // Merges negate into a add or sub operation if that operation contains a
  395. // constant operand.
  396. // Cases:
  397. // -(x + 2) = -2 - x
  398. // -(2 + x) = -2 - x
  399. // -(x - 2) = 2 - x
  400. // -(2 - x) = x - 2
  401. FoldingRule MergeNegateAddSubArithmetic() {
  402. return [](IRContext* context, Instruction* inst,
  403. const std::vector<const analysis::Constant*>& constants) {
  404. assert(inst->opcode() == spv::Op::OpFNegate ||
  405. inst->opcode() == spv::Op::OpSNegate);
  406. (void)constants;
  407. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  408. const analysis::Type* type =
  409. context->get_type_mgr()->GetType(inst->type_id());
  410. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  411. return false;
  412. Instruction* op_inst =
  413. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  414. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  415. return false;
  416. uint32_t width = ElementWidth(type);
  417. if (width != 32 && width != 64) return false;
  418. if (op_inst->opcode() == spv::Op::OpFAdd ||
  419. op_inst->opcode() == spv::Op::OpFSub ||
  420. op_inst->opcode() == spv::Op::OpIAdd ||
  421. op_inst->opcode() == spv::Op::OpISub) {
  422. std::vector<const analysis::Constant*> op_constants =
  423. const_mgr->GetOperandConstants(op_inst);
  424. if (op_constants[0] || op_constants[1]) {
  425. bool zero_is_variable = op_constants[0] == nullptr;
  426. bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
  427. (op_inst->opcode() == spv::Op::OpIAdd);
  428. bool swap_operands = !is_add || zero_is_variable;
  429. bool negate_const = is_add;
  430. const analysis::Constant* c = ConstInput(op_constants);
  431. uint32_t const_id = 0;
  432. if (negate_const) {
  433. const_id = NegateConstant(const_mgr, c);
  434. } else {
  435. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  436. : op_inst->GetSingleWordInOperand(0u);
  437. }
  438. // Swap operands if necessary and make the instruction a subtraction.
  439. uint32_t op0 =
  440. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  441. uint32_t op1 =
  442. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  443. if (swap_operands) std::swap(op0, op1);
  444. inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
  445. : spv::Op::OpISub);
  446. inst->SetInOperands(
  447. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  448. return true;
  449. }
  450. }
  451. return false;
  452. };
  453. }
  454. // Returns true if |c| has a zero element.
  455. bool HasZero(const analysis::Constant* c) {
  456. if (c->AsNullConstant()) {
  457. return true;
  458. }
  459. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  460. for (auto& comp : vec_const->GetComponents())
  461. if (HasZero(comp)) return true;
  462. } else {
  463. assert(c->AsScalarConstant());
  464. return c->AsScalarConstant()->IsZero();
  465. }
  466. return false;
  467. }
  468. // Performs |input1| |opcode| |input2| and returns the merged constant result
  469. // id. Returns 0 if the result is not a valid value. The input types must be
  470. // Float.
  471. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  472. spv::Op opcode,
  473. const analysis::Constant* input1,
  474. const analysis::Constant* input2) {
  475. const analysis::Type* type = input1->type();
  476. assert(type->AsFloat());
  477. uint32_t width = type->AsFloat()->width();
  478. assert(width == 32 || width == 64);
  479. std::vector<uint32_t> words;
  480. #define FOLD_OP(op) \
  481. if (width == 64) { \
  482. utils::FloatProxy<double> val = \
  483. input1->GetDouble() op input2->GetDouble(); \
  484. double dval = val.getAsFloat(); \
  485. if (!IsValidResult(dval)) return 0; \
  486. words = val.GetWords(); \
  487. } else { \
  488. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  489. float fval = val.getAsFloat(); \
  490. if (!IsValidResult(fval)) return 0; \
  491. words = val.GetWords(); \
  492. } \
  493. static_assert(true, "require extra semicolon")
  494. switch (opcode) {
  495. case spv::Op::OpFMul:
  496. FOLD_OP(*);
  497. break;
  498. case spv::Op::OpFDiv:
  499. if (HasZero(input2)) return 0;
  500. FOLD_OP(/);
  501. break;
  502. case spv::Op::OpFAdd:
  503. FOLD_OP(+);
  504. break;
  505. case spv::Op::OpFSub:
  506. FOLD_OP(-);
  507. break;
  508. default:
  509. assert(false && "Unexpected operation");
  510. break;
  511. }
  512. #undef FOLD_OP
  513. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  514. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  515. }
  516. // Performs |input1| |opcode| |input2| and returns the merged constant result
  517. // id. Returns 0 if the result is not a valid value. The input types must be
  518. // Integers.
  519. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  520. spv::Op opcode,
  521. const analysis::Constant* input1,
  522. const analysis::Constant* input2) {
  523. assert(input1->type()->AsInteger());
  524. const analysis::Integer* type = input1->type()->AsInteger();
  525. uint32_t width = type->AsInteger()->width();
  526. assert(width == 32 || width == 64);
  527. std::vector<uint32_t> words;
  528. // Regardless of the sign of the constant, folding is performed on an unsigned
  529. // interpretation of the constant data. This avoids signed integer overflow
  530. // while folding, and works because sign is irrelevant for the IAdd, ISub and
  531. // IMul instructions.
  532. #define FOLD_OP(op) \
  533. if (width == 64) { \
  534. uint64_t val = input1->GetU64() op input2->GetU64(); \
  535. words = ExtractInts(val); \
  536. } else { \
  537. uint32_t val = input1->GetU32() op input2->GetU32(); \
  538. words.push_back(val); \
  539. } \
  540. static_assert(true, "require extra semicolon")
  541. switch (opcode) {
  542. case spv::Op::OpIMul:
  543. FOLD_OP(*);
  544. break;
  545. case spv::Op::OpSDiv:
  546. case spv::Op::OpUDiv:
  547. assert(false && "Should not merge integer division");
  548. break;
  549. case spv::Op::OpIAdd:
  550. FOLD_OP(+);
  551. break;
  552. case spv::Op::OpISub:
  553. FOLD_OP(-);
  554. break;
  555. default:
  556. assert(false && "Unexpected operation");
  557. break;
  558. }
  559. #undef FOLD_OP
  560. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  561. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  562. }
  563. // Performs |input1| |opcode| |input2| and returns the merged constant result
  564. // id. Returns 0 if the result is not a valid value. The input types must be
  565. // Integers, Floats or Vectors of such.
  566. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
  567. const analysis::Constant* input1,
  568. const analysis::Constant* input2) {
  569. assert(input1 && input2);
  570. const analysis::Type* type = input1->type();
  571. std::vector<uint32_t> words;
  572. if (const analysis::Vector* vector_type = type->AsVector()) {
  573. const analysis::Type* ele_type = vector_type->element_type();
  574. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  575. uint32_t id = 0;
  576. const analysis::Constant* input1_comp = nullptr;
  577. if (const analysis::VectorConstant* input1_vector =
  578. input1->AsVectorConstant()) {
  579. input1_comp = input1_vector->GetComponents()[i];
  580. } else {
  581. assert(input1->AsNullConstant());
  582. input1_comp = const_mgr->GetConstant(ele_type, {});
  583. }
  584. const analysis::Constant* input2_comp = nullptr;
  585. if (const analysis::VectorConstant* input2_vector =
  586. input2->AsVectorConstant()) {
  587. input2_comp = input2_vector->GetComponents()[i];
  588. } else {
  589. assert(input2->AsNullConstant());
  590. input2_comp = const_mgr->GetConstant(ele_type, {});
  591. }
  592. if (ele_type->AsFloat()) {
  593. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  594. input2_comp);
  595. } else {
  596. assert(ele_type->AsInteger());
  597. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  598. input2_comp);
  599. }
  600. if (id == 0) return 0;
  601. words.push_back(id);
  602. }
  603. const analysis::Constant* merged_const =
  604. const_mgr->GetConstant(type, words);
  605. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  606. } else if (type->AsFloat()) {
  607. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  608. } else {
  609. assert(type->AsInteger());
  610. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  611. }
  612. }
  613. // Merges consecutive multiplies where each contains one constant operand.
  614. // Cases:
  615. // 2 * (x * 2) = x * 4
  616. // 2 * (2 * x) = x * 4
  617. // (x * 2) * 2 = x * 4
  618. // (2 * x) * 2 = x * 4
  619. FoldingRule MergeMulMulArithmetic() {
  620. return [](IRContext* context, Instruction* inst,
  621. const std::vector<const analysis::Constant*>& constants) {
  622. assert(inst->opcode() == spv::Op::OpFMul ||
  623. inst->opcode() == spv::Op::OpIMul);
  624. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  625. const analysis::Type* type =
  626. context->get_type_mgr()->GetType(inst->type_id());
  627. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  628. return false;
  629. uint32_t width = ElementWidth(type);
  630. if (width != 32 && width != 64) return false;
  631. // Determine the constant input and the variable input in |inst|.
  632. const analysis::Constant* const_input1 = ConstInput(constants);
  633. if (!const_input1) return false;
  634. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  635. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  636. return false;
  637. if (other_inst->opcode() == inst->opcode()) {
  638. std::vector<const analysis::Constant*> other_constants =
  639. const_mgr->GetOperandConstants(other_inst);
  640. const analysis::Constant* const_input2 = ConstInput(other_constants);
  641. if (!const_input2) return false;
  642. bool other_first_is_variable = other_constants[0] == nullptr;
  643. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  644. const_input1, const_input2);
  645. if (merged_id == 0) return false;
  646. uint32_t non_const_id = other_first_is_variable
  647. ? other_inst->GetSingleWordInOperand(0u)
  648. : other_inst->GetSingleWordInOperand(1u);
  649. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  650. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  651. return true;
  652. }
  653. return false;
  654. };
  655. }
  656. // Merges divides into subsequent multiplies if each instruction contains one
  657. // constant operand. Does not support integer operations.
  658. // Cases:
  659. // 2 * (x / 2) = x * 1
  660. // 2 * (2 / x) = 4 / x
  661. // (x / 2) * 2 = x * 1
  662. // (2 / x) * 2 = 4 / x
  663. // (y / x) * x = y
  664. // x * (y / x) = y
  665. FoldingRule MergeMulDivArithmetic() {
  666. return [](IRContext* context, Instruction* inst,
  667. const std::vector<const analysis::Constant*>& constants) {
  668. assert(inst->opcode() == spv::Op::OpFMul);
  669. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  670. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  671. const analysis::Type* type =
  672. context->get_type_mgr()->GetType(inst->type_id());
  673. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  674. uint32_t width = ElementWidth(type);
  675. if (width != 32 && width != 64) return false;
  676. for (uint32_t i = 0; i < 2; i++) {
  677. uint32_t op_id = inst->GetSingleWordInOperand(i);
  678. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  679. if (op_inst->opcode() == spv::Op::OpFDiv) {
  680. if (op_inst->GetSingleWordInOperand(1) ==
  681. inst->GetSingleWordInOperand(1 - i)) {
  682. inst->SetOpcode(spv::Op::OpCopyObject);
  683. inst->SetInOperands(
  684. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  685. return true;
  686. }
  687. }
  688. }
  689. const analysis::Constant* const_input1 = ConstInput(constants);
  690. if (!const_input1) return false;
  691. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  692. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  693. if (other_inst->opcode() == spv::Op::OpFDiv) {
  694. std::vector<const analysis::Constant*> other_constants =
  695. const_mgr->GetOperandConstants(other_inst);
  696. const analysis::Constant* const_input2 = ConstInput(other_constants);
  697. if (!const_input2 || HasZero(const_input2)) return false;
  698. bool other_first_is_variable = other_constants[0] == nullptr;
  699. // If the variable value is the second operand of the divide, multiply
  700. // the constants together. Otherwise divide the constants.
  701. uint32_t merged_id = PerformOperation(
  702. const_mgr,
  703. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  704. const_input1, const_input2);
  705. if (merged_id == 0) return false;
  706. uint32_t non_const_id = other_first_is_variable
  707. ? other_inst->GetSingleWordInOperand(0u)
  708. : other_inst->GetSingleWordInOperand(1u);
  709. // If the variable value is on the second operand of the div, then this
  710. // operation is a div. Otherwise it should be a multiply.
  711. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  712. : other_inst->opcode());
  713. if (other_first_is_variable) {
  714. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  715. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  716. } else {
  717. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  718. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  719. }
  720. return true;
  721. }
  722. return false;
  723. };
  724. }
  725. // Merges multiply of constant and negation.
  726. // Cases:
  727. // (-x) * 2 = x * -2
  728. // 2 * (-x) = x * -2
  729. FoldingRule MergeMulNegateArithmetic() {
  730. return [](IRContext* context, Instruction* inst,
  731. const std::vector<const analysis::Constant*>& constants) {
  732. assert(inst->opcode() == spv::Op::OpFMul ||
  733. inst->opcode() == spv::Op::OpIMul);
  734. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  735. const analysis::Type* type =
  736. context->get_type_mgr()->GetType(inst->type_id());
  737. bool uses_float = HasFloatingPoint(type);
  738. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  739. uint32_t width = ElementWidth(type);
  740. if (width != 32 && width != 64) return false;
  741. const analysis::Constant* const_input1 = ConstInput(constants);
  742. if (!const_input1) return false;
  743. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  744. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  745. return false;
  746. if (other_inst->opcode() == spv::Op::OpFNegate ||
  747. other_inst->opcode() == spv::Op::OpSNegate) {
  748. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  749. inst->SetInOperands(
  750. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  751. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  752. return true;
  753. }
  754. return false;
  755. };
  756. }
  757. // Merges consecutive divides if each instruction contains one constant operand.
  758. // Does not support integer division.
  759. // Cases:
  760. // 2 / (x / 2) = 4 / x
  761. // 4 / (2 / x) = 2 * x
  762. // (4 / x) / 2 = 2 / x
  763. // (x / 2) / 2 = x / 4
  764. FoldingRule MergeDivDivArithmetic() {
  765. return [](IRContext* context, Instruction* inst,
  766. const std::vector<const analysis::Constant*>& constants) {
  767. assert(inst->opcode() == spv::Op::OpFDiv);
  768. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  769. const analysis::Type* type =
  770. context->get_type_mgr()->GetType(inst->type_id());
  771. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  772. uint32_t width = ElementWidth(type);
  773. if (width != 32 && width != 64) return false;
  774. const analysis::Constant* const_input1 = ConstInput(constants);
  775. if (!const_input1 || HasZero(const_input1)) return false;
  776. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  777. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  778. bool first_is_variable = constants[0] == nullptr;
  779. if (other_inst->opcode() == inst->opcode()) {
  780. std::vector<const analysis::Constant*> other_constants =
  781. const_mgr->GetOperandConstants(other_inst);
  782. const analysis::Constant* const_input2 = ConstInput(other_constants);
  783. if (!const_input2 || HasZero(const_input2)) return false;
  784. bool other_first_is_variable = other_constants[0] == nullptr;
  785. spv::Op merge_op = inst->opcode();
  786. if (other_first_is_variable) {
  787. // Constants magnify.
  788. merge_op = spv::Op::OpFMul;
  789. }
  790. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  791. // because it is commutative.
  792. if (first_is_variable) std::swap(const_input1, const_input2);
  793. uint32_t merged_id =
  794. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  795. if (merged_id == 0) return false;
  796. uint32_t non_const_id = other_first_is_variable
  797. ? other_inst->GetSingleWordInOperand(0u)
  798. : other_inst->GetSingleWordInOperand(1u);
  799. spv::Op op = inst->opcode();
  800. if (!first_is_variable && !other_first_is_variable) {
  801. // Effectively div of 1/x, so change to multiply.
  802. op = spv::Op::OpFMul;
  803. }
  804. uint32_t op1 = merged_id;
  805. uint32_t op2 = non_const_id;
  806. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  807. inst->SetOpcode(op);
  808. inst->SetInOperands(
  809. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  810. return true;
  811. }
  812. return false;
  813. };
  814. }
  815. // Fold multiplies succeeded by divides where each instruction contains a
  816. // constant operand. Does not support integer divide.
  817. // Cases:
  818. // 4 / (x * 2) = 2 / x
  819. // 4 / (2 * x) = 2 / x
  820. // (x * 4) / 2 = x * 2
  821. // (4 * x) / 2 = x * 2
  822. // (x * y) / x = y
  823. // (y * x) / x = y
  824. FoldingRule MergeDivMulArithmetic() {
  825. return [](IRContext* context, Instruction* inst,
  826. const std::vector<const analysis::Constant*>& constants) {
  827. assert(inst->opcode() == spv::Op::OpFDiv);
  828. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  829. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  830. const analysis::Type* type =
  831. context->get_type_mgr()->GetType(inst->type_id());
  832. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  833. uint32_t width = ElementWidth(type);
  834. if (width != 32 && width != 64) return false;
  835. uint32_t op_id = inst->GetSingleWordInOperand(0);
  836. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  837. if (op_inst->opcode() == spv::Op::OpFMul) {
  838. for (uint32_t i = 0; i < 2; i++) {
  839. if (op_inst->GetSingleWordInOperand(i) ==
  840. inst->GetSingleWordInOperand(1)) {
  841. inst->SetOpcode(spv::Op::OpCopyObject);
  842. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  843. {op_inst->GetSingleWordInOperand(1 - i)}}});
  844. return true;
  845. }
  846. }
  847. }
  848. const analysis::Constant* const_input1 = ConstInput(constants);
  849. if (!const_input1 || HasZero(const_input1)) return false;
  850. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  851. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  852. bool first_is_variable = constants[0] == nullptr;
  853. if (other_inst->opcode() == spv::Op::OpFMul) {
  854. std::vector<const analysis::Constant*> other_constants =
  855. const_mgr->GetOperandConstants(other_inst);
  856. const analysis::Constant* const_input2 = ConstInput(other_constants);
  857. if (!const_input2) return false;
  858. bool other_first_is_variable = other_constants[0] == nullptr;
  859. // This is an x / (*) case. Swap the inputs.
  860. if (first_is_variable) std::swap(const_input1, const_input2);
  861. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  862. const_input1, const_input2);
  863. if (merged_id == 0) return false;
  864. uint32_t non_const_id = other_first_is_variable
  865. ? other_inst->GetSingleWordInOperand(0u)
  866. : other_inst->GetSingleWordInOperand(1u);
  867. uint32_t op1 = merged_id;
  868. uint32_t op2 = non_const_id;
  869. if (first_is_variable) std::swap(op1, op2);
  870. // Convert to multiply
  871. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  872. inst->SetInOperands(
  873. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  874. return true;
  875. }
  876. return false;
  877. };
  878. }
  879. // Fold divides of a constant and a negation.
  880. // Cases:
  881. // (-x) / 2 = x / -2
  882. // 2 / (-x) = -2 / x
  883. FoldingRule MergeDivNegateArithmetic() {
  884. return [](IRContext* context, Instruction* inst,
  885. const std::vector<const analysis::Constant*>& constants) {
  886. assert(inst->opcode() == spv::Op::OpFDiv);
  887. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  888. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  889. const analysis::Constant* const_input1 = ConstInput(constants);
  890. if (!const_input1) return false;
  891. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  892. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  893. bool first_is_variable = constants[0] == nullptr;
  894. if (other_inst->opcode() == spv::Op::OpFNegate) {
  895. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  896. if (first_is_variable) {
  897. inst->SetInOperands(
  898. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  899. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  900. } else {
  901. inst->SetInOperands(
  902. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  903. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  904. }
  905. return true;
  906. }
  907. return false;
  908. };
  909. }
  910. // Folds addition of a constant and a negation.
  911. // Cases:
  912. // (-x) + 2 = 2 - x
  913. // 2 + (-x) = 2 - x
  914. FoldingRule MergeAddNegateArithmetic() {
  915. return [](IRContext* context, Instruction* inst,
  916. const std::vector<const analysis::Constant*>& constants) {
  917. assert(inst->opcode() == spv::Op::OpFAdd ||
  918. inst->opcode() == spv::Op::OpIAdd);
  919. const analysis::Type* type =
  920. context->get_type_mgr()->GetType(inst->type_id());
  921. bool uses_float = HasFloatingPoint(type);
  922. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  923. const analysis::Constant* const_input1 = ConstInput(constants);
  924. if (!const_input1) return false;
  925. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  926. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  927. return false;
  928. if (other_inst->opcode() == spv::Op::OpSNegate ||
  929. other_inst->opcode() == spv::Op::OpFNegate) {
  930. inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
  931. : spv::Op::OpISub);
  932. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  933. : inst->GetSingleWordInOperand(1u);
  934. inst->SetInOperands(
  935. {{SPV_OPERAND_TYPE_ID, {const_id}},
  936. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  937. return true;
  938. }
  939. return false;
  940. };
  941. }
  942. // Folds subtraction of a constant and a negation.
  943. // Cases:
  944. // (-x) - 2 = -2 - x
  945. // 2 - (-x) = x + 2
  946. FoldingRule MergeSubNegateArithmetic() {
  947. return [](IRContext* context, Instruction* inst,
  948. const std::vector<const analysis::Constant*>& constants) {
  949. assert(inst->opcode() == spv::Op::OpFSub ||
  950. inst->opcode() == spv::Op::OpISub);
  951. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  952. const analysis::Type* type =
  953. context->get_type_mgr()->GetType(inst->type_id());
  954. bool uses_float = HasFloatingPoint(type);
  955. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  956. uint32_t width = ElementWidth(type);
  957. if (width != 32 && width != 64) return false;
  958. const analysis::Constant* const_input1 = ConstInput(constants);
  959. if (!const_input1) return false;
  960. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  961. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  962. return false;
  963. if (other_inst->opcode() == spv::Op::OpSNegate ||
  964. other_inst->opcode() == spv::Op::OpFNegate) {
  965. uint32_t op1 = 0;
  966. uint32_t op2 = 0;
  967. spv::Op opcode = inst->opcode();
  968. if (constants[0] != nullptr) {
  969. op1 = other_inst->GetSingleWordInOperand(0u);
  970. op2 = inst->GetSingleWordInOperand(0u);
  971. opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd;
  972. } else {
  973. op1 = NegateConstant(const_mgr, const_input1);
  974. op2 = other_inst->GetSingleWordInOperand(0u);
  975. }
  976. inst->SetOpcode(opcode);
  977. inst->SetInOperands(
  978. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  979. return true;
  980. }
  981. return false;
  982. };
  983. }
  984. // Folds addition of an addition where each operation has a constant operand.
  985. // Cases:
  986. // (x + 2) + 2 = x + 4
  987. // (2 + x) + 2 = x + 4
  988. // 2 + (x + 2) = x + 4
  989. // 2 + (2 + x) = x + 4
  990. FoldingRule MergeAddAddArithmetic() {
  991. return [](IRContext* context, Instruction* inst,
  992. const std::vector<const analysis::Constant*>& constants) {
  993. assert(inst->opcode() == spv::Op::OpFAdd ||
  994. inst->opcode() == spv::Op::OpIAdd);
  995. const analysis::Type* type =
  996. context->get_type_mgr()->GetType(inst->type_id());
  997. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  998. bool uses_float = HasFloatingPoint(type);
  999. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1000. uint32_t width = ElementWidth(type);
  1001. if (width != 32 && width != 64) return false;
  1002. const analysis::Constant* const_input1 = ConstInput(constants);
  1003. if (!const_input1) return false;
  1004. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1005. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1006. return false;
  1007. if (other_inst->opcode() == spv::Op::OpFAdd ||
  1008. other_inst->opcode() == spv::Op::OpIAdd) {
  1009. std::vector<const analysis::Constant*> other_constants =
  1010. const_mgr->GetOperandConstants(other_inst);
  1011. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1012. if (!const_input2) return false;
  1013. Instruction* non_const_input =
  1014. NonConstInput(context, other_constants[0], other_inst);
  1015. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1016. const_input1, const_input2);
  1017. if (merged_id == 0) return false;
  1018. inst->SetInOperands(
  1019. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  1020. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  1021. return true;
  1022. }
  1023. return false;
  1024. };
  1025. }
  1026. // Folds addition of a subtraction where each operation has a constant operand.
  1027. // Cases:
  1028. // (x - 2) + 2 = x + 0
  1029. // (2 - x) + 2 = 4 - x
  1030. // 2 + (x - 2) = x + 0
  1031. // 2 + (2 - x) = 4 - x
  1032. FoldingRule MergeAddSubArithmetic() {
  1033. return [](IRContext* context, Instruction* inst,
  1034. const std::vector<const analysis::Constant*>& constants) {
  1035. assert(inst->opcode() == spv::Op::OpFAdd ||
  1036. inst->opcode() == spv::Op::OpIAdd);
  1037. const analysis::Type* type =
  1038. context->get_type_mgr()->GetType(inst->type_id());
  1039. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1040. bool uses_float = HasFloatingPoint(type);
  1041. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1042. uint32_t width = ElementWidth(type);
  1043. if (width != 32 && width != 64) return false;
  1044. const analysis::Constant* const_input1 = ConstInput(constants);
  1045. if (!const_input1) return false;
  1046. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1047. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1048. return false;
  1049. if (other_inst->opcode() == spv::Op::OpFSub ||
  1050. other_inst->opcode() == spv::Op::OpISub) {
  1051. std::vector<const analysis::Constant*> other_constants =
  1052. const_mgr->GetOperandConstants(other_inst);
  1053. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1054. if (!const_input2) return false;
  1055. bool first_is_variable = other_constants[0] == nullptr;
  1056. spv::Op op = inst->opcode();
  1057. uint32_t op1 = 0;
  1058. uint32_t op2 = 0;
  1059. if (first_is_variable) {
  1060. // Subtract constants. Non-constant operand is first.
  1061. op1 = other_inst->GetSingleWordInOperand(0u);
  1062. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  1063. const_input2);
  1064. } else {
  1065. // Add constants. Constant operand is first. Change the opcode.
  1066. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  1067. const_input2);
  1068. op2 = other_inst->GetSingleWordInOperand(1u);
  1069. op = other_inst->opcode();
  1070. }
  1071. if (op1 == 0 || op2 == 0) return false;
  1072. inst->SetOpcode(op);
  1073. inst->SetInOperands(
  1074. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1075. return true;
  1076. }
  1077. return false;
  1078. };
  1079. }
  1080. // Folds subtraction of an addition where each operand has a constant operand.
  1081. // Cases:
  1082. // (x + 2) - 2 = x + 0
  1083. // (2 + x) - 2 = x + 0
  1084. // 2 - (x + 2) = 0 - x
  1085. // 2 - (2 + x) = 0 - x
  1086. FoldingRule MergeSubAddArithmetic() {
  1087. return [](IRContext* context, Instruction* inst,
  1088. const std::vector<const analysis::Constant*>& constants) {
  1089. assert(inst->opcode() == spv::Op::OpFSub ||
  1090. inst->opcode() == spv::Op::OpISub);
  1091. const analysis::Type* type =
  1092. context->get_type_mgr()->GetType(inst->type_id());
  1093. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1094. bool uses_float = HasFloatingPoint(type);
  1095. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1096. uint32_t width = ElementWidth(type);
  1097. if (width != 32 && width != 64) return false;
  1098. const analysis::Constant* const_input1 = ConstInput(constants);
  1099. if (!const_input1) return false;
  1100. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1101. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1102. return false;
  1103. if (other_inst->opcode() == spv::Op::OpFAdd ||
  1104. other_inst->opcode() == spv::Op::OpIAdd) {
  1105. std::vector<const analysis::Constant*> other_constants =
  1106. const_mgr->GetOperandConstants(other_inst);
  1107. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1108. if (!const_input2) return false;
  1109. Instruction* non_const_input =
  1110. NonConstInput(context, other_constants[0], other_inst);
  1111. // If the first operand of the sub is not a constant, swap the constants
  1112. // so the subtraction has the correct operands.
  1113. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1114. // Subtract the constants.
  1115. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1116. const_input1, const_input2);
  1117. spv::Op op = inst->opcode();
  1118. uint32_t op1 = 0;
  1119. uint32_t op2 = 0;
  1120. if (constants[0] == nullptr) {
  1121. // Non-constant operand is first. Change the opcode.
  1122. op1 = non_const_input->result_id();
  1123. op2 = merged_id;
  1124. op = other_inst->opcode();
  1125. } else {
  1126. // Constant operand is first.
  1127. op1 = merged_id;
  1128. op2 = non_const_input->result_id();
  1129. }
  1130. if (op1 == 0 || op2 == 0) return false;
  1131. inst->SetOpcode(op);
  1132. inst->SetInOperands(
  1133. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1134. return true;
  1135. }
  1136. return false;
  1137. };
  1138. }
  1139. // Folds subtraction of a subtraction where each operand has a constant operand.
  1140. // Cases:
  1141. // (x - 2) - 2 = x - 4
  1142. // (2 - x) - 2 = 0 - x
  1143. // 2 - (x - 2) = 4 - x
  1144. // 2 - (2 - x) = x + 0
  1145. FoldingRule MergeSubSubArithmetic() {
  1146. return [](IRContext* context, Instruction* inst,
  1147. const std::vector<const analysis::Constant*>& constants) {
  1148. assert(inst->opcode() == spv::Op::OpFSub ||
  1149. inst->opcode() == spv::Op::OpISub);
  1150. const analysis::Type* type =
  1151. context->get_type_mgr()->GetType(inst->type_id());
  1152. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1153. bool uses_float = HasFloatingPoint(type);
  1154. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1155. uint32_t width = ElementWidth(type);
  1156. if (width != 32 && width != 64) return false;
  1157. const analysis::Constant* const_input1 = ConstInput(constants);
  1158. if (!const_input1) return false;
  1159. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1160. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1161. return false;
  1162. if (other_inst->opcode() == spv::Op::OpFSub ||
  1163. other_inst->opcode() == spv::Op::OpISub) {
  1164. std::vector<const analysis::Constant*> other_constants =
  1165. const_mgr->GetOperandConstants(other_inst);
  1166. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1167. if (!const_input2) return false;
  1168. Instruction* non_const_input =
  1169. NonConstInput(context, other_constants[0], other_inst);
  1170. // Merge the constants.
  1171. uint32_t merged_id = 0;
  1172. spv::Op merge_op = inst->opcode();
  1173. if (other_constants[0] == nullptr) {
  1174. merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
  1175. } else if (constants[0] == nullptr) {
  1176. std::swap(const_input1, const_input2);
  1177. }
  1178. merged_id =
  1179. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1180. if (merged_id == 0) return false;
  1181. spv::Op op = inst->opcode();
  1182. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1183. // Change the operation.
  1184. op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
  1185. }
  1186. uint32_t op1 = 0;
  1187. uint32_t op2 = 0;
  1188. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1189. op1 = merged_id;
  1190. op2 = non_const_input->result_id();
  1191. } else {
  1192. op1 = non_const_input->result_id();
  1193. op2 = merged_id;
  1194. }
  1195. inst->SetOpcode(op);
  1196. inst->SetInOperands(
  1197. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1198. return true;
  1199. }
  1200. return false;
  1201. };
  1202. }
  1203. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1204. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1205. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1206. IRContext* context = inst->context();
  1207. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1208. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1209. if (sub_inst->opcode() != spv::Op::OpFSub &&
  1210. sub_inst->opcode() != spv::Op::OpISub)
  1211. return false;
  1212. if (sub_inst->opcode() == spv::Op::OpFSub &&
  1213. !sub_inst->IsFloatingPointFoldingAllowed())
  1214. return false;
  1215. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1216. inst->SetOpcode(spv::Op::OpCopyObject);
  1217. inst->SetInOperands(
  1218. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1219. context->UpdateDefUse(inst);
  1220. return true;
  1221. }
  1222. // Folds addition of a subtraction where the subtrahend is equal to the
  1223. // other addend. Return a copy of the minuend. Accepts generic (const and
  1224. // non-const) operands.
  1225. // Cases:
  1226. // (a - b) + b = a
  1227. // b + (a - b) = a
  1228. FoldingRule MergeGenericAddSubArithmetic() {
  1229. return [](IRContext* context, Instruction* inst,
  1230. const std::vector<const analysis::Constant*>&) {
  1231. assert(inst->opcode() == spv::Op::OpFAdd ||
  1232. inst->opcode() == spv::Op::OpIAdd);
  1233. const analysis::Type* type =
  1234. context->get_type_mgr()->GetType(inst->type_id());
  1235. bool uses_float = HasFloatingPoint(type);
  1236. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1237. uint32_t width = ElementWidth(type);
  1238. if (width != 32 && width != 64) return false;
  1239. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1240. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1241. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1242. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1243. };
  1244. }
  1245. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1246. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1247. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1248. uint32_t factor1_0, uint32_t factor1_1,
  1249. Instruction* inst) {
  1250. IRContext* context = inst->context();
  1251. if (factor0_0 != factor1_0) return false;
  1252. InstructionBuilder ir_builder(
  1253. context, inst,
  1254. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1255. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1256. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1257. inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul
  1258. : spv::Op::OpIMul);
  1259. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1260. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1261. context->UpdateDefUse(inst);
  1262. return true;
  1263. }
  1264. // Perform the following factoring identity, handling all operand order
  1265. // combinations: (a * b) + (a * c) = a * (b + c)
  1266. FoldingRule FactorAddMuls() {
  1267. return [](IRContext* context, Instruction* inst,
  1268. const std::vector<const analysis::Constant*>&) {
  1269. assert(inst->opcode() == spv::Op::OpFAdd ||
  1270. inst->opcode() == spv::Op::OpIAdd);
  1271. const analysis::Type* type =
  1272. context->get_type_mgr()->GetType(inst->type_id());
  1273. bool uses_float = HasFloatingPoint(type);
  1274. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1275. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1276. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1277. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1278. if (add_op0_inst->opcode() != spv::Op::OpFMul &&
  1279. add_op0_inst->opcode() != spv::Op::OpIMul)
  1280. return false;
  1281. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1282. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1283. if (add_op1_inst->opcode() != spv::Op::OpFMul &&
  1284. add_op1_inst->opcode() != spv::Op::OpIMul)
  1285. return false;
  1286. // Only perform this optimization if both of the muls only have one use.
  1287. // Otherwise this is a deoptimization in size and performance.
  1288. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1289. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1290. if (add_op0_inst->opcode() == spv::Op::OpFMul &&
  1291. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1292. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1293. return false;
  1294. for (int i = 0; i < 2; i++) {
  1295. for (int j = 0; j < 2; j++) {
  1296. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1297. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1298. add_op0_inst->GetSingleWordInOperand(1 - i),
  1299. add_op1_inst->GetSingleWordInOperand(j),
  1300. add_op1_inst->GetSingleWordInOperand(1 - j),
  1301. inst))
  1302. return true;
  1303. }
  1304. }
  1305. return false;
  1306. };
  1307. }
  1308. // Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
  1309. void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
  1310. uint32_t ext =
  1311. inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1312. if (ext == 0) {
  1313. inst->context()->AddExtInstImport("GLSL.std.450");
  1314. ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1315. assert(ext != 0 &&
  1316. "Could not add the GLSL.std.450 extended instruction set");
  1317. }
  1318. std::vector<Operand> operands;
  1319. operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
  1320. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
  1321. operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
  1322. operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
  1323. operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
  1324. inst->SetOpcode(spv::Op::OpExtInst);
  1325. inst->SetInOperands(std::move(operands));
  1326. }
  1327. // Folds a multiple and add into an Fma.
  1328. //
  1329. // Cases:
  1330. // (x * y) + a = Fma x y a
  1331. // a + (x * y) = Fma x y a
  1332. bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
  1333. const std::vector<const analysis::Constant*>&) {
  1334. assert(inst->opcode() == spv::Op::OpFAdd);
  1335. if (!inst->IsFloatingPointFoldingAllowed()) {
  1336. return false;
  1337. }
  1338. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1339. for (int i = 0; i < 2; i++) {
  1340. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1341. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  1342. if (op_inst->opcode() != spv::Op::OpFMul) {
  1343. continue;
  1344. }
  1345. if (!op_inst->IsFloatingPointFoldingAllowed()) {
  1346. continue;
  1347. }
  1348. uint32_t x = op_inst->GetSingleWordInOperand(0);
  1349. uint32_t y = op_inst->GetSingleWordInOperand(1);
  1350. uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
  1351. ReplaceWithFma(inst, x, y, a);
  1352. return true;
  1353. }
  1354. return false;
  1355. }
  1356. // Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
  1357. // negated if |negate_addition| is true, otherwise |x| gets negated.
  1358. void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
  1359. uint32_t a, bool negate_addition) {
  1360. uint32_t ext =
  1361. sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1362. if (ext == 0) {
  1363. sub->context()->AddExtInstImport("GLSL.std.450");
  1364. ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1365. assert(ext != 0 &&
  1366. "Could not add the GLSL.std.450 extended instruction set");
  1367. }
  1368. InstructionBuilder ir_builder(
  1369. sub->context(), sub,
  1370. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1371. Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), spv::Op::OpFNegate,
  1372. negate_addition ? a : x);
  1373. uint32_t neg_op = neg->result_id(); // -a : -x
  1374. std::vector<Operand> operands;
  1375. operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
  1376. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
  1377. operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
  1378. operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
  1379. operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
  1380. sub->SetOpcode(spv::Op::OpExtInst);
  1381. sub->SetInOperands(std::move(operands));
  1382. }
  1383. // Folds a multiply and subtract into an Fma and negation.
  1384. //
  1385. // Cases:
  1386. // (x * y) - a = Fma x y -a
  1387. // a - (x * y) = Fma -x y a
  1388. bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
  1389. const std::vector<const analysis::Constant*>&) {
  1390. assert(sub->opcode() == spv::Op::OpFSub);
  1391. if (!sub->IsFloatingPointFoldingAllowed()) {
  1392. return false;
  1393. }
  1394. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1395. for (int i = 0; i < 2; i++) {
  1396. uint32_t op_id = sub->GetSingleWordInOperand(i);
  1397. Instruction* mul = def_use_mgr->GetDef(op_id);
  1398. if (mul->opcode() != spv::Op::OpFMul) {
  1399. continue;
  1400. }
  1401. if (!mul->IsFloatingPointFoldingAllowed()) {
  1402. continue;
  1403. }
  1404. uint32_t x = mul->GetSingleWordInOperand(0);
  1405. uint32_t y = mul->GetSingleWordInOperand(1);
  1406. uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
  1407. ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
  1408. return true;
  1409. }
  1410. return false;
  1411. }
  1412. FoldingRule IntMultipleBy1() {
  1413. return [](IRContext*, Instruction* inst,
  1414. const std::vector<const analysis::Constant*>& constants) {
  1415. assert(inst->opcode() == spv::Op::OpIMul &&
  1416. "Wrong opcode. Should be OpIMul.");
  1417. for (uint32_t i = 0; i < 2; i++) {
  1418. if (constants[i] == nullptr) {
  1419. continue;
  1420. }
  1421. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1422. if (int_constant) {
  1423. uint32_t width = ElementWidth(int_constant->type());
  1424. if (width != 32 && width != 64) return false;
  1425. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1426. : int_constant->GetU64BitValue() == 1ull;
  1427. if (is_one) {
  1428. inst->SetOpcode(spv::Op::OpCopyObject);
  1429. inst->SetInOperands(
  1430. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1431. return true;
  1432. }
  1433. }
  1434. }
  1435. return false;
  1436. };
  1437. }
  1438. // Returns the number of elements that the |index|th in operand in |inst|
  1439. // contributes to the result of |inst|. |inst| must be an
  1440. // OpCompositeConstructInstruction.
  1441. uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
  1442. const Instruction* inst,
  1443. uint32_t index) {
  1444. assert(inst->opcode() == spv::Op::OpCompositeConstruct);
  1445. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1446. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1447. analysis::Vector* result_type =
  1448. type_mgr->GetType(inst->type_id())->AsVector();
  1449. if (result_type == nullptr) {
  1450. // If the result of the OpCompositeConstruct is not a vector then every
  1451. // operands corresponds to a single element in the result.
  1452. return 1;
  1453. }
  1454. // If the result type is a vector then the operands are either scalars or
  1455. // vectors. If it is a scalar, then it corresponds to a single element. If it
  1456. // is a vector, then each element in the vector will be an element in the
  1457. // result.
  1458. uint32_t id = inst->GetSingleWordInOperand(index);
  1459. Instruction* def = def_use_mgr->GetDef(id);
  1460. analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
  1461. if (type == nullptr) {
  1462. return 1;
  1463. }
  1464. return type->element_count();
  1465. }
  1466. // Returns the in-operands for an OpCompositeExtract instruction that are needed
  1467. // to extract the |result_index|th element in the result of |inst| without using
  1468. // the result of |inst|. Returns the empty vector if |result_index| is
  1469. // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
  1470. std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
  1471. IRContext* context, const Instruction* inst, uint32_t result_index) {
  1472. assert(inst->opcode() == spv::Op::OpCompositeConstruct);
  1473. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1474. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1475. analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  1476. if (result_type->AsVector() == nullptr) {
  1477. if (result_index < inst->NumInOperands()) {
  1478. uint32_t id = inst->GetSingleWordInOperand(result_index);
  1479. return {Operand(SPV_OPERAND_TYPE_ID, {id})};
  1480. }
  1481. return {};
  1482. }
  1483. // If the result type is a vector, then vector operands are concatenated.
  1484. uint32_t total_element_count = 0;
  1485. for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
  1486. uint32_t element_count =
  1487. GetNumOfElementsContributedByOperand(context, inst, idx);
  1488. total_element_count += element_count;
  1489. if (result_index < total_element_count) {
  1490. std::vector<Operand> operands;
  1491. uint32_t id = inst->GetSingleWordInOperand(idx);
  1492. Instruction* operand_def = def_use_mgr->GetDef(id);
  1493. analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
  1494. operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
  1495. if (operand_type->AsVector()) {
  1496. uint32_t start_index_of_id = total_element_count - element_count;
  1497. uint32_t index_into_id = result_index - start_index_of_id;
  1498. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
  1499. }
  1500. return operands;
  1501. }
  1502. }
  1503. return {};
  1504. }
  1505. bool CompositeConstructFeedingExtract(
  1506. IRContext* context, Instruction* inst,
  1507. const std::vector<const analysis::Constant*>&) {
  1508. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1509. // then we can simply use the appropriate element in the construction.
  1510. assert(inst->opcode() == spv::Op::OpCompositeExtract &&
  1511. "Wrong opcode. Should be OpCompositeExtract.");
  1512. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1513. // If there are no index operands, then this rule cannot do anything.
  1514. if (inst->NumInOperands() <= 1) {
  1515. return false;
  1516. }
  1517. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1518. Instruction* cinst = def_use_mgr->GetDef(cid);
  1519. if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
  1520. return false;
  1521. }
  1522. uint32_t index_into_result = inst->GetSingleWordInOperand(1);
  1523. std::vector<Operand> operands =
  1524. GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
  1525. index_into_result);
  1526. if (operands.empty()) {
  1527. return false;
  1528. }
  1529. // Add the remaining indices for extraction.
  1530. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1531. operands.push_back(
  1532. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
  1533. }
  1534. if (operands.size() == 1) {
  1535. // If there were no extra indices, then we have the final object. No need
  1536. // to extract any more.
  1537. inst->SetOpcode(spv::Op::OpCopyObject);
  1538. }
  1539. inst->SetInOperands(std::move(operands));
  1540. return true;
  1541. }
  1542. // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
  1543. // OpCompositeExtract instruction, and returns the type of the final element
  1544. // being accessed.
  1545. const analysis::Type* GetElementType(uint32_t type_id,
  1546. Instruction::iterator start,
  1547. Instruction::iterator end,
  1548. const analysis::TypeManager* type_mgr) {
  1549. const analysis::Type* type = type_mgr->GetType(type_id);
  1550. for (auto index : make_range(std::move(start), std::move(end))) {
  1551. assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
  1552. index.words.size() == 1);
  1553. if (auto* array_type = type->AsArray()) {
  1554. type = array_type->element_type();
  1555. } else if (auto* matrix_type = type->AsMatrix()) {
  1556. type = matrix_type->element_type();
  1557. } else if (auto* struct_type = type->AsStruct()) {
  1558. type = struct_type->element_types()[index.words[0]];
  1559. } else {
  1560. type = nullptr;
  1561. }
  1562. }
  1563. return type;
  1564. }
  1565. // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
  1566. // to index into a composite object, excluding the last index. The two
  1567. // instructions must have the same opcode, and be either OpCompositeExtract or
  1568. // OpCompositeInsert instructions.
  1569. bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
  1570. assert(inst_1->opcode() == inst_2->opcode() &&
  1571. "Expecting the opcodes to be the same.");
  1572. assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
  1573. inst_1->opcode() == spv::Op::OpCompositeExtract) &&
  1574. "Instructions must be OpCompositeInsert or OpCompositeExtract.");
  1575. if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
  1576. return false;
  1577. }
  1578. uint32_t first_index_position =
  1579. (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
  1580. for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
  1581. i++) {
  1582. if (inst_1->GetSingleWordInOperand(i) !=
  1583. inst_2->GetSingleWordInOperand(i)) {
  1584. return false;
  1585. }
  1586. }
  1587. return true;
  1588. }
  1589. // If the OpCompositeConstruct is simply putting back together elements that
  1590. // where extracted from the same source, we can simply reuse the source.
  1591. //
  1592. // This is a common code pattern because of the way that scalar replacement
  1593. // works.
  1594. bool CompositeExtractFeedingConstruct(
  1595. IRContext* context, Instruction* inst,
  1596. const std::vector<const analysis::Constant*>&) {
  1597. assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
  1598. "Wrong opcode. Should be OpCompositeConstruct.");
  1599. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1600. uint32_t original_id = 0;
  1601. if (inst->NumInOperands() == 0) {
  1602. // The struct being constructed has no members.
  1603. return false;
  1604. }
  1605. // Check each element to make sure they are:
  1606. // - extractions
  1607. // - extracting the same position they are inserting
  1608. // - all extract from the same id.
  1609. Instruction* first_element_inst = nullptr;
  1610. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1611. const uint32_t element_id = inst->GetSingleWordInOperand(i);
  1612. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1613. if (first_element_inst == nullptr) {
  1614. first_element_inst = element_inst;
  1615. }
  1616. if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
  1617. return false;
  1618. }
  1619. if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
  1620. return false;
  1621. }
  1622. if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
  1623. 1) != i) {
  1624. return false;
  1625. }
  1626. if (i == 0) {
  1627. original_id =
  1628. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1629. } else if (original_id !=
  1630. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
  1631. return false;
  1632. }
  1633. }
  1634. // The last check it to see that the object being extracted from is the
  1635. // correct type.
  1636. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1637. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1638. const analysis::Type* original_type =
  1639. GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
  1640. first_element_inst->end() - 1, type_mgr);
  1641. if (original_type == nullptr) {
  1642. return false;
  1643. }
  1644. if (inst->type_id() != type_mgr->GetId(original_type)) {
  1645. return false;
  1646. }
  1647. if (first_element_inst->NumInOperands() == 2) {
  1648. // Simplify by using the original object.
  1649. inst->SetOpcode(spv::Op::OpCopyObject);
  1650. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1651. return true;
  1652. }
  1653. // Copies the original id and all indexes except for the last to the new
  1654. // extract instruction.
  1655. inst->SetOpcode(spv::Op::OpCompositeExtract);
  1656. inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
  1657. first_element_inst->end() - 1));
  1658. return true;
  1659. }
  1660. FoldingRule InsertFeedingExtract() {
  1661. return [](IRContext* context, Instruction* inst,
  1662. const std::vector<const analysis::Constant*>&) {
  1663. assert(inst->opcode() == spv::Op::OpCompositeExtract &&
  1664. "Wrong opcode. Should be OpCompositeExtract.");
  1665. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1666. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1667. Instruction* cinst = def_use_mgr->GetDef(cid);
  1668. if (cinst->opcode() != spv::Op::OpCompositeInsert) {
  1669. return false;
  1670. }
  1671. // Find the first position where the list of insert and extract indicies
  1672. // differ, if at all.
  1673. uint32_t i;
  1674. for (i = 1; i < inst->NumInOperands(); ++i) {
  1675. if (i + 1 >= cinst->NumInOperands()) {
  1676. break;
  1677. }
  1678. if (inst->GetSingleWordInOperand(i) !=
  1679. cinst->GetSingleWordInOperand(i + 1)) {
  1680. break;
  1681. }
  1682. }
  1683. // We are extracting the element that was inserted.
  1684. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1685. inst->SetOpcode(spv::Op::OpCopyObject);
  1686. inst->SetInOperands(
  1687. {{SPV_OPERAND_TYPE_ID,
  1688. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1689. return true;
  1690. }
  1691. // Extracting the value that was inserted along with values for the base
  1692. // composite. Cannot do anything.
  1693. if (i == inst->NumInOperands()) {
  1694. return false;
  1695. }
  1696. // Extracting an element of the value that was inserted. Extract from
  1697. // that value directly.
  1698. if (i + 1 == cinst->NumInOperands()) {
  1699. std::vector<Operand> operands;
  1700. operands.push_back(
  1701. {SPV_OPERAND_TYPE_ID,
  1702. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1703. for (; i < inst->NumInOperands(); ++i) {
  1704. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1705. {inst->GetSingleWordInOperand(i)}});
  1706. }
  1707. inst->SetInOperands(std::move(operands));
  1708. return true;
  1709. }
  1710. // Extracting a value that is disjoint from the element being inserted.
  1711. // Rewrite the extract to use the composite input to the insert.
  1712. std::vector<Operand> operands;
  1713. operands.push_back(
  1714. {SPV_OPERAND_TYPE_ID,
  1715. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1716. for (i = 1; i < inst->NumInOperands(); ++i) {
  1717. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1718. {inst->GetSingleWordInOperand(i)}});
  1719. }
  1720. inst->SetInOperands(std::move(operands));
  1721. return true;
  1722. };
  1723. }
  1724. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1725. // operands of the VectorShuffle. We just need to adjust the index in the
  1726. // extract instruction.
  1727. FoldingRule VectorShuffleFeedingExtract() {
  1728. return [](IRContext* context, Instruction* inst,
  1729. const std::vector<const analysis::Constant*>&) {
  1730. assert(inst->opcode() == spv::Op::OpCompositeExtract &&
  1731. "Wrong opcode. Should be OpCompositeExtract.");
  1732. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1733. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1734. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1735. Instruction* cinst = def_use_mgr->GetDef(cid);
  1736. if (cinst->opcode() != spv::Op::OpVectorShuffle) {
  1737. return false;
  1738. }
  1739. // Find the size of the first vector operand of the VectorShuffle
  1740. Instruction* first_input =
  1741. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1742. analysis::Type* first_input_type =
  1743. type_mgr->GetType(first_input->type_id());
  1744. assert(first_input_type->AsVector() &&
  1745. "Input to vector shuffle should be vectors.");
  1746. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1747. // Get index of the element the vector shuffle is placing in the position
  1748. // being extracted.
  1749. uint32_t new_index =
  1750. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1751. // Extracting an undefined value so fold this extract into an undef.
  1752. const uint32_t undef_literal_value = 0xffffffff;
  1753. if (new_index == undef_literal_value) {
  1754. inst->SetOpcode(spv::Op::OpUndef);
  1755. inst->SetInOperands({});
  1756. return true;
  1757. }
  1758. // Get the id of the of the vector the elemtent comes from, and update the
  1759. // index if needed.
  1760. uint32_t new_vector = 0;
  1761. if (new_index < first_input_size) {
  1762. new_vector = cinst->GetSingleWordInOperand(0);
  1763. } else {
  1764. new_vector = cinst->GetSingleWordInOperand(1);
  1765. new_index -= first_input_size;
  1766. }
  1767. // Update the extract instruction.
  1768. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1769. inst->SetInOperand(1, {new_index});
  1770. return true;
  1771. };
  1772. }
  1773. // When an FMix with is feeding an Extract that extracts an element whose
  1774. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1775. // operands of the FMix.
  1776. FoldingRule FMixFeedingExtract() {
  1777. return [](IRContext* context, Instruction* inst,
  1778. const std::vector<const analysis::Constant*>&) {
  1779. assert(inst->opcode() == spv::Op::OpCompositeExtract &&
  1780. "Wrong opcode. Should be OpCompositeExtract.");
  1781. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1782. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1783. uint32_t composite_id =
  1784. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1785. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1786. if (composite_inst->opcode() != spv::Op::OpExtInst) {
  1787. return false;
  1788. }
  1789. uint32_t inst_set_id =
  1790. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1791. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1792. inst_set_id ||
  1793. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1794. GLSLstd450FMix) {
  1795. return false;
  1796. }
  1797. // Get the |a| for the FMix instruction.
  1798. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1799. std::unique_ptr<Instruction> a(inst->Clone(context));
  1800. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1801. context->get_instruction_folder().FoldInstruction(a.get());
  1802. if (a->opcode() != spv::Op::OpCopyObject) {
  1803. return false;
  1804. }
  1805. const analysis::Constant* a_const =
  1806. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1807. if (!a_const) {
  1808. return false;
  1809. }
  1810. bool use_x = false;
  1811. assert(a_const->type()->AsFloat());
  1812. double element_value = a_const->GetValueAsDouble();
  1813. if (element_value == 0.0) {
  1814. use_x = true;
  1815. } else if (element_value == 1.0) {
  1816. use_x = false;
  1817. } else {
  1818. return false;
  1819. }
  1820. // Get the id of the of the vector the element comes from.
  1821. uint32_t new_vector = 0;
  1822. if (use_x) {
  1823. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1824. } else {
  1825. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1826. }
  1827. // Update the extract instruction.
  1828. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1829. return true;
  1830. };
  1831. }
  1832. // Returns the number of elements in the composite type |type|. Returns 0 if
  1833. // |type| is a scalar value.
  1834. uint32_t GetNumberOfElements(const analysis::Type* type) {
  1835. if (auto* vector_type = type->AsVector()) {
  1836. return vector_type->element_count();
  1837. }
  1838. if (auto* matrix_type = type->AsMatrix()) {
  1839. return matrix_type->element_count();
  1840. }
  1841. if (auto* struct_type = type->AsStruct()) {
  1842. return static_cast<uint32_t>(struct_type->element_types().size());
  1843. }
  1844. if (auto* array_type = type->AsArray()) {
  1845. return array_type->length_info().words[0];
  1846. }
  1847. return 0;
  1848. }
  1849. // Returns a map with the set of values that were inserted into an object by
  1850. // the chain of OpCompositeInsertInstruction starting with |inst|.
  1851. // The map will map the index to the value inserted at that index.
  1852. std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
  1853. analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
  1854. std::map<uint32_t, uint32_t> values_inserted;
  1855. Instruction* current_inst = inst;
  1856. while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
  1857. if (current_inst->NumInOperands() > inst->NumInOperands()) {
  1858. // This is the catch the case
  1859. // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
  1860. // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
  1861. // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
  1862. // In this case we cannot do a single construct to get the matrix.
  1863. uint32_t partially_inserted_element_index =
  1864. current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
  1865. if (values_inserted.count(partially_inserted_element_index) == 0)
  1866. return {};
  1867. }
  1868. if (HaveSameIndexesExceptForLast(inst, current_inst)) {
  1869. values_inserted.insert(
  1870. {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
  1871. 1),
  1872. current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
  1873. }
  1874. current_inst = def_use_mgr->GetDef(
  1875. current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
  1876. }
  1877. return values_inserted;
  1878. }
  1879. // Returns true of there is an entry in |values_inserted| for every element of
  1880. // |Type|.
  1881. bool DoInsertedValuesCoverEntireObject(
  1882. const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
  1883. uint32_t container_size = GetNumberOfElements(type);
  1884. if (container_size != values_inserted.size()) {
  1885. return false;
  1886. }
  1887. if (values_inserted.rbegin()->first >= container_size) {
  1888. return false;
  1889. }
  1890. return true;
  1891. }
  1892. // Returns the type of the element that immediately contains the element being
  1893. // inserted by the OpCompositeInsert instruction |inst|.
  1894. const analysis::Type* GetContainerType(Instruction* inst) {
  1895. assert(inst->opcode() == spv::Op::OpCompositeInsert);
  1896. analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
  1897. return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
  1898. type_mgr);
  1899. }
  1900. // Returns an OpCompositeConstruct instruction that build an object with
  1901. // |type_id| out of the values in |values_inserted|. Each value will be
  1902. // placed at the index corresponding to the value. The new instruction will
  1903. // be placed before |insert_before|.
  1904. Instruction* BuildCompositeConstruct(
  1905. uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
  1906. Instruction* insert_before) {
  1907. InstructionBuilder ir_builder(
  1908. insert_before->context(), insert_before,
  1909. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1910. std::vector<uint32_t> ids_in_order;
  1911. for (auto it : values_inserted) {
  1912. ids_in_order.push_back(it.second);
  1913. }
  1914. Instruction* construct =
  1915. ir_builder.AddCompositeConstruct(type_id, ids_in_order);
  1916. return construct;
  1917. }
  1918. // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
  1919. // object as |inst| with final index removed. If the resulting
  1920. // OpCompositeInsert instruction would have no remaining indexes, the
  1921. // instruction is replaced with an OpCopyObject instead.
  1922. void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
  1923. if (inst->NumInOperands() == 3) {
  1924. inst->SetOpcode(spv::Op::OpCopyObject);
  1925. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
  1926. } else {
  1927. inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
  1928. inst->RemoveOperand(inst->NumOperands() - 1);
  1929. }
  1930. }
  1931. // Replaces a series of |OpCompositeInsert| instruction that cover the entire
  1932. // object with an |OpCompositeConstruct|.
  1933. bool CompositeInsertToCompositeConstruct(
  1934. IRContext* context, Instruction* inst,
  1935. const std::vector<const analysis::Constant*>&) {
  1936. assert(inst->opcode() == spv::Op::OpCompositeInsert &&
  1937. "Wrong opcode. Should be OpCompositeInsert.");
  1938. if (inst->NumInOperands() < 3) return false;
  1939. std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
  1940. const analysis::Type* container_type = GetContainerType(inst);
  1941. if (container_type == nullptr) {
  1942. return false;
  1943. }
  1944. if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
  1945. return false;
  1946. }
  1947. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1948. Instruction* construct = BuildCompositeConstruct(
  1949. type_mgr->GetId(container_type), values_inserted, inst);
  1950. InsertConstructedObject(inst, construct);
  1951. return true;
  1952. }
  1953. FoldingRule RedundantPhi() {
  1954. // An OpPhi instruction where all values are the same or the result of the phi
  1955. // itself, can be replaced by the value itself.
  1956. return [](IRContext*, Instruction* inst,
  1957. const std::vector<const analysis::Constant*>&) {
  1958. assert(inst->opcode() == spv::Op::OpPhi &&
  1959. "Wrong opcode. Should be OpPhi.");
  1960. uint32_t incoming_value = 0;
  1961. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1962. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1963. if (op_id == inst->result_id()) {
  1964. continue;
  1965. }
  1966. if (incoming_value == 0) {
  1967. incoming_value = op_id;
  1968. } else if (op_id != incoming_value) {
  1969. // Found two possible value. Can't simplify.
  1970. return false;
  1971. }
  1972. }
  1973. if (incoming_value == 0) {
  1974. // Code looks invalid. Don't do anything.
  1975. return false;
  1976. }
  1977. // We have a single incoming value. Simplify using that value.
  1978. inst->SetOpcode(spv::Op::OpCopyObject);
  1979. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1980. return true;
  1981. };
  1982. }
  1983. FoldingRule BitCastScalarOrVector() {
  1984. return [](IRContext* context, Instruction* inst,
  1985. const std::vector<const analysis::Constant*>& constants) {
  1986. assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
  1987. if (constants[0] == nullptr) return false;
  1988. const analysis::Type* type =
  1989. context->get_type_mgr()->GetType(inst->type_id());
  1990. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  1991. return false;
  1992. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1993. std::vector<uint32_t> words =
  1994. GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
  1995. if (words.size() == 0) return false;
  1996. const analysis::Constant* bitcasted_constant =
  1997. ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
  1998. if (!bitcasted_constant) return false;
  1999. auto new_feeder_id =
  2000. const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
  2001. ->result_id();
  2002. inst->SetOpcode(spv::Op::OpCopyObject);
  2003. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
  2004. return true;
  2005. };
  2006. }
  2007. FoldingRule RedundantSelect() {
  2008. // An OpSelect instruction where both values are the same or the condition is
  2009. // constant can be replaced by one of the values
  2010. return [](IRContext*, Instruction* inst,
  2011. const std::vector<const analysis::Constant*>& constants) {
  2012. assert(inst->opcode() == spv::Op::OpSelect &&
  2013. "Wrong opcode. Should be OpSelect.");
  2014. assert(inst->NumInOperands() == 3);
  2015. assert(constants.size() == 3);
  2016. uint32_t true_id = inst->GetSingleWordInOperand(1);
  2017. uint32_t false_id = inst->GetSingleWordInOperand(2);
  2018. if (true_id == false_id) {
  2019. // Both results are the same, condition doesn't matter
  2020. inst->SetOpcode(spv::Op::OpCopyObject);
  2021. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  2022. return true;
  2023. } else if (constants[0]) {
  2024. const analysis::Type* type = constants[0]->type();
  2025. if (type->AsBool()) {
  2026. // Scalar constant value, select the corresponding value.
  2027. inst->SetOpcode(spv::Op::OpCopyObject);
  2028. if (constants[0]->AsNullConstant() ||
  2029. !constants[0]->AsBoolConstant()->value()) {
  2030. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  2031. } else {
  2032. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  2033. }
  2034. return true;
  2035. } else {
  2036. assert(type->AsVector());
  2037. if (constants[0]->AsNullConstant()) {
  2038. // All values come from false id.
  2039. inst->SetOpcode(spv::Op::OpCopyObject);
  2040. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  2041. return true;
  2042. } else {
  2043. // Convert to a vector shuffle.
  2044. std::vector<Operand> ops;
  2045. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  2046. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  2047. const analysis::VectorConstant* vector_const =
  2048. constants[0]->AsVectorConstant();
  2049. uint32_t size =
  2050. static_cast<uint32_t>(vector_const->GetComponents().size());
  2051. for (uint32_t i = 0; i != size; ++i) {
  2052. const analysis::Constant* component =
  2053. vector_const->GetComponents()[i];
  2054. if (component->AsNullConstant() ||
  2055. !component->AsBoolConstant()->value()) {
  2056. // Selecting from the false vector which is the second input
  2057. // vector to the shuffle. Offset the index by |size|.
  2058. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  2059. } else {
  2060. // Selecting from true vector which is the first input vector to
  2061. // the shuffle.
  2062. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  2063. }
  2064. }
  2065. inst->SetOpcode(spv::Op::OpVectorShuffle);
  2066. inst->SetInOperands(std::move(ops));
  2067. return true;
  2068. }
  2069. }
  2070. }
  2071. return false;
  2072. };
  2073. }
  2074. enum class FloatConstantKind { Unknown, Zero, One };
  2075. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  2076. if (constant == nullptr) {
  2077. return FloatConstantKind::Unknown;
  2078. }
  2079. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  2080. if (constant->AsNullConstant()) {
  2081. return FloatConstantKind::Zero;
  2082. } else if (const analysis::VectorConstant* vc =
  2083. constant->AsVectorConstant()) {
  2084. const std::vector<const analysis::Constant*>& components =
  2085. vc->GetComponents();
  2086. assert(!components.empty());
  2087. FloatConstantKind kind = getFloatConstantKind(components[0]);
  2088. for (size_t i = 1; i < components.size(); ++i) {
  2089. if (getFloatConstantKind(components[i]) != kind) {
  2090. return FloatConstantKind::Unknown;
  2091. }
  2092. }
  2093. return kind;
  2094. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  2095. if (fc->IsZero()) return FloatConstantKind::Zero;
  2096. uint32_t width = fc->type()->AsFloat()->width();
  2097. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  2098. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  2099. if (value == 0.0) {
  2100. return FloatConstantKind::Zero;
  2101. } else if (value == 1.0) {
  2102. return FloatConstantKind::One;
  2103. } else {
  2104. return FloatConstantKind::Unknown;
  2105. }
  2106. } else {
  2107. return FloatConstantKind::Unknown;
  2108. }
  2109. }
  2110. FoldingRule RedundantFAdd() {
  2111. return [](IRContext*, Instruction* inst,
  2112. const std::vector<const analysis::Constant*>& constants) {
  2113. assert(inst->opcode() == spv::Op::OpFAdd &&
  2114. "Wrong opcode. Should be OpFAdd.");
  2115. assert(constants.size() == 2);
  2116. if (!inst->IsFloatingPointFoldingAllowed()) {
  2117. return false;
  2118. }
  2119. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2120. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2121. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  2122. inst->SetOpcode(spv::Op::OpCopyObject);
  2123. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2124. {inst->GetSingleWordInOperand(
  2125. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  2126. return true;
  2127. }
  2128. return false;
  2129. };
  2130. }
  2131. FoldingRule RedundantFSub() {
  2132. return [](IRContext*, Instruction* inst,
  2133. const std::vector<const analysis::Constant*>& constants) {
  2134. assert(inst->opcode() == spv::Op::OpFSub &&
  2135. "Wrong opcode. Should be OpFSub.");
  2136. assert(constants.size() == 2);
  2137. if (!inst->IsFloatingPointFoldingAllowed()) {
  2138. return false;
  2139. }
  2140. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2141. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2142. if (kind0 == FloatConstantKind::Zero) {
  2143. inst->SetOpcode(spv::Op::OpFNegate);
  2144. inst->SetInOperands(
  2145. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  2146. return true;
  2147. }
  2148. if (kind1 == FloatConstantKind::Zero) {
  2149. inst->SetOpcode(spv::Op::OpCopyObject);
  2150. inst->SetInOperands(
  2151. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2152. return true;
  2153. }
  2154. return false;
  2155. };
  2156. }
  2157. FoldingRule RedundantFMul() {
  2158. return [](IRContext*, Instruction* inst,
  2159. const std::vector<const analysis::Constant*>& constants) {
  2160. assert(inst->opcode() == spv::Op::OpFMul &&
  2161. "Wrong opcode. Should be OpFMul.");
  2162. assert(constants.size() == 2);
  2163. if (!inst->IsFloatingPointFoldingAllowed()) {
  2164. return false;
  2165. }
  2166. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2167. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2168. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  2169. inst->SetOpcode(spv::Op::OpCopyObject);
  2170. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2171. {inst->GetSingleWordInOperand(
  2172. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  2173. return true;
  2174. }
  2175. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  2176. inst->SetOpcode(spv::Op::OpCopyObject);
  2177. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2178. {inst->GetSingleWordInOperand(
  2179. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  2180. return true;
  2181. }
  2182. return false;
  2183. };
  2184. }
  2185. FoldingRule RedundantFDiv() {
  2186. return [](IRContext*, Instruction* inst,
  2187. const std::vector<const analysis::Constant*>& constants) {
  2188. assert(inst->opcode() == spv::Op::OpFDiv &&
  2189. "Wrong opcode. Should be OpFDiv.");
  2190. assert(constants.size() == 2);
  2191. if (!inst->IsFloatingPointFoldingAllowed()) {
  2192. return false;
  2193. }
  2194. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2195. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2196. if (kind0 == FloatConstantKind::Zero) {
  2197. inst->SetOpcode(spv::Op::OpCopyObject);
  2198. inst->SetInOperands(
  2199. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2200. return true;
  2201. }
  2202. if (kind1 == FloatConstantKind::One) {
  2203. inst->SetOpcode(spv::Op::OpCopyObject);
  2204. inst->SetInOperands(
  2205. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2206. return true;
  2207. }
  2208. return false;
  2209. };
  2210. }
  2211. FoldingRule RedundantFMix() {
  2212. return [](IRContext* context, Instruction* inst,
  2213. const std::vector<const analysis::Constant*>& constants) {
  2214. assert(inst->opcode() == spv::Op::OpExtInst &&
  2215. "Wrong opcode. Should be OpExtInst.");
  2216. if (!inst->IsFloatingPointFoldingAllowed()) {
  2217. return false;
  2218. }
  2219. uint32_t instSetId =
  2220. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  2221. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  2222. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  2223. GLSLstd450FMix) {
  2224. assert(constants.size() == 5);
  2225. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  2226. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  2227. inst->SetOpcode(spv::Op::OpCopyObject);
  2228. inst->SetInOperands(
  2229. {{SPV_OPERAND_TYPE_ID,
  2230. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  2231. ? kFMixXIdInIdx
  2232. : kFMixYIdInIdx)}}});
  2233. return true;
  2234. }
  2235. }
  2236. return false;
  2237. };
  2238. }
  2239. // This rule handles addition of zero for integers.
  2240. FoldingRule RedundantIAdd() {
  2241. return [](IRContext* context, Instruction* inst,
  2242. const std::vector<const analysis::Constant*>& constants) {
  2243. assert(inst->opcode() == spv::Op::OpIAdd &&
  2244. "Wrong opcode. Should be OpIAdd.");
  2245. uint32_t operand = std::numeric_limits<uint32_t>::max();
  2246. const analysis::Type* operand_type = nullptr;
  2247. if (constants[0] && constants[0]->IsZero()) {
  2248. operand = inst->GetSingleWordInOperand(1);
  2249. operand_type = constants[0]->type();
  2250. } else if (constants[1] && constants[1]->IsZero()) {
  2251. operand = inst->GetSingleWordInOperand(0);
  2252. operand_type = constants[1]->type();
  2253. }
  2254. if (operand != std::numeric_limits<uint32_t>::max()) {
  2255. const analysis::Type* inst_type =
  2256. context->get_type_mgr()->GetType(inst->type_id());
  2257. if (inst_type->IsSame(operand_type)) {
  2258. inst->SetOpcode(spv::Op::OpCopyObject);
  2259. } else {
  2260. inst->SetOpcode(spv::Op::OpBitcast);
  2261. }
  2262. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  2263. return true;
  2264. }
  2265. return false;
  2266. };
  2267. }
  2268. // This rule look for a dot with a constant vector containing a single 1 and
  2269. // the rest 0s. This is the same as doing an extract.
  2270. FoldingRule DotProductDoingExtract() {
  2271. return [](IRContext* context, Instruction* inst,
  2272. const std::vector<const analysis::Constant*>& constants) {
  2273. assert(inst->opcode() == spv::Op::OpDot &&
  2274. "Wrong opcode. Should be OpDot.");
  2275. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2276. if (!inst->IsFloatingPointFoldingAllowed()) {
  2277. return false;
  2278. }
  2279. for (int i = 0; i < 2; ++i) {
  2280. if (!constants[i]) {
  2281. continue;
  2282. }
  2283. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  2284. assert(vector_type && "Inputs to OpDot must be vectors.");
  2285. const analysis::Float* element_type =
  2286. vector_type->element_type()->AsFloat();
  2287. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  2288. uint32_t element_width = element_type->width();
  2289. if (element_width != 32 && element_width != 64) {
  2290. return false;
  2291. }
  2292. std::vector<const analysis::Constant*> components;
  2293. components = constants[i]->GetVectorComponents(const_mgr);
  2294. constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  2295. uint32_t component_with_one = kNotFound;
  2296. bool all_others_zero = true;
  2297. for (uint32_t j = 0; j < components.size(); ++j) {
  2298. const analysis::Constant* element = components[j];
  2299. double value =
  2300. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  2301. if (value == 0.0) {
  2302. continue;
  2303. } else if (value == 1.0) {
  2304. if (component_with_one == kNotFound) {
  2305. component_with_one = j;
  2306. } else {
  2307. component_with_one = kNotFound;
  2308. break;
  2309. }
  2310. } else {
  2311. all_others_zero = false;
  2312. break;
  2313. }
  2314. }
  2315. if (!all_others_zero || component_with_one == kNotFound) {
  2316. continue;
  2317. }
  2318. std::vector<Operand> operands;
  2319. operands.push_back(
  2320. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  2321. operands.push_back(
  2322. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  2323. inst->SetOpcode(spv::Op::OpCompositeExtract);
  2324. inst->SetInOperands(std::move(operands));
  2325. return true;
  2326. }
  2327. return false;
  2328. };
  2329. }
  2330. // If we are storing an undef, then we can remove the store.
  2331. //
  2332. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  2333. // is complicated. Waiting to see if it is needed.
  2334. FoldingRule StoringUndef() {
  2335. return [](IRContext* context, Instruction* inst,
  2336. const std::vector<const analysis::Constant*>&) {
  2337. assert(inst->opcode() == spv::Op::OpStore &&
  2338. "Wrong opcode. Should be OpStore.");
  2339. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  2340. // If this is a volatile store, the store cannot be removed.
  2341. if (inst->NumInOperands() == 3) {
  2342. if (inst->GetSingleWordInOperand(2) &
  2343. uint32_t(spv::MemoryAccessMask::Volatile)) {
  2344. return false;
  2345. }
  2346. }
  2347. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  2348. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  2349. if (object_inst->opcode() == spv::Op::OpUndef) {
  2350. inst->ToNop();
  2351. return true;
  2352. }
  2353. return false;
  2354. };
  2355. }
  2356. FoldingRule VectorShuffleFeedingShuffle() {
  2357. return [](IRContext* context, Instruction* inst,
  2358. const std::vector<const analysis::Constant*>&) {
  2359. assert(inst->opcode() == spv::Op::OpVectorShuffle &&
  2360. "Wrong opcode. Should be OpVectorShuffle.");
  2361. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  2362. analysis::TypeManager* type_mgr = context->get_type_mgr();
  2363. Instruction* feeding_shuffle_inst =
  2364. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  2365. analysis::Vector* op0_type =
  2366. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  2367. uint32_t op0_length = op0_type->element_count();
  2368. bool feeder_is_op0 = true;
  2369. if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
  2370. feeding_shuffle_inst =
  2371. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  2372. feeder_is_op0 = false;
  2373. }
  2374. if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
  2375. return false;
  2376. }
  2377. Instruction* feeder2 =
  2378. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  2379. analysis::Vector* feeder_op0_type =
  2380. type_mgr->GetType(feeder2->type_id())->AsVector();
  2381. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  2382. uint32_t new_feeder_id = 0;
  2383. std::vector<Operand> new_operands;
  2384. new_operands.resize(
  2385. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  2386. const uint32_t undef_literal = 0xffffffff;
  2387. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  2388. uint32_t component_index = inst->GetSingleWordInOperand(op);
  2389. // Do not interpret the undefined value literal as coming from operand 1.
  2390. if (component_index != undef_literal &&
  2391. feeder_is_op0 == (component_index < op0_length)) {
  2392. // This component comes from the feeding_shuffle_inst. Update
  2393. // |component_index| to be the index into the operand of the feeder.
  2394. // Adjust component_index to get the index into the operands of the
  2395. // feeding_shuffle_inst.
  2396. if (component_index >= op0_length) {
  2397. component_index -= op0_length;
  2398. }
  2399. component_index =
  2400. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  2401. // Check if we are using a component from the first or second operand of
  2402. // the feeding instruction.
  2403. if (component_index < feeder_op0_length) {
  2404. if (new_feeder_id == 0) {
  2405. // First time through, save the id of the operand the element comes
  2406. // from.
  2407. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  2408. } else if (new_feeder_id !=
  2409. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  2410. // We need both elements of the feeding_shuffle_inst, so we cannot
  2411. // fold.
  2412. return false;
  2413. }
  2414. } else if (component_index != undef_literal) {
  2415. if (new_feeder_id == 0) {
  2416. // First time through, save the id of the operand the element comes
  2417. // from.
  2418. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  2419. } else if (new_feeder_id !=
  2420. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  2421. // We need both elements of the feeding_shuffle_inst, so we cannot
  2422. // fold.
  2423. return false;
  2424. }
  2425. component_index -= feeder_op0_length;
  2426. }
  2427. if (!feeder_is_op0 && component_index != undef_literal) {
  2428. component_index += op0_length;
  2429. }
  2430. }
  2431. new_operands.push_back(
  2432. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  2433. }
  2434. if (new_feeder_id == 0) {
  2435. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2436. const analysis::Type* type =
  2437. type_mgr->GetType(feeding_shuffle_inst->type_id());
  2438. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  2439. new_feeder_id =
  2440. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  2441. }
  2442. if (feeder_is_op0) {
  2443. // If the size of the first vector operand changed then the indices
  2444. // referring to the second operand need to be adjusted.
  2445. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  2446. analysis::Type* new_feeder_type =
  2447. type_mgr->GetType(new_feeder_inst->type_id());
  2448. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  2449. int32_t adjustment = op0_length - new_op0_size;
  2450. if (adjustment != 0) {
  2451. for (uint32_t i = 2; i < new_operands.size(); i++) {
  2452. uint32_t operand = inst->GetSingleWordInOperand(i);
  2453. if (operand >= op0_length && operand != undef_literal) {
  2454. new_operands[i].words[0] -= adjustment;
  2455. }
  2456. }
  2457. }
  2458. new_operands[0].words[0] = new_feeder_id;
  2459. new_operands[1] = inst->GetInOperand(1);
  2460. } else {
  2461. new_operands[1].words[0] = new_feeder_id;
  2462. new_operands[0] = inst->GetInOperand(0);
  2463. }
  2464. inst->SetInOperands(std::move(new_operands));
  2465. return true;
  2466. };
  2467. }
  2468. // Removes duplicate ids from the interface list of an OpEntryPoint
  2469. // instruction.
  2470. FoldingRule RemoveRedundantOperands() {
  2471. return [](IRContext*, Instruction* inst,
  2472. const std::vector<const analysis::Constant*>&) {
  2473. assert(inst->opcode() == spv::Op::OpEntryPoint &&
  2474. "Wrong opcode. Should be OpEntryPoint.");
  2475. bool has_redundant_operand = false;
  2476. std::unordered_set<uint32_t> seen_operands;
  2477. std::vector<Operand> new_operands;
  2478. new_operands.emplace_back(inst->GetOperand(0));
  2479. new_operands.emplace_back(inst->GetOperand(1));
  2480. new_operands.emplace_back(inst->GetOperand(2));
  2481. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2482. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2483. new_operands.emplace_back(inst->GetOperand(i));
  2484. } else {
  2485. has_redundant_operand = true;
  2486. }
  2487. }
  2488. if (!has_redundant_operand) {
  2489. return false;
  2490. }
  2491. inst->SetInOperands(std::move(new_operands));
  2492. return true;
  2493. };
  2494. }
  2495. // If an image instruction's operand is a constant, updates the image operand
  2496. // flag from Offset to ConstOffset.
  2497. FoldingRule UpdateImageOperands() {
  2498. return [](IRContext*, Instruction* inst,
  2499. const std::vector<const analysis::Constant*>& constants) {
  2500. const auto opcode = inst->opcode();
  2501. (void)opcode;
  2502. assert((opcode == spv::Op::OpImageSampleImplicitLod ||
  2503. opcode == spv::Op::OpImageSampleExplicitLod ||
  2504. opcode == spv::Op::OpImageSampleDrefImplicitLod ||
  2505. opcode == spv::Op::OpImageSampleDrefExplicitLod ||
  2506. opcode == spv::Op::OpImageSampleProjImplicitLod ||
  2507. opcode == spv::Op::OpImageSampleProjExplicitLod ||
  2508. opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
  2509. opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
  2510. opcode == spv::Op::OpImageFetch ||
  2511. opcode == spv::Op::OpImageGather ||
  2512. opcode == spv::Op::OpImageDrefGather ||
  2513. opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
  2514. opcode == spv::Op::OpImageSparseSampleImplicitLod ||
  2515. opcode == spv::Op::OpImageSparseSampleExplicitLod ||
  2516. opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
  2517. opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
  2518. opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
  2519. opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
  2520. opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
  2521. opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
  2522. opcode == spv::Op::OpImageSparseFetch ||
  2523. opcode == spv::Op::OpImageSparseGather ||
  2524. opcode == spv::Op::OpImageSparseDrefGather ||
  2525. opcode == spv::Op::OpImageSparseRead) &&
  2526. "Wrong opcode. Should be an image instruction.");
  2527. int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
  2528. if (operand_index >= 0) {
  2529. auto image_operands = inst->GetSingleWordInOperand(operand_index);
  2530. if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
  2531. uint32_t offset_operand_index = operand_index + 1;
  2532. if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
  2533. offset_operand_index++;
  2534. if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
  2535. offset_operand_index++;
  2536. if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
  2537. offset_operand_index += 2;
  2538. assert(((image_operands &
  2539. uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
  2540. "Offset and ConstOffset may not be used together");
  2541. if (offset_operand_index < inst->NumOperands()) {
  2542. if (constants[offset_operand_index]) {
  2543. image_operands =
  2544. image_operands | uint32_t(spv::ImageOperandsMask::ConstOffset);
  2545. image_operands =
  2546. image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
  2547. inst->SetInOperand(operand_index, {image_operands});
  2548. return true;
  2549. }
  2550. }
  2551. }
  2552. }
  2553. return false;
  2554. };
  2555. }
  2556. } // namespace
  2557. void FoldingRules::AddFoldingRules() {
  2558. // Add all folding rules to the list for the opcodes to which they apply.
  2559. // Note that the order in which rules are added to the list matters. If a rule
  2560. // applies to the instruction, the rest of the rules will not be attempted.
  2561. // Take that into consideration.
  2562. rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
  2563. rules_[spv::Op::OpCompositeConstruct].push_back(
  2564. CompositeExtractFeedingConstruct);
  2565. rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
  2566. rules_[spv::Op::OpCompositeExtract].push_back(
  2567. CompositeConstructFeedingExtract);
  2568. rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2569. rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
  2570. rules_[spv::Op::OpCompositeInsert].push_back(
  2571. CompositeInsertToCompositeConstruct);
  2572. rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
  2573. rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
  2574. rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
  2575. rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
  2576. rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
  2577. rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
  2578. rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
  2579. rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
  2580. rules_[spv::Op::OpFAdd].push_back(MergeMulAddArithmetic);
  2581. rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
  2582. rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
  2583. rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
  2584. rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
  2585. rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
  2586. rules_[spv::Op::OpFMul].push_back(RedundantFMul());
  2587. rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
  2588. rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
  2589. rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
  2590. rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
  2591. rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
  2592. rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
  2593. rules_[spv::Op::OpFSub].push_back(RedundantFSub());
  2594. rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
  2595. rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
  2596. rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
  2597. rules_[spv::Op::OpFSub].push_back(MergeMulSubArithmetic);
  2598. rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
  2599. rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
  2600. rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
  2601. rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
  2602. rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
  2603. rules_[spv::Op::OpIAdd].push_back(FactorAddMuls());
  2604. rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
  2605. rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
  2606. rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
  2607. rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
  2608. rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
  2609. rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
  2610. rules_[spv::Op::OpPhi].push_back(RedundantPhi());
  2611. rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
  2612. rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
  2613. rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
  2614. rules_[spv::Op::OpSelect].push_back(RedundantSelect());
  2615. rules_[spv::Op::OpStore].push_back(StoringUndef());
  2616. rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2617. rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
  2618. rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
  2619. rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
  2620. UpdateImageOperands());
  2621. rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
  2622. UpdateImageOperands());
  2623. rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
  2624. UpdateImageOperands());
  2625. rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
  2626. UpdateImageOperands());
  2627. rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
  2628. UpdateImageOperands());
  2629. rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
  2630. UpdateImageOperands());
  2631. rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
  2632. rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
  2633. rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
  2634. rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
  2635. rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
  2636. rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
  2637. UpdateImageOperands());
  2638. rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
  2639. UpdateImageOperands());
  2640. rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
  2641. UpdateImageOperands());
  2642. rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
  2643. UpdateImageOperands());
  2644. rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
  2645. UpdateImageOperands());
  2646. rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
  2647. UpdateImageOperands());
  2648. rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
  2649. UpdateImageOperands());
  2650. rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
  2651. UpdateImageOperands());
  2652. rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
  2653. rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
  2654. rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
  2655. rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
  2656. FeatureManager* feature_manager = context_->get_feature_mgr();
  2657. // Add rules for GLSLstd450
  2658. uint32_t ext_inst_glslstd450_id =
  2659. feature_manager->GetExtInstImportId_GLSLstd450();
  2660. if (ext_inst_glslstd450_id != 0) {
  2661. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2662. RedundantFMix());
  2663. }
  2664. }
  2665. } // namespace opt
  2666. } // namespace spvtools