folding_rules.cpp 88 KB


  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <limits>
  16. #include <memory>
  17. #include <utility>
  18. #include "ir_builder.h"
  19. #include "source/latest_version_glsl_std_450_header.h"
  20. #include "source/opt/ir_context.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. const uint32_t kExtractCompositeIdInIdx = 0;
  25. const uint32_t kInsertObjectIdInIdx = 0;
  26. const uint32_t kInsertCompositeIdInIdx = 1;
  27. const uint32_t kExtInstSetIdInIdx = 0;
  28. const uint32_t kExtInstInstructionInIdx = 1;
  29. const uint32_t kFMixXIdInIdx = 2;
  30. const uint32_t kFMixYIdInIdx = 3;
  31. const uint32_t kFMixAIdInIdx = 4;
  32. const uint32_t kStoreObjectInIdx = 1;
  33. // Returns the element width of |type|.
  34. uint32_t ElementWidth(const analysis::Type* type) {
  35. if (const analysis::Vector* vec_type = type->AsVector()) {
  36. return ElementWidth(vec_type->element_type());
  37. } else if (const analysis::Float* float_type = type->AsFloat()) {
  38. return float_type->width();
  39. } else {
  40. assert(type->AsInteger());
  41. return type->AsInteger()->width();
  42. }
  43. }
  44. // Returns true if |type| is Float or a vector of Float.
  45. bool HasFloatingPoint(const analysis::Type* type) {
  46. if (type->AsFloat()) {
  47. return true;
  48. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  49. return vec_type->element_type()->AsFloat() != nullptr;
  50. }
  51. return false;
  52. }
  53. // Returns false if |val| is NaN, infinite or subnormal.
  54. template <typename T>
  55. bool IsValidResult(T val) {
  56. int classified = std::fpclassify(val);
  57. switch (classified) {
  58. case FP_NAN:
  59. case FP_INFINITE:
  60. case FP_SUBNORMAL:
  61. return false;
  62. default:
  63. return true;
  64. }
  65. }
  66. const analysis::Constant* ConstInput(
  67. const std::vector<const analysis::Constant*>& constants) {
  68. return constants[0] ? constants[0] : constants[1];
  69. }
  70. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  71. Instruction* inst) {
  72. uint32_t in_op = c ? 1u : 0u;
  73. return context->get_def_use_mgr()->GetDef(
  74. inst->GetSingleWordInOperand(in_op));
  75. }
  76. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  77. // constant.
  78. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  79. const analysis::Constant* c) {
  80. assert(c);
  81. assert(c->type()->AsFloat());
  82. uint32_t width = c->type()->AsFloat()->width();
  83. assert(width == 32 || width == 64);
  84. std::vector<uint32_t> words;
  85. if (width == 64) {
  86. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  87. words = result.GetWords();
  88. } else {
  89. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  90. words = result.GetWords();
  91. }
  92. const analysis::Constant* negated_const =
  93. const_mgr->GetConstant(c->type(), std::move(words));
  94. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  95. }
  96. std::vector<uint32_t> ExtractInts(uint64_t val) {
  97. std::vector<uint32_t> words;
  98. words.push_back(static_cast<uint32_t>(val));
  99. words.push_back(static_cast<uint32_t>(val >> 32));
  100. return words;
  101. }
  102. // Negates the integer constant |c|. Returns the id of the defining instruction.
  103. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  104. const analysis::Constant* c) {
  105. assert(c);
  106. assert(c->type()->AsInteger());
  107. uint32_t width = c->type()->AsInteger()->width();
  108. assert(width == 32 || width == 64);
  109. std::vector<uint32_t> words;
  110. if (width == 64) {
  111. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  112. words = ExtractInts(uval);
  113. } else {
  114. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  115. }
  116. const analysis::Constant* negated_const =
  117. const_mgr->GetConstant(c->type(), std::move(words));
  118. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  119. }
  120. // Negates the vector constant |c|. Returns the id of the defining instruction.
  121. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  122. const analysis::Constant* c) {
  123. assert(const_mgr && c);
  124. assert(c->type()->AsVector());
  125. if (c->AsNullConstant()) {
  126. // 0.0 vs -0.0 shouldn't matter.
  127. return const_mgr->GetDefiningInstruction(c)->result_id();
  128. } else {
  129. const analysis::Type* component_type =
  130. c->AsVectorConstant()->component_type();
  131. std::vector<uint32_t> words;
  132. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  133. if (component_type->AsFloat()) {
  134. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  135. } else {
  136. assert(component_type->AsInteger());
  137. words.push_back(NegateIntegerConstant(const_mgr, comp));
  138. }
  139. }
  140. const analysis::Constant* negated_const =
  141. const_mgr->GetConstant(c->type(), std::move(words));
  142. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  143. }
  144. }
  145. // Negates |c|. Returns the id of the defining instruction.
  146. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  147. const analysis::Constant* c) {
  148. if (c->type()->AsVector()) {
  149. return NegateVectorConstant(const_mgr, c);
  150. } else if (c->type()->AsFloat()) {
  151. return NegateFloatingPointConstant(const_mgr, c);
  152. } else {
  153. assert(c->type()->AsInteger());
  154. return NegateIntegerConstant(const_mgr, c);
  155. }
  156. }
  157. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  158. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  159. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  160. const analysis::Constant* c) {
  161. assert(const_mgr && c);
  162. assert(c->type()->AsFloat());
  163. uint32_t width = c->type()->AsFloat()->width();
  164. assert(width == 32 || width == 64);
  165. std::vector<uint32_t> words;
  166. if (width == 64) {
  167. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  168. if (!IsValidResult(result.getAsFloat())) return 0;
  169. words = result.GetWords();
  170. } else {
  171. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  172. if (!IsValidResult(result.getAsFloat())) return 0;
  173. words = result.GetWords();
  174. }
  175. const analysis::Constant* negated_const =
  176. const_mgr->GetConstant(c->type(), std::move(words));
  177. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  178. }
  179. // Replaces fdiv where second operand is constant with fmul.
  180. FoldingRule ReciprocalFDiv() {
  181. return [](IRContext* context, Instruction* inst,
  182. const std::vector<const analysis::Constant*>& constants) {
  183. assert(inst->opcode() == SpvOpFDiv);
  184. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  185. const analysis::Type* type =
  186. context->get_type_mgr()->GetType(inst->type_id());
  187. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  188. uint32_t width = ElementWidth(type);
  189. if (width != 32 && width != 64) return false;
  190. if (constants[1] != nullptr) {
  191. uint32_t id = 0;
  192. if (const analysis::VectorConstant* vector_const =
  193. constants[1]->AsVectorConstant()) {
  194. std::vector<uint32_t> neg_ids;
  195. for (auto& comp : vector_const->GetComponents()) {
  196. id = Reciprocal(const_mgr, comp);
  197. if (id == 0) return false;
  198. neg_ids.push_back(id);
  199. }
  200. const analysis::Constant* negated_const =
  201. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  202. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  203. } else if (constants[1]->AsFloatConstant()) {
  204. id = Reciprocal(const_mgr, constants[1]);
  205. if (id == 0) return false;
  206. } else {
  207. // Don't fold a null constant.
  208. return false;
  209. }
  210. inst->SetOpcode(SpvOpFMul);
  211. inst->SetInOperands(
  212. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  213. {SPV_OPERAND_TYPE_ID, {id}}});
  214. return true;
  215. }
  216. return false;
  217. };
  218. }
  219. // Elides consecutive negate instructions.
  220. FoldingRule MergeNegateArithmetic() {
  221. return [](IRContext* context, Instruction* inst,
  222. const std::vector<const analysis::Constant*>& constants) {
  223. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  224. (void)constants;
  225. const analysis::Type* type =
  226. context->get_type_mgr()->GetType(inst->type_id());
  227. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  228. return false;
  229. Instruction* op_inst =
  230. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  231. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  232. return false;
  233. if (op_inst->opcode() == inst->opcode()) {
  234. // Elide negates.
  235. inst->SetOpcode(SpvOpCopyObject);
  236. inst->SetInOperands(
  237. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  238. return true;
  239. }
  240. return false;
  241. };
  242. }
  243. // Merges negate into a mul or div operation if that operation contains a
  244. // constant operand.
  245. // Cases:
  246. // -(x * 2) = x * -2
  247. // -(2 * x) = x * -2
  248. // -(x / 2) = x / -2
  249. // -(2 / x) = -2 / x
  250. FoldingRule MergeNegateMulDivArithmetic() {
  251. return [](IRContext* context, Instruction* inst,
  252. const std::vector<const analysis::Constant*>& constants) {
  253. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  254. (void)constants;
  255. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  256. const analysis::Type* type =
  257. context->get_type_mgr()->GetType(inst->type_id());
  258. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  259. return false;
  260. Instruction* op_inst =
  261. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  262. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  263. return false;
  264. uint32_t width = ElementWidth(type);
  265. if (width != 32 && width != 64) return false;
  266. SpvOp opcode = op_inst->opcode();
  267. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  268. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  269. std::vector<const analysis::Constant*> op_constants =
  270. const_mgr->GetOperandConstants(op_inst);
  271. // Merge negate into mul or div if one operand is constant.
  272. if (op_constants[0] || op_constants[1]) {
  273. bool zero_is_variable = op_constants[0] == nullptr;
  274. const analysis::Constant* c = ConstInput(op_constants);
  275. uint32_t neg_id = NegateConstant(const_mgr, c);
  276. uint32_t non_const_id = zero_is_variable
  277. ? op_inst->GetSingleWordInOperand(0u)
  278. : op_inst->GetSingleWordInOperand(1u);
  279. // Change this instruction to a mul/div.
  280. inst->SetOpcode(op_inst->opcode());
  281. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  282. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  283. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  284. inst->SetInOperands(
  285. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  286. } else {
  287. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  288. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  289. }
  290. return true;
  291. }
  292. }
  293. return false;
  294. };
  295. }
  296. // Merges negate into a add or sub operation if that operation contains a
  297. // constant operand.
  298. // Cases:
  299. // -(x + 2) = -2 - x
  300. // -(2 + x) = -2 - x
  301. // -(x - 2) = 2 - x
  302. // -(2 - x) = x - 2
  303. FoldingRule MergeNegateAddSubArithmetic() {
  304. return [](IRContext* context, Instruction* inst,
  305. const std::vector<const analysis::Constant*>& constants) {
  306. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  307. (void)constants;
  308. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  309. const analysis::Type* type =
  310. context->get_type_mgr()->GetType(inst->type_id());
  311. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  312. return false;
  313. Instruction* op_inst =
  314. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  315. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  316. return false;
  317. uint32_t width = ElementWidth(type);
  318. if (width != 32 && width != 64) return false;
  319. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  320. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  321. std::vector<const analysis::Constant*> op_constants =
  322. const_mgr->GetOperandConstants(op_inst);
  323. if (op_constants[0] || op_constants[1]) {
  324. bool zero_is_variable = op_constants[0] == nullptr;
  325. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  326. (op_inst->opcode() == SpvOpIAdd);
  327. bool swap_operands = !is_add || zero_is_variable;
  328. bool negate_const = is_add;
  329. const analysis::Constant* c = ConstInput(op_constants);
  330. uint32_t const_id = 0;
  331. if (negate_const) {
  332. const_id = NegateConstant(const_mgr, c);
  333. } else {
  334. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  335. : op_inst->GetSingleWordInOperand(0u);
  336. }
  337. // Swap operands if necessary and make the instruction a subtraction.
  338. uint32_t op0 =
  339. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  340. uint32_t op1 =
  341. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  342. if (swap_operands) std::swap(op0, op1);
  343. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  344. inst->SetInOperands(
  345. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  346. return true;
  347. }
  348. }
  349. return false;
  350. };
  351. }
  352. // Returns true if |c| has a zero element.
  353. bool HasZero(const analysis::Constant* c) {
  354. if (c->AsNullConstant()) {
  355. return true;
  356. }
  357. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  358. for (auto& comp : vec_const->GetComponents())
  359. if (HasZero(comp)) return true;
  360. } else {
  361. assert(c->AsScalarConstant());
  362. return c->AsScalarConstant()->IsZero();
  363. }
  364. return false;
  365. }
  366. // Performs |input1| |opcode| |input2| and returns the merged constant result
  367. // id. Returns 0 if the result is not a valid value. The input types must be
  368. // Float.
  369. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  370. SpvOp opcode,
  371. const analysis::Constant* input1,
  372. const analysis::Constant* input2) {
  373. const analysis::Type* type = input1->type();
  374. assert(type->AsFloat());
  375. uint32_t width = type->AsFloat()->width();
  376. assert(width == 32 || width == 64);
  377. std::vector<uint32_t> words;
  378. #define FOLD_OP(op) \
  379. if (width == 64) { \
  380. utils::FloatProxy<double> val = \
  381. input1->GetDouble() op input2->GetDouble(); \
  382. double dval = val.getAsFloat(); \
  383. if (!IsValidResult(dval)) return 0; \
  384. words = val.GetWords(); \
  385. } else { \
  386. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  387. float fval = val.getAsFloat(); \
  388. if (!IsValidResult(fval)) return 0; \
  389. words = val.GetWords(); \
  390. }
  391. switch (opcode) {
  392. case SpvOpFMul:
  393. FOLD_OP(*);
  394. break;
  395. case SpvOpFDiv:
  396. if (HasZero(input2)) return 0;
  397. FOLD_OP(/);
  398. break;
  399. case SpvOpFAdd:
  400. FOLD_OP(+);
  401. break;
  402. case SpvOpFSub:
  403. FOLD_OP(-);
  404. break;
  405. default:
  406. assert(false && "Unexpected operation");
  407. break;
  408. }
  409. #undef FOLD_OP
  410. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  411. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  412. }
  413. // Performs |input1| |opcode| |input2| and returns the merged constant result
  414. // id. Returns 0 if the result is not a valid value. The input types must be
  415. // Integers.
  416. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  417. SpvOp opcode, const analysis::Constant* input1,
  418. const analysis::Constant* input2) {
  419. assert(input1->type()->AsInteger());
  420. const analysis::Integer* type = input1->type()->AsInteger();
  421. uint32_t width = type->AsInteger()->width();
  422. assert(width == 32 || width == 64);
  423. std::vector<uint32_t> words;
  424. #define FOLD_OP(op) \
  425. if (width == 64) { \
  426. if (type->IsSigned()) { \
  427. int64_t val = input1->GetS64() op input2->GetS64(); \
  428. words = ExtractInts(static_cast<uint64_t>(val)); \
  429. } else { \
  430. uint64_t val = input1->GetU64() op input2->GetU64(); \
  431. words = ExtractInts(val); \
  432. } \
  433. } else { \
  434. if (type->IsSigned()) { \
  435. int32_t val = input1->GetS32() op input2->GetS32(); \
  436. words.push_back(static_cast<uint32_t>(val)); \
  437. } else { \
  438. uint32_t val = input1->GetU32() op input2->GetU32(); \
  439. words.push_back(val); \
  440. } \
  441. }
  442. switch (opcode) {
  443. case SpvOpIMul:
  444. FOLD_OP(*);
  445. break;
  446. case SpvOpSDiv:
  447. case SpvOpUDiv:
  448. assert(false && "Should not merge integer division");
  449. break;
  450. case SpvOpIAdd:
  451. FOLD_OP(+);
  452. break;
  453. case SpvOpISub:
  454. FOLD_OP(-);
  455. break;
  456. default:
  457. assert(false && "Unexpected operation");
  458. break;
  459. }
  460. #undef FOLD_OP
  461. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  462. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  463. }
  464. // Performs |input1| |opcode| |input2| and returns the merged constant result
  465. // id. Returns 0 if the result is not a valid value. The input types must be
  466. // Integers, Floats or Vectors of such.
  467. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  468. const analysis::Constant* input1,
  469. const analysis::Constant* input2) {
  470. assert(input1 && input2);
  471. assert(input1->type() == input2->type());
  472. const analysis::Type* type = input1->type();
  473. std::vector<uint32_t> words;
  474. if (const analysis::Vector* vector_type = type->AsVector()) {
  475. const analysis::Type* ele_type = vector_type->element_type();
  476. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  477. uint32_t id = 0;
  478. const analysis::Constant* input1_comp = nullptr;
  479. if (const analysis::VectorConstant* input1_vector =
  480. input1->AsVectorConstant()) {
  481. input1_comp = input1_vector->GetComponents()[i];
  482. } else {
  483. assert(input1->AsNullConstant());
  484. input1_comp = const_mgr->GetConstant(ele_type, {});
  485. }
  486. const analysis::Constant* input2_comp = nullptr;
  487. if (const analysis::VectorConstant* input2_vector =
  488. input2->AsVectorConstant()) {
  489. input2_comp = input2_vector->GetComponents()[i];
  490. } else {
  491. assert(input2->AsNullConstant());
  492. input2_comp = const_mgr->GetConstant(ele_type, {});
  493. }
  494. if (ele_type->AsFloat()) {
  495. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  496. input2_comp);
  497. } else {
  498. assert(ele_type->AsInteger());
  499. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  500. input2_comp);
  501. }
  502. if (id == 0) return 0;
  503. words.push_back(id);
  504. }
  505. const analysis::Constant* merged_const =
  506. const_mgr->GetConstant(type, words);
  507. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  508. } else if (type->AsFloat()) {
  509. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  510. } else {
  511. assert(type->AsInteger());
  512. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  513. }
  514. }
  515. // Merges consecutive multiplies where each contains one constant operand.
  516. // Cases:
  517. // 2 * (x * 2) = x * 4
  518. // 2 * (2 * x) = x * 4
  519. // (x * 2) * 2 = x * 4
  520. // (2 * x) * 2 = x * 4
  521. FoldingRule MergeMulMulArithmetic() {
  522. return [](IRContext* context, Instruction* inst,
  523. const std::vector<const analysis::Constant*>& constants) {
  524. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  525. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  526. const analysis::Type* type =
  527. context->get_type_mgr()->GetType(inst->type_id());
  528. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  529. return false;
  530. uint32_t width = ElementWidth(type);
  531. if (width != 32 && width != 64) return false;
  532. // Determine the constant input and the variable input in |inst|.
  533. const analysis::Constant* const_input1 = ConstInput(constants);
  534. if (!const_input1) return false;
  535. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  536. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  537. return false;
  538. if (other_inst->opcode() == inst->opcode()) {
  539. std::vector<const analysis::Constant*> other_constants =
  540. const_mgr->GetOperandConstants(other_inst);
  541. const analysis::Constant* const_input2 = ConstInput(other_constants);
  542. if (!const_input2) return false;
  543. bool other_first_is_variable = other_constants[0] == nullptr;
  544. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  545. const_input1, const_input2);
  546. if (merged_id == 0) return false;
  547. uint32_t non_const_id = other_first_is_variable
  548. ? other_inst->GetSingleWordInOperand(0u)
  549. : other_inst->GetSingleWordInOperand(1u);
  550. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  551. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  552. return true;
  553. }
  554. return false;
  555. };
  556. }
  557. // Merges divides into subsequent multiplies if each instruction contains one
  558. // constant operand. Does not support integer operations.
  559. // Cases:
  560. // 2 * (x / 2) = x * 1
  561. // 2 * (2 / x) = 4 / x
  562. // (x / 2) * 2 = x * 1
  563. // (2 / x) * 2 = 4 / x
  564. // (y / x) * x = y
  565. // x * (y / x) = y
  566. FoldingRule MergeMulDivArithmetic() {
  567. return [](IRContext* context, Instruction* inst,
  568. const std::vector<const analysis::Constant*>& constants) {
  569. assert(inst->opcode() == SpvOpFMul);
  570. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  571. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  572. const analysis::Type* type =
  573. context->get_type_mgr()->GetType(inst->type_id());
  574. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  575. uint32_t width = ElementWidth(type);
  576. if (width != 32 && width != 64) return false;
  577. for (uint32_t i = 0; i < 2; i++) {
  578. uint32_t op_id = inst->GetSingleWordInOperand(i);
  579. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  580. if (op_inst->opcode() == SpvOpFDiv) {
  581. if (op_inst->GetSingleWordInOperand(1) ==
  582. inst->GetSingleWordInOperand(1 - i)) {
  583. inst->SetOpcode(SpvOpCopyObject);
  584. inst->SetInOperands(
  585. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  586. return true;
  587. }
  588. }
  589. }
  590. const analysis::Constant* const_input1 = ConstInput(constants);
  591. if (!const_input1) return false;
  592. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  593. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  594. if (other_inst->opcode() == SpvOpFDiv) {
  595. std::vector<const analysis::Constant*> other_constants =
  596. const_mgr->GetOperandConstants(other_inst);
  597. const analysis::Constant* const_input2 = ConstInput(other_constants);
  598. if (!const_input2 || HasZero(const_input2)) return false;
  599. bool other_first_is_variable = other_constants[0] == nullptr;
  600. // If the variable value is the second operand of the divide, multiply
  601. // the constants together. Otherwise divide the constants.
  602. uint32_t merged_id = PerformOperation(
  603. const_mgr,
  604. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  605. const_input1, const_input2);
  606. if (merged_id == 0) return false;
  607. uint32_t non_const_id = other_first_is_variable
  608. ? other_inst->GetSingleWordInOperand(0u)
  609. : other_inst->GetSingleWordInOperand(1u);
  610. // If the variable value is on the second operand of the div, then this
  611. // operation is a div. Otherwise it should be a multiply.
  612. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  613. : other_inst->opcode());
  614. if (other_first_is_variable) {
  615. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  616. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  617. } else {
  618. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  619. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  620. }
  621. return true;
  622. }
  623. return false;
  624. };
  625. }
  626. // Merges multiply of constant and negation.
  627. // Cases:
  628. // (-x) * 2 = x * -2
  629. // 2 * (-x) = x * -2
  630. FoldingRule MergeMulNegateArithmetic() {
  631. return [](IRContext* context, Instruction* inst,
  632. const std::vector<const analysis::Constant*>& constants) {
  633. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  634. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  635. const analysis::Type* type =
  636. context->get_type_mgr()->GetType(inst->type_id());
  637. bool uses_float = HasFloatingPoint(type);
  638. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  639. uint32_t width = ElementWidth(type);
  640. if (width != 32 && width != 64) return false;
  641. const analysis::Constant* const_input1 = ConstInput(constants);
  642. if (!const_input1) return false;
  643. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  644. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  645. return false;
  646. if (other_inst->opcode() == SpvOpFNegate ||
  647. other_inst->opcode() == SpvOpSNegate) {
  648. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  649. inst->SetInOperands(
  650. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  651. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  652. return true;
  653. }
  654. return false;
  655. };
  656. }
  657. // Merges consecutive divides if each instruction contains one constant operand.
  658. // Does not support integer division.
  659. // Cases:
  660. // 2 / (x / 2) = 4 / x
  661. // 4 / (2 / x) = 2 * x
  662. // (4 / x) / 2 = 2 / x
  663. // (x / 2) / 2 = x / 4
  664. FoldingRule MergeDivDivArithmetic() {
  665. return [](IRContext* context, Instruction* inst,
  666. const std::vector<const analysis::Constant*>& constants) {
  667. assert(inst->opcode() == SpvOpFDiv);
  668. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  669. const analysis::Type* type =
  670. context->get_type_mgr()->GetType(inst->type_id());
  671. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  672. uint32_t width = ElementWidth(type);
  673. if (width != 32 && width != 64) return false;
  674. const analysis::Constant* const_input1 = ConstInput(constants);
  675. if (!const_input1 || HasZero(const_input1)) return false;
  676. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  677. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  678. bool first_is_variable = constants[0] == nullptr;
  679. if (other_inst->opcode() == inst->opcode()) {
  680. std::vector<const analysis::Constant*> other_constants =
  681. const_mgr->GetOperandConstants(other_inst);
  682. const analysis::Constant* const_input2 = ConstInput(other_constants);
  683. if (!const_input2 || HasZero(const_input2)) return false;
  684. bool other_first_is_variable = other_constants[0] == nullptr;
  685. SpvOp merge_op = inst->opcode();
  686. if (other_first_is_variable) {
  687. // Constants magnify.
  688. merge_op = SpvOpFMul;
  689. }
  690. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  691. // because it is commutative.
  692. if (first_is_variable) std::swap(const_input1, const_input2);
  693. uint32_t merged_id =
  694. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  695. if (merged_id == 0) return false;
  696. uint32_t non_const_id = other_first_is_variable
  697. ? other_inst->GetSingleWordInOperand(0u)
  698. : other_inst->GetSingleWordInOperand(1u);
  699. SpvOp op = inst->opcode();
  700. if (!first_is_variable && !other_first_is_variable) {
  701. // Effectively div of 1/x, so change to multiply.
  702. op = SpvOpFMul;
  703. }
  704. uint32_t op1 = merged_id;
  705. uint32_t op2 = non_const_id;
  706. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  707. inst->SetOpcode(op);
  708. inst->SetInOperands(
  709. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  710. return true;
  711. }
  712. return false;
  713. };
  714. }
  715. // Fold multiplies succeeded by divides where each instruction contains a
  716. // constant operand. Does not support integer divide.
  717. // Cases:
  718. // 4 / (x * 2) = 2 / x
  719. // 4 / (2 * x) = 2 / x
  720. // (x * 4) / 2 = x * 2
  721. // (4 * x) / 2 = x * 2
  722. // (x * y) / x = y
  723. // (y * x) / x = y
  724. FoldingRule MergeDivMulArithmetic() {
  725. return [](IRContext* context, Instruction* inst,
  726. const std::vector<const analysis::Constant*>& constants) {
  727. assert(inst->opcode() == SpvOpFDiv);
  728. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  729. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  730. const analysis::Type* type =
  731. context->get_type_mgr()->GetType(inst->type_id());
  732. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  733. uint32_t width = ElementWidth(type);
  734. if (width != 32 && width != 64) return false;
  735. uint32_t op_id = inst->GetSingleWordInOperand(0);
  736. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  737. if (op_inst->opcode() == SpvOpFMul) {
  738. for (uint32_t i = 0; i < 2; i++) {
  739. if (op_inst->GetSingleWordInOperand(i) ==
  740. inst->GetSingleWordInOperand(1)) {
  741. inst->SetOpcode(SpvOpCopyObject);
  742. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  743. {op_inst->GetSingleWordInOperand(1 - i)}}});
  744. return true;
  745. }
  746. }
  747. }
  748. const analysis::Constant* const_input1 = ConstInput(constants);
  749. if (!const_input1 || HasZero(const_input1)) return false;
  750. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  751. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  752. bool first_is_variable = constants[0] == nullptr;
  753. if (other_inst->opcode() == SpvOpFMul) {
  754. std::vector<const analysis::Constant*> other_constants =
  755. const_mgr->GetOperandConstants(other_inst);
  756. const analysis::Constant* const_input2 = ConstInput(other_constants);
  757. if (!const_input2) return false;
  758. bool other_first_is_variable = other_constants[0] == nullptr;
  759. // This is an x / (*) case. Swap the inputs.
  760. if (first_is_variable) std::swap(const_input1, const_input2);
  761. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  762. const_input1, const_input2);
  763. if (merged_id == 0) return false;
  764. uint32_t non_const_id = other_first_is_variable
  765. ? other_inst->GetSingleWordInOperand(0u)
  766. : other_inst->GetSingleWordInOperand(1u);
  767. uint32_t op1 = merged_id;
  768. uint32_t op2 = non_const_id;
  769. if (first_is_variable) std::swap(op1, op2);
  770. // Convert to multiply
  771. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  772. inst->SetInOperands(
  773. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  774. return true;
  775. }
  776. return false;
  777. };
  778. }
  779. // Fold divides of a constant and a negation.
  780. // Cases:
  781. // (-x) / 2 = x / -2
  782. // 2 / (-x) = 2 / -x
  783. FoldingRule MergeDivNegateArithmetic() {
  784. return [](IRContext* context, Instruction* inst,
  785. const std::vector<const analysis::Constant*>& constants) {
  786. assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
  787. inst->opcode() == SpvOpUDiv);
  788. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  789. const analysis::Type* type =
  790. context->get_type_mgr()->GetType(inst->type_id());
  791. bool uses_float = HasFloatingPoint(type);
  792. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  793. uint32_t width = ElementWidth(type);
  794. if (width != 32 && width != 64) return false;
  795. const analysis::Constant* const_input1 = ConstInput(constants);
  796. if (!const_input1) return false;
  797. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  798. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  799. return false;
  800. bool first_is_variable = constants[0] == nullptr;
  801. if (other_inst->opcode() == SpvOpFNegate ||
  802. other_inst->opcode() == SpvOpSNegate) {
  803. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  804. if (first_is_variable) {
  805. inst->SetInOperands(
  806. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  807. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  808. } else {
  809. inst->SetInOperands(
  810. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  811. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  812. }
  813. return true;
  814. }
  815. return false;
  816. };
  817. }
  818. // Folds addition of a constant and a negation.
  819. // Cases:
  820. // (-x) + 2 = 2 - x
  821. // 2 + (-x) = 2 - x
  822. FoldingRule MergeAddNegateArithmetic() {
  823. return [](IRContext* context, Instruction* inst,
  824. const std::vector<const analysis::Constant*>& constants) {
  825. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  826. const analysis::Type* type =
  827. context->get_type_mgr()->GetType(inst->type_id());
  828. bool uses_float = HasFloatingPoint(type);
  829. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  830. const analysis::Constant* const_input1 = ConstInput(constants);
  831. if (!const_input1) return false;
  832. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  833. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  834. return false;
  835. if (other_inst->opcode() == SpvOpSNegate ||
  836. other_inst->opcode() == SpvOpFNegate) {
  837. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  838. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  839. : inst->GetSingleWordInOperand(1u);
  840. inst->SetInOperands(
  841. {{SPV_OPERAND_TYPE_ID, {const_id}},
  842. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  843. return true;
  844. }
  845. return false;
  846. };
  847. }
  848. // Folds subtraction of a constant and a negation.
  849. // Cases:
  850. // (-x) - 2 = -2 - x
  851. // 2 - (-x) = x + 2
  852. FoldingRule MergeSubNegateArithmetic() {
  853. return [](IRContext* context, Instruction* inst,
  854. const std::vector<const analysis::Constant*>& constants) {
  855. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  856. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  857. const analysis::Type* type =
  858. context->get_type_mgr()->GetType(inst->type_id());
  859. bool uses_float = HasFloatingPoint(type);
  860. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  861. uint32_t width = ElementWidth(type);
  862. if (width != 32 && width != 64) return false;
  863. const analysis::Constant* const_input1 = ConstInput(constants);
  864. if (!const_input1) return false;
  865. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  866. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  867. return false;
  868. if (other_inst->opcode() == SpvOpSNegate ||
  869. other_inst->opcode() == SpvOpFNegate) {
  870. uint32_t op1 = 0;
  871. uint32_t op2 = 0;
  872. SpvOp opcode = inst->opcode();
  873. if (constants[0] != nullptr) {
  874. op1 = other_inst->GetSingleWordInOperand(0u);
  875. op2 = inst->GetSingleWordInOperand(0u);
  876. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  877. } else {
  878. op1 = NegateConstant(const_mgr, const_input1);
  879. op2 = other_inst->GetSingleWordInOperand(0u);
  880. }
  881. inst->SetOpcode(opcode);
  882. inst->SetInOperands(
  883. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  884. return true;
  885. }
  886. return false;
  887. };
  888. }
  889. // Folds addition of an addition where each operation has a constant operand.
  890. // Cases:
  891. // (x + 2) + 2 = x + 4
  892. // (2 + x) + 2 = x + 4
  893. // 2 + (x + 2) = x + 4
  894. // 2 + (2 + x) = x + 4
  895. FoldingRule MergeAddAddArithmetic() {
  896. return [](IRContext* context, Instruction* inst,
  897. const std::vector<const analysis::Constant*>& constants) {
  898. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  899. const analysis::Type* type =
  900. context->get_type_mgr()->GetType(inst->type_id());
  901. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  902. bool uses_float = HasFloatingPoint(type);
  903. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  904. uint32_t width = ElementWidth(type);
  905. if (width != 32 && width != 64) return false;
  906. const analysis::Constant* const_input1 = ConstInput(constants);
  907. if (!const_input1) return false;
  908. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  909. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  910. return false;
  911. if (other_inst->opcode() == SpvOpFAdd ||
  912. other_inst->opcode() == SpvOpIAdd) {
  913. std::vector<const analysis::Constant*> other_constants =
  914. const_mgr->GetOperandConstants(other_inst);
  915. const analysis::Constant* const_input2 = ConstInput(other_constants);
  916. if (!const_input2) return false;
  917. Instruction* non_const_input =
  918. NonConstInput(context, other_constants[0], other_inst);
  919. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  920. const_input1, const_input2);
  921. if (merged_id == 0) return false;
  922. inst->SetInOperands(
  923. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  924. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  925. return true;
  926. }
  927. return false;
  928. };
  929. }
  930. // Folds addition of a subtraction where each operation has a constant operand.
  931. // Cases:
  932. // (x - 2) + 2 = x + 0
  933. // (2 - x) + 2 = 4 - x
  934. // 2 + (x - 2) = x + 0
  935. // 2 + (2 - x) = 4 - x
  936. FoldingRule MergeAddSubArithmetic() {
  937. return [](IRContext* context, Instruction* inst,
  938. const std::vector<const analysis::Constant*>& constants) {
  939. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  940. const analysis::Type* type =
  941. context->get_type_mgr()->GetType(inst->type_id());
  942. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  943. bool uses_float = HasFloatingPoint(type);
  944. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  945. uint32_t width = ElementWidth(type);
  946. if (width != 32 && width != 64) return false;
  947. const analysis::Constant* const_input1 = ConstInput(constants);
  948. if (!const_input1) return false;
  949. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  950. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  951. return false;
  952. if (other_inst->opcode() == SpvOpFSub ||
  953. other_inst->opcode() == SpvOpISub) {
  954. std::vector<const analysis::Constant*> other_constants =
  955. const_mgr->GetOperandConstants(other_inst);
  956. const analysis::Constant* const_input2 = ConstInput(other_constants);
  957. if (!const_input2) return false;
  958. bool first_is_variable = other_constants[0] == nullptr;
  959. SpvOp op = inst->opcode();
  960. uint32_t op1 = 0;
  961. uint32_t op2 = 0;
  962. if (first_is_variable) {
  963. // Subtract constants. Non-constant operand is first.
  964. op1 = other_inst->GetSingleWordInOperand(0u);
  965. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  966. const_input2);
  967. } else {
  968. // Add constants. Constant operand is first. Change the opcode.
  969. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  970. const_input2);
  971. op2 = other_inst->GetSingleWordInOperand(1u);
  972. op = other_inst->opcode();
  973. }
  974. if (op1 == 0 || op2 == 0) return false;
  975. inst->SetOpcode(op);
  976. inst->SetInOperands(
  977. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  978. return true;
  979. }
  980. return false;
  981. };
  982. }
  983. // Folds subtraction of an addition where each operand has a constant operand.
  984. // Cases:
  985. // (x + 2) - 2 = x + 0
  986. // (2 + x) - 2 = x + 0
  987. // 2 - (x + 2) = 0 - x
  988. // 2 - (2 + x) = 0 - x
  989. FoldingRule MergeSubAddArithmetic() {
  990. return [](IRContext* context, Instruction* inst,
  991. const std::vector<const analysis::Constant*>& constants) {
  992. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  993. const analysis::Type* type =
  994. context->get_type_mgr()->GetType(inst->type_id());
  995. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  996. bool uses_float = HasFloatingPoint(type);
  997. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  998. uint32_t width = ElementWidth(type);
  999. if (width != 32 && width != 64) return false;
  1000. const analysis::Constant* const_input1 = ConstInput(constants);
  1001. if (!const_input1) return false;
  1002. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1003. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1004. return false;
  1005. if (other_inst->opcode() == SpvOpFAdd ||
  1006. other_inst->opcode() == SpvOpIAdd) {
  1007. std::vector<const analysis::Constant*> other_constants =
  1008. const_mgr->GetOperandConstants(other_inst);
  1009. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1010. if (!const_input2) return false;
  1011. Instruction* non_const_input =
  1012. NonConstInput(context, other_constants[0], other_inst);
  1013. // If the first operand of the sub is not a constant, swap the constants
  1014. // so the subtraction has the correct operands.
  1015. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1016. // Subtract the constants.
  1017. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1018. const_input1, const_input2);
  1019. SpvOp op = inst->opcode();
  1020. uint32_t op1 = 0;
  1021. uint32_t op2 = 0;
  1022. if (constants[0] == nullptr) {
  1023. // Non-constant operand is first. Change the opcode.
  1024. op1 = non_const_input->result_id();
  1025. op2 = merged_id;
  1026. op = other_inst->opcode();
  1027. } else {
  1028. // Constant operand is first.
  1029. op1 = merged_id;
  1030. op2 = non_const_input->result_id();
  1031. }
  1032. if (op1 == 0 || op2 == 0) return false;
  1033. inst->SetOpcode(op);
  1034. inst->SetInOperands(
  1035. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1036. return true;
  1037. }
  1038. return false;
  1039. };
  1040. }
  1041. // Folds subtraction of a subtraction where each operand has a constant operand.
  1042. // Cases:
  1043. // (x - 2) - 2 = x - 4
  1044. // (2 - x) - 2 = 0 - x
  1045. // 2 - (x - 2) = 4 - x
  1046. // 2 - (2 - x) = x + 0
  1047. FoldingRule MergeSubSubArithmetic() {
  1048. return [](IRContext* context, Instruction* inst,
  1049. const std::vector<const analysis::Constant*>& constants) {
  1050. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1051. const analysis::Type* type =
  1052. context->get_type_mgr()->GetType(inst->type_id());
  1053. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1054. bool uses_float = HasFloatingPoint(type);
  1055. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1056. uint32_t width = ElementWidth(type);
  1057. if (width != 32 && width != 64) return false;
  1058. const analysis::Constant* const_input1 = ConstInput(constants);
  1059. if (!const_input1) return false;
  1060. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1061. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1062. return false;
  1063. if (other_inst->opcode() == SpvOpFSub ||
  1064. other_inst->opcode() == SpvOpISub) {
  1065. std::vector<const analysis::Constant*> other_constants =
  1066. const_mgr->GetOperandConstants(other_inst);
  1067. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1068. if (!const_input2) return false;
  1069. Instruction* non_const_input =
  1070. NonConstInput(context, other_constants[0], other_inst);
  1071. // Merge the constants.
  1072. uint32_t merged_id = 0;
  1073. SpvOp merge_op = inst->opcode();
  1074. if (other_constants[0] == nullptr) {
  1075. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1076. } else if (constants[0] == nullptr) {
  1077. std::swap(const_input1, const_input2);
  1078. }
  1079. merged_id =
  1080. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1081. if (merged_id == 0) return false;
  1082. SpvOp op = inst->opcode();
  1083. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1084. // Change the operation.
  1085. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1086. }
  1087. uint32_t op1 = 0;
  1088. uint32_t op2 = 0;
  1089. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1090. op1 = merged_id;
  1091. op2 = non_const_input->result_id();
  1092. } else {
  1093. op1 = non_const_input->result_id();
  1094. op2 = merged_id;
  1095. }
  1096. inst->SetOpcode(op);
  1097. inst->SetInOperands(
  1098. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1099. return true;
  1100. }
  1101. return false;
  1102. };
  1103. }
  1104. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1105. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1106. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1107. IRContext* context = inst->context();
  1108. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1109. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1110. if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
  1111. return false;
  1112. if (sub_inst->opcode() == SpvOpFSub &&
  1113. !sub_inst->IsFloatingPointFoldingAllowed())
  1114. return false;
  1115. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1116. inst->SetOpcode(SpvOpCopyObject);
  1117. inst->SetInOperands(
  1118. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1119. context->UpdateDefUse(inst);
  1120. return true;
  1121. }
  1122. // Folds addition of a subtraction where the subtrahend is equal to the
  1123. // other addend. Return a copy of the minuend. Accepts generic (const and
  1124. // non-const) operands.
  1125. // Cases:
  1126. // (a - b) + b = a
  1127. // b + (a - b) = a
  1128. FoldingRule MergeGenericAddSubArithmetic() {
  1129. return [](IRContext* context, Instruction* inst,
  1130. const std::vector<const analysis::Constant*>&) {
  1131. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1132. const analysis::Type* type =
  1133. context->get_type_mgr()->GetType(inst->type_id());
  1134. bool uses_float = HasFloatingPoint(type);
  1135. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1136. uint32_t width = ElementWidth(type);
  1137. if (width != 32 && width != 64) return false;
  1138. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1139. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1140. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1141. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1142. };
  1143. }
  1144. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1145. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1146. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1147. uint32_t factor1_0, uint32_t factor1_1,
  1148. Instruction* inst) {
  1149. IRContext* context = inst->context();
  1150. if (factor0_0 != factor1_0) return false;
  1151. InstructionBuilder ir_builder(
  1152. context, inst,
  1153. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1154. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1155. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1156. inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
  1157. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1158. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1159. context->UpdateDefUse(inst);
  1160. return true;
  1161. }
  1162. // Perform the following factoring identity, handling all operand order
  1163. // combinations: (a * b) + (a * c) = a * (b + c)
  1164. FoldingRule FactorAddMuls() {
  1165. return [](IRContext* context, Instruction* inst,
  1166. const std::vector<const analysis::Constant*>&) {
  1167. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1168. const analysis::Type* type =
  1169. context->get_type_mgr()->GetType(inst->type_id());
  1170. bool uses_float = HasFloatingPoint(type);
  1171. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1172. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1173. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1174. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1175. if (add_op0_inst->opcode() != SpvOpFMul &&
  1176. add_op0_inst->opcode() != SpvOpIMul)
  1177. return false;
  1178. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1179. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1180. if (add_op1_inst->opcode() != SpvOpFMul &&
  1181. add_op1_inst->opcode() != SpvOpIMul)
  1182. return false;
  1183. // Only perform this optimization if both of the muls only have one use.
  1184. // Otherwise this is a deoptimization in size and performance.
  1185. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1186. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1187. if (add_op0_inst->opcode() == SpvOpFMul &&
  1188. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1189. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1190. return false;
  1191. for (int i = 0; i < 2; i++) {
  1192. for (int j = 0; j < 2; j++) {
  1193. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1194. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1195. add_op0_inst->GetSingleWordInOperand(1 - i),
  1196. add_op1_inst->GetSingleWordInOperand(j),
  1197. add_op1_inst->GetSingleWordInOperand(1 - j),
  1198. inst))
  1199. return true;
  1200. }
  1201. }
  1202. return false;
  1203. };
  1204. }
  1205. FoldingRule IntMultipleBy1() {
  1206. return [](IRContext*, Instruction* inst,
  1207. const std::vector<const analysis::Constant*>& constants) {
  1208. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1209. for (uint32_t i = 0; i < 2; i++) {
  1210. if (constants[i] == nullptr) {
  1211. continue;
  1212. }
  1213. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1214. if (int_constant) {
  1215. uint32_t width = ElementWidth(int_constant->type());
  1216. if (width != 32 && width != 64) return false;
  1217. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1218. : int_constant->GetU64BitValue() == 1ull;
  1219. if (is_one) {
  1220. inst->SetOpcode(SpvOpCopyObject);
  1221. inst->SetInOperands(
  1222. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1223. return true;
  1224. }
  1225. }
  1226. }
  1227. return false;
  1228. };
  1229. }
  1230. FoldingRule CompositeConstructFeedingExtract() {
  1231. return [](IRContext* context, Instruction* inst,
  1232. const std::vector<const analysis::Constant*>&) {
  1233. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1234. // then we can simply use the appropriate element in the construction.
  1235. assert(inst->opcode() == SpvOpCompositeExtract &&
  1236. "Wrong opcode. Should be OpCompositeExtract.");
  1237. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1238. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1239. // If there are no index operands, then this rule cannot do anything.
  1240. if (inst->NumInOperands() <= 1) {
  1241. return false;
  1242. }
  1243. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1244. Instruction* cinst = def_use_mgr->GetDef(cid);
  1245. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1246. return false;
  1247. }
  1248. std::vector<Operand> operands;
  1249. analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
  1250. if (composite_type->AsVector() == nullptr) {
  1251. // Get the element being extracted from the OpCompositeConstruct
  1252. // Since it is not a vector, it is simple to extract the single element.
  1253. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1254. uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
  1255. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1256. // Add the remaining indices for extraction.
  1257. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1258. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1259. {inst->GetSingleWordInOperand(i)}});
  1260. }
  1261. } else {
  1262. // With vectors we have to handle the case where it is concatenating
  1263. // vectors.
  1264. assert(inst->NumInOperands() == 2 &&
  1265. "Expecting a vector of scalar values.");
  1266. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1267. for (uint32_t construct_index = 0;
  1268. construct_index < cinst->NumInOperands(); ++construct_index) {
  1269. uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
  1270. Instruction* element_def = def_use_mgr->GetDef(element_id);
  1271. analysis::Vector* element_type =
  1272. type_mgr->GetType(element_def->type_id())->AsVector();
  1273. if (element_type) {
  1274. uint32_t vector_size = element_type->element_count();
  1275. if (vector_size < element_index) {
  1276. // The element we want comes after this vector.
  1277. element_index -= vector_size;
  1278. } else {
  1279. // We want an element of this vector.
  1280. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1281. operands.push_back(
  1282. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
  1283. break;
  1284. }
  1285. } else {
  1286. if (element_index == 0) {
  1287. // This is a scalar, and we this is the element we are extracting.
  1288. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1289. break;
  1290. } else {
  1291. // Skip over this scalar value.
  1292. --element_index;
  1293. }
  1294. }
  1295. }
  1296. }
  1297. // If there were no extra indices, then we have the final object. No need
  1298. // to extract even more.
  1299. if (operands.size() == 1) {
  1300. inst->SetOpcode(SpvOpCopyObject);
  1301. }
  1302. inst->SetInOperands(std::move(operands));
  1303. return true;
  1304. };
  1305. }
  1306. FoldingRule CompositeExtractFeedingConstruct() {
  1307. // If the OpCompositeConstruct is simply putting back together elements that
  1308. // where extracted from the same souce, we can simlpy reuse the source.
  1309. //
  1310. // This is a common code pattern because of the way that scalar replacement
  1311. // works.
  1312. return [](IRContext* context, Instruction* inst,
  1313. const std::vector<const analysis::Constant*>&) {
  1314. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1315. "Wrong opcode. Should be OpCompositeConstruct.");
  1316. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1317. uint32_t original_id = 0;
  1318. // Check each element to make sure they are:
  1319. // - extractions
  1320. // - extracting the same position they are inserting
  1321. // - all extract from the same id.
  1322. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1323. uint32_t element_id = inst->GetSingleWordInOperand(i);
  1324. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1325. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1326. return false;
  1327. }
  1328. if (element_inst->NumInOperands() != 2) {
  1329. return false;
  1330. }
  1331. if (element_inst->GetSingleWordInOperand(1) != i) {
  1332. return false;
  1333. }
  1334. if (i == 0) {
  1335. original_id =
  1336. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1337. } else if (original_id != element_inst->GetSingleWordInOperand(
  1338. kExtractCompositeIdInIdx)) {
  1339. return false;
  1340. }
  1341. }
  1342. // The last check it to see that the object being extracted from is the
  1343. // correct type.
  1344. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1345. if (original_inst->type_id() != inst->type_id()) {
  1346. return false;
  1347. }
  1348. // Simplify by using the original object.
  1349. inst->SetOpcode(SpvOpCopyObject);
  1350. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1351. return true;
  1352. };
  1353. }
  1354. FoldingRule InsertFeedingExtract() {
  1355. return [](IRContext* context, Instruction* inst,
  1356. const std::vector<const analysis::Constant*>&) {
  1357. assert(inst->opcode() == SpvOpCompositeExtract &&
  1358. "Wrong opcode. Should be OpCompositeExtract.");
  1359. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1360. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1361. Instruction* cinst = def_use_mgr->GetDef(cid);
  1362. if (cinst->opcode() != SpvOpCompositeInsert) {
  1363. return false;
  1364. }
  1365. // Find the first position where the list of insert and extract indicies
  1366. // differ, if at all.
  1367. uint32_t i;
  1368. for (i = 1; i < inst->NumInOperands(); ++i) {
  1369. if (i + 1 >= cinst->NumInOperands()) {
  1370. break;
  1371. }
  1372. if (inst->GetSingleWordInOperand(i) !=
  1373. cinst->GetSingleWordInOperand(i + 1)) {
  1374. break;
  1375. }
  1376. }
  1377. // We are extracting the element that was inserted.
  1378. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1379. inst->SetOpcode(SpvOpCopyObject);
  1380. inst->SetInOperands(
  1381. {{SPV_OPERAND_TYPE_ID,
  1382. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1383. return true;
  1384. }
  1385. // Extracting the value that was inserted along with values for the base
  1386. // composite. Cannot do anything.
  1387. if (i == inst->NumInOperands()) {
  1388. return false;
  1389. }
  1390. // Extracting an element of the value that was inserted. Extract from
  1391. // that value directly.
  1392. if (i + 1 == cinst->NumInOperands()) {
  1393. std::vector<Operand> operands;
  1394. operands.push_back(
  1395. {SPV_OPERAND_TYPE_ID,
  1396. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1397. for (; i < inst->NumInOperands(); ++i) {
  1398. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1399. {inst->GetSingleWordInOperand(i)}});
  1400. }
  1401. inst->SetInOperands(std::move(operands));
  1402. return true;
  1403. }
  1404. // Extracting a value that is disjoint from the element being inserted.
  1405. // Rewrite the extract to use the composite input to the insert.
  1406. std::vector<Operand> operands;
  1407. operands.push_back(
  1408. {SPV_OPERAND_TYPE_ID,
  1409. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1410. for (i = 1; i < inst->NumInOperands(); ++i) {
  1411. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1412. {inst->GetSingleWordInOperand(i)}});
  1413. }
  1414. inst->SetInOperands(std::move(operands));
  1415. return true;
  1416. };
  1417. }
  1418. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1419. // operands of the VectorShuffle. We just need to adjust the index in the
  1420. // extract instruction.
  1421. FoldingRule VectorShuffleFeedingExtract() {
  1422. return [](IRContext* context, Instruction* inst,
  1423. const std::vector<const analysis::Constant*>&) {
  1424. assert(inst->opcode() == SpvOpCompositeExtract &&
  1425. "Wrong opcode. Should be OpCompositeExtract.");
  1426. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1427. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1428. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1429. Instruction* cinst = def_use_mgr->GetDef(cid);
  1430. if (cinst->opcode() != SpvOpVectorShuffle) {
  1431. return false;
  1432. }
  1433. // Find the size of the first vector operand of the VectorShuffle
  1434. Instruction* first_input =
  1435. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1436. analysis::Type* first_input_type =
  1437. type_mgr->GetType(first_input->type_id());
  1438. assert(first_input_type->AsVector() &&
  1439. "Input to vector shuffle should be vectors.");
  1440. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1441. // Get index of the element the vector shuffle is placing in the position
  1442. // being extracted.
  1443. uint32_t new_index =
  1444. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1445. // Extracting an undefined value so fold this extract into an undef.
  1446. const uint32_t undef_literal_value = 0xffffffff;
  1447. if (new_index == undef_literal_value) {
  1448. inst->SetOpcode(SpvOpUndef);
  1449. inst->SetInOperands({});
  1450. return true;
  1451. }
  1452. // Get the id of the of the vector the elemtent comes from, and update the
  1453. // index if needed.
  1454. uint32_t new_vector = 0;
  1455. if (new_index < first_input_size) {
  1456. new_vector = cinst->GetSingleWordInOperand(0);
  1457. } else {
  1458. new_vector = cinst->GetSingleWordInOperand(1);
  1459. new_index -= first_input_size;
  1460. }
  1461. // Update the extract instruction.
  1462. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1463. inst->SetInOperand(1, {new_index});
  1464. return true;
  1465. };
  1466. }
  1467. // When an FMix with is feeding an Extract that extracts an element whose
  1468. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1469. // operands of the FMix.
  1470. FoldingRule FMixFeedingExtract() {
  1471. return [](IRContext* context, Instruction* inst,
  1472. const std::vector<const analysis::Constant*>&) {
  1473. assert(inst->opcode() == SpvOpCompositeExtract &&
  1474. "Wrong opcode. Should be OpCompositeExtract.");
  1475. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1476. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1477. uint32_t composite_id =
  1478. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1479. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1480. if (composite_inst->opcode() != SpvOpExtInst) {
  1481. return false;
  1482. }
  1483. uint32_t inst_set_id =
  1484. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1485. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1486. inst_set_id ||
  1487. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1488. GLSLstd450FMix) {
  1489. return false;
  1490. }
  1491. // Get the |a| for the FMix instruction.
  1492. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1493. std::unique_ptr<Instruction> a(inst->Clone(context));
  1494. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1495. context->get_instruction_folder().FoldInstruction(a.get());
  1496. if (a->opcode() != SpvOpCopyObject) {
  1497. return false;
  1498. }
  1499. const analysis::Constant* a_const =
  1500. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1501. if (!a_const) {
  1502. return false;
  1503. }
  1504. bool use_x = false;
  1505. assert(a_const->type()->AsFloat());
  1506. double element_value = a_const->GetValueAsDouble();
  1507. if (element_value == 0.0) {
  1508. use_x = true;
  1509. } else if (element_value == 1.0) {
  1510. use_x = false;
  1511. } else {
  1512. return false;
  1513. }
  1514. // Get the id of the of the vector the element comes from.
  1515. uint32_t new_vector = 0;
  1516. if (use_x) {
  1517. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1518. } else {
  1519. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1520. }
  1521. // Update the extract instruction.
  1522. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1523. return true;
  1524. };
  1525. }
  1526. FoldingRule RedundantPhi() {
  1527. // An OpPhi instruction where all values are the same or the result of the phi
  1528. // itself, can be replaced by the value itself.
  1529. return [](IRContext*, Instruction* inst,
  1530. const std::vector<const analysis::Constant*>&) {
  1531. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1532. uint32_t incoming_value = 0;
  1533. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1534. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1535. if (op_id == inst->result_id()) {
  1536. continue;
  1537. }
  1538. if (incoming_value == 0) {
  1539. incoming_value = op_id;
  1540. } else if (op_id != incoming_value) {
  1541. // Found two possible value. Can't simplify.
  1542. return false;
  1543. }
  1544. }
  1545. if (incoming_value == 0) {
  1546. // Code looks invalid. Don't do anything.
  1547. return false;
  1548. }
  1549. // We have a single incoming value. Simplify using that value.
  1550. inst->SetOpcode(SpvOpCopyObject);
  1551. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1552. return true;
  1553. };
  1554. }
  1555. FoldingRule RedundantSelect() {
  1556. // An OpSelect instruction where both values are the same or the condition is
  1557. // constant can be replaced by one of the values
  1558. return [](IRContext*, Instruction* inst,
  1559. const std::vector<const analysis::Constant*>& constants) {
  1560. assert(inst->opcode() == SpvOpSelect &&
  1561. "Wrong opcode. Should be OpSelect.");
  1562. assert(inst->NumInOperands() == 3);
  1563. assert(constants.size() == 3);
  1564. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1565. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1566. if (true_id == false_id) {
  1567. // Both results are the same, condition doesn't matter
  1568. inst->SetOpcode(SpvOpCopyObject);
  1569. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1570. return true;
  1571. } else if (constants[0]) {
  1572. const analysis::Type* type = constants[0]->type();
  1573. if (type->AsBool()) {
  1574. // Scalar constant value, select the corresponding value.
  1575. inst->SetOpcode(SpvOpCopyObject);
  1576. if (constants[0]->AsNullConstant() ||
  1577. !constants[0]->AsBoolConstant()->value()) {
  1578. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1579. } else {
  1580. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1581. }
  1582. return true;
  1583. } else {
  1584. assert(type->AsVector());
  1585. if (constants[0]->AsNullConstant()) {
  1586. // All values come from false id.
  1587. inst->SetOpcode(SpvOpCopyObject);
  1588. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1589. return true;
  1590. } else {
  1591. // Convert to a vector shuffle.
  1592. std::vector<Operand> ops;
  1593. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  1594. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  1595. const analysis::VectorConstant* vector_const =
  1596. constants[0]->AsVectorConstant();
  1597. uint32_t size =
  1598. static_cast<uint32_t>(vector_const->GetComponents().size());
  1599. for (uint32_t i = 0; i != size; ++i) {
  1600. const analysis::Constant* component =
  1601. vector_const->GetComponents()[i];
  1602. if (component->AsNullConstant() ||
  1603. !component->AsBoolConstant()->value()) {
  1604. // Selecting from the false vector which is the second input
  1605. // vector to the shuffle. Offset the index by |size|.
  1606. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  1607. } else {
  1608. // Selecting from true vector which is the first input vector to
  1609. // the shuffle.
  1610. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  1611. }
  1612. }
  1613. inst->SetOpcode(SpvOpVectorShuffle);
  1614. inst->SetInOperands(std::move(ops));
  1615. return true;
  1616. }
  1617. }
  1618. }
  1619. return false;
  1620. };
  1621. }
  1622. enum class FloatConstantKind { Unknown, Zero, One };
  1623. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  1624. if (constant == nullptr) {
  1625. return FloatConstantKind::Unknown;
  1626. }
  1627. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  1628. if (constant->AsNullConstant()) {
  1629. return FloatConstantKind::Zero;
  1630. } else if (const analysis::VectorConstant* vc =
  1631. constant->AsVectorConstant()) {
  1632. const std::vector<const analysis::Constant*>& components =
  1633. vc->GetComponents();
  1634. assert(!components.empty());
  1635. FloatConstantKind kind = getFloatConstantKind(components[0]);
  1636. for (size_t i = 1; i < components.size(); ++i) {
  1637. if (getFloatConstantKind(components[i]) != kind) {
  1638. return FloatConstantKind::Unknown;
  1639. }
  1640. }
  1641. return kind;
  1642. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  1643. if (fc->IsZero()) return FloatConstantKind::Zero;
  1644. uint32_t width = fc->type()->AsFloat()->width();
  1645. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  1646. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  1647. if (value == 0.0) {
  1648. return FloatConstantKind::Zero;
  1649. } else if (value == 1.0) {
  1650. return FloatConstantKind::One;
  1651. } else {
  1652. return FloatConstantKind::Unknown;
  1653. }
  1654. } else {
  1655. return FloatConstantKind::Unknown;
  1656. }
  1657. }
  1658. FoldingRule RedundantFAdd() {
  1659. return [](IRContext*, Instruction* inst,
  1660. const std::vector<const analysis::Constant*>& constants) {
  1661. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  1662. assert(constants.size() == 2);
  1663. if (!inst->IsFloatingPointFoldingAllowed()) {
  1664. return false;
  1665. }
  1666. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1667. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1668. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1669. inst->SetOpcode(SpvOpCopyObject);
  1670. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1671. {inst->GetSingleWordInOperand(
  1672. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  1673. return true;
  1674. }
  1675. return false;
  1676. };
  1677. }
  1678. FoldingRule RedundantFSub() {
  1679. return [](IRContext*, Instruction* inst,
  1680. const std::vector<const analysis::Constant*>& constants) {
  1681. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  1682. assert(constants.size() == 2);
  1683. if (!inst->IsFloatingPointFoldingAllowed()) {
  1684. return false;
  1685. }
  1686. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1687. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1688. if (kind0 == FloatConstantKind::Zero) {
  1689. inst->SetOpcode(SpvOpFNegate);
  1690. inst->SetInOperands(
  1691. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  1692. return true;
  1693. }
  1694. if (kind1 == FloatConstantKind::Zero) {
  1695. inst->SetOpcode(SpvOpCopyObject);
  1696. inst->SetInOperands(
  1697. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1698. return true;
  1699. }
  1700. return false;
  1701. };
  1702. }
  1703. FoldingRule RedundantFMul() {
  1704. return [](IRContext*, Instruction* inst,
  1705. const std::vector<const analysis::Constant*>& constants) {
  1706. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  1707. assert(constants.size() == 2);
  1708. if (!inst->IsFloatingPointFoldingAllowed()) {
  1709. return false;
  1710. }
  1711. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1712. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1713. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1714. inst->SetOpcode(SpvOpCopyObject);
  1715. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1716. {inst->GetSingleWordInOperand(
  1717. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  1718. return true;
  1719. }
  1720. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  1721. inst->SetOpcode(SpvOpCopyObject);
  1722. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1723. {inst->GetSingleWordInOperand(
  1724. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  1725. return true;
  1726. }
  1727. return false;
  1728. };
  1729. }
  1730. FoldingRule RedundantFDiv() {
  1731. return [](IRContext*, Instruction* inst,
  1732. const std::vector<const analysis::Constant*>& constants) {
  1733. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  1734. assert(constants.size() == 2);
  1735. if (!inst->IsFloatingPointFoldingAllowed()) {
  1736. return false;
  1737. }
  1738. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1739. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1740. if (kind0 == FloatConstantKind::Zero) {
  1741. inst->SetOpcode(SpvOpCopyObject);
  1742. inst->SetInOperands(
  1743. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1744. return true;
  1745. }
  1746. if (kind1 == FloatConstantKind::One) {
  1747. inst->SetOpcode(SpvOpCopyObject);
  1748. inst->SetInOperands(
  1749. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1750. return true;
  1751. }
  1752. return false;
  1753. };
  1754. }
  1755. FoldingRule RedundantFMix() {
  1756. return [](IRContext* context, Instruction* inst,
  1757. const std::vector<const analysis::Constant*>& constants) {
  1758. assert(inst->opcode() == SpvOpExtInst &&
  1759. "Wrong opcode. Should be OpExtInst.");
  1760. if (!inst->IsFloatingPointFoldingAllowed()) {
  1761. return false;
  1762. }
  1763. uint32_t instSetId =
  1764. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1765. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  1766. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  1767. GLSLstd450FMix) {
  1768. assert(constants.size() == 5);
  1769. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  1770. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  1771. inst->SetOpcode(SpvOpCopyObject);
  1772. inst->SetInOperands(
  1773. {{SPV_OPERAND_TYPE_ID,
  1774. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  1775. ? kFMixXIdInIdx
  1776. : kFMixYIdInIdx)}}});
  1777. return true;
  1778. }
  1779. }
  1780. return false;
  1781. };
  1782. }
  1783. // This rule handles addition of zero for integers.
  1784. FoldingRule RedundantIAdd() {
  1785. return [](IRContext* context, Instruction* inst,
  1786. const std::vector<const analysis::Constant*>& constants) {
  1787. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  1788. uint32_t operand = std::numeric_limits<uint32_t>::max();
  1789. const analysis::Type* operand_type = nullptr;
  1790. if (constants[0] && constants[0]->IsZero()) {
  1791. operand = inst->GetSingleWordInOperand(1);
  1792. operand_type = constants[0]->type();
  1793. } else if (constants[1] && constants[1]->IsZero()) {
  1794. operand = inst->GetSingleWordInOperand(0);
  1795. operand_type = constants[1]->type();
  1796. }
  1797. if (operand != std::numeric_limits<uint32_t>::max()) {
  1798. const analysis::Type* inst_type =
  1799. context->get_type_mgr()->GetType(inst->type_id());
  1800. if (inst_type->IsSame(operand_type)) {
  1801. inst->SetOpcode(SpvOpCopyObject);
  1802. } else {
  1803. inst->SetOpcode(SpvOpBitcast);
  1804. }
  1805. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  1806. return true;
  1807. }
  1808. return false;
  1809. };
  1810. }
  1811. // This rule look for a dot with a constant vector containing a single 1 and
  1812. // the rest 0s. This is the same as doing an extract.
  1813. FoldingRule DotProductDoingExtract() {
  1814. return [](IRContext* context, Instruction* inst,
  1815. const std::vector<const analysis::Constant*>& constants) {
  1816. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  1817. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1818. if (!inst->IsFloatingPointFoldingAllowed()) {
  1819. return false;
  1820. }
  1821. for (int i = 0; i < 2; ++i) {
  1822. if (!constants[i]) {
  1823. continue;
  1824. }
  1825. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  1826. assert(vector_type && "Inputs to OpDot must be vectors.");
  1827. const analysis::Float* element_type =
  1828. vector_type->element_type()->AsFloat();
  1829. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  1830. uint32_t element_width = element_type->width();
  1831. if (element_width != 32 && element_width != 64) {
  1832. return false;
  1833. }
  1834. std::vector<const analysis::Constant*> components;
  1835. components = constants[i]->GetVectorComponents(const_mgr);
  1836. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  1837. uint32_t component_with_one = kNotFound;
  1838. bool all_others_zero = true;
  1839. for (uint32_t j = 0; j < components.size(); ++j) {
  1840. const analysis::Constant* element = components[j];
  1841. double value =
  1842. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  1843. if (value == 0.0) {
  1844. continue;
  1845. } else if (value == 1.0) {
  1846. if (component_with_one == kNotFound) {
  1847. component_with_one = j;
  1848. } else {
  1849. component_with_one = kNotFound;
  1850. break;
  1851. }
  1852. } else {
  1853. all_others_zero = false;
  1854. break;
  1855. }
  1856. }
  1857. if (!all_others_zero || component_with_one == kNotFound) {
  1858. continue;
  1859. }
  1860. std::vector<Operand> operands;
  1861. operands.push_back(
  1862. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  1863. operands.push_back(
  1864. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  1865. inst->SetOpcode(SpvOpCompositeExtract);
  1866. inst->SetInOperands(std::move(operands));
  1867. return true;
  1868. }
  1869. return false;
  1870. };
  1871. }
  1872. // If we are storing an undef, then we can remove the store.
  1873. //
  1874. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  1875. // is complicated. Waiting to see if it is needed.
  1876. FoldingRule StoringUndef() {
  1877. return [](IRContext* context, Instruction* inst,
  1878. const std::vector<const analysis::Constant*>&) {
  1879. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  1880. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1881. // If this is a volatile store, the store cannot be removed.
  1882. if (inst->NumInOperands() == 3) {
  1883. if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
  1884. return false;
  1885. }
  1886. }
  1887. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  1888. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  1889. if (object_inst->opcode() == SpvOpUndef) {
  1890. inst->ToNop();
  1891. return true;
  1892. }
  1893. return false;
  1894. };
  1895. }
  1896. FoldingRule VectorShuffleFeedingShuffle() {
  1897. return [](IRContext* context, Instruction* inst,
  1898. const std::vector<const analysis::Constant*>&) {
  1899. assert(inst->opcode() == SpvOpVectorShuffle &&
  1900. "Wrong opcode. Should be OpVectorShuffle.");
  1901. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1902. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1903. Instruction* feeding_shuffle_inst =
  1904. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  1905. analysis::Vector* op0_type =
  1906. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  1907. uint32_t op0_length = op0_type->element_count();
  1908. bool feeder_is_op0 = true;
  1909. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1910. feeding_shuffle_inst =
  1911. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  1912. feeder_is_op0 = false;
  1913. }
  1914. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1915. return false;
  1916. }
  1917. Instruction* feeder2 =
  1918. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  1919. analysis::Vector* feeder_op0_type =
  1920. type_mgr->GetType(feeder2->type_id())->AsVector();
  1921. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  1922. uint32_t new_feeder_id = 0;
  1923. std::vector<Operand> new_operands;
  1924. new_operands.resize(
  1925. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  1926. const uint32_t undef_literal = 0xffffffff;
  1927. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  1928. uint32_t component_index = inst->GetSingleWordInOperand(op);
  1929. // Do not interpret the undefined value literal as coming from operand 1.
  1930. if (component_index != undef_literal &&
  1931. feeder_is_op0 == (component_index < op0_length)) {
  1932. // This component comes from the feeding_shuffle_inst. Update
  1933. // |component_index| to be the index into the operand of the feeder.
  1934. // Adjust component_index to get the index into the operands of the
  1935. // feeding_shuffle_inst.
  1936. if (component_index >= op0_length) {
  1937. component_index -= op0_length;
  1938. }
  1939. component_index =
  1940. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  1941. // Check if we are using a component from the first or second operand of
  1942. // the feeding instruction.
  1943. if (component_index < feeder_op0_length) {
  1944. if (new_feeder_id == 0) {
  1945. // First time through, save the id of the operand the element comes
  1946. // from.
  1947. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  1948. } else if (new_feeder_id !=
  1949. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  1950. // We need both elements of the feeding_shuffle_inst, so we cannot
  1951. // fold.
  1952. return false;
  1953. }
  1954. } else {
  1955. if (new_feeder_id == 0) {
  1956. // First time through, save the id of the operand the element comes
  1957. // from.
  1958. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  1959. } else if (new_feeder_id !=
  1960. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  1961. // We need both elements of the feeding_shuffle_inst, so we cannot
  1962. // fold.
  1963. return false;
  1964. }
  1965. component_index -= feeder_op0_length;
  1966. }
  1967. if (!feeder_is_op0) {
  1968. component_index += op0_length;
  1969. }
  1970. }
  1971. new_operands.push_back(
  1972. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  1973. }
  1974. if (new_feeder_id == 0) {
  1975. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1976. const analysis::Type* type =
  1977. type_mgr->GetType(feeding_shuffle_inst->type_id());
  1978. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  1979. new_feeder_id =
  1980. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  1981. }
  1982. if (feeder_is_op0) {
  1983. // If the size of the first vector operand changed then the indices
  1984. // referring to the second operand need to be adjusted.
  1985. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  1986. analysis::Type* new_feeder_type =
  1987. type_mgr->GetType(new_feeder_inst->type_id());
  1988. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  1989. int32_t adjustment = op0_length - new_op0_size;
  1990. if (adjustment != 0) {
  1991. for (uint32_t i = 2; i < new_operands.size(); i++) {
  1992. if (inst->GetSingleWordInOperand(i) >= op0_length) {
  1993. new_operands[i].words[0] -= adjustment;
  1994. }
  1995. }
  1996. }
  1997. new_operands[0].words[0] = new_feeder_id;
  1998. new_operands[1] = inst->GetInOperand(1);
  1999. } else {
  2000. new_operands[1].words[0] = new_feeder_id;
  2001. new_operands[0] = inst->GetInOperand(0);
  2002. }
  2003. inst->SetInOperands(std::move(new_operands));
  2004. return true;
  2005. };
  2006. }
  2007. // Removes duplicate ids from the interface list of an OpEntryPoint
  2008. // instruction.
  2009. FoldingRule RemoveRedundantOperands() {
  2010. return [](IRContext*, Instruction* inst,
  2011. const std::vector<const analysis::Constant*>&) {
  2012. assert(inst->opcode() == SpvOpEntryPoint &&
  2013. "Wrong opcode. Should be OpEntryPoint.");
  2014. bool has_redundant_operand = false;
  2015. std::unordered_set<uint32_t> seen_operands;
  2016. std::vector<Operand> new_operands;
  2017. new_operands.emplace_back(inst->GetOperand(0));
  2018. new_operands.emplace_back(inst->GetOperand(1));
  2019. new_operands.emplace_back(inst->GetOperand(2));
  2020. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2021. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2022. new_operands.emplace_back(inst->GetOperand(i));
  2023. } else {
  2024. has_redundant_operand = true;
  2025. }
  2026. }
  2027. if (!has_redundant_operand) {
  2028. return false;
  2029. }
  2030. inst->SetInOperands(std::move(new_operands));
  2031. return true;
  2032. };
  2033. }
  2034. } // namespace
  2035. void FoldingRules::AddFoldingRules() {
  2036. // Add all folding rules to the list for the opcodes to which they apply.
  2037. // Note that the order in which rules are added to the list matters. If a rule
  2038. // applies to the instruction, the rest of the rules will not be attempted.
  2039. // Take that into consideration.
  2040. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
  2041. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  2042. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
  2043. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2044. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  2045. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  2046. rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
  2047. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  2048. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  2049. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  2050. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  2051. rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
  2052. rules_[SpvOpFAdd].push_back(FactorAddMuls());
  2053. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  2054. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  2055. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  2056. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  2057. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  2058. rules_[SpvOpFMul].push_back(RedundantFMul());
  2059. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  2060. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  2061. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  2062. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  2063. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  2064. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  2065. rules_[SpvOpFSub].push_back(RedundantFSub());
  2066. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  2067. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  2068. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  2069. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  2070. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  2071. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  2072. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  2073. rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
  2074. rules_[SpvOpIAdd].push_back(FactorAddMuls());
  2075. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  2076. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  2077. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  2078. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  2079. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  2080. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  2081. rules_[SpvOpPhi].push_back(RedundantPhi());
  2082. rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
  2083. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  2084. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  2085. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  2086. rules_[SpvOpSelect].push_back(RedundantSelect());
  2087. rules_[SpvOpStore].push_back(StoringUndef());
  2088. rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
  2089. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2090. FeatureManager* feature_manager = context_->get_feature_mgr();
  2091. // Add rules for GLSLstd450
  2092. uint32_t ext_inst_glslstd450_id =
  2093. feature_manager->GetExtInstImportId_GLSLstd450();
  2094. if (ext_inst_glslstd450_id != 0) {
  2095. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2096. RedundantFMix());
  2097. }
  2098. }
  2099. } // namespace opt
  2100. } // namespace spvtools