folding_rules.cpp 81 KB


  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <limits>
  16. #include <memory>
  17. #include <utility>
  18. #include "source/latest_version_glsl_std_450_header.h"
  19. #include "source/opt/ir_context.h"
  20. namespace spvtools {
  21. namespace opt {
  22. namespace {
  23. const uint32_t kExtractCompositeIdInIdx = 0;
  24. const uint32_t kInsertObjectIdInIdx = 0;
  25. const uint32_t kInsertCompositeIdInIdx = 1;
  26. const uint32_t kExtInstSetIdInIdx = 0;
  27. const uint32_t kExtInstInstructionInIdx = 1;
  28. const uint32_t kFMixXIdInIdx = 2;
  29. const uint32_t kFMixYIdInIdx = 3;
  30. const uint32_t kFMixAIdInIdx = 4;
  31. const uint32_t kStoreObjectInIdx = 1;
  32. // Returns the element width of |type|.
  33. uint32_t ElementWidth(const analysis::Type* type) {
  34. if (const analysis::Vector* vec_type = type->AsVector()) {
  35. return ElementWidth(vec_type->element_type());
  36. } else if (const analysis::Float* float_type = type->AsFloat()) {
  37. return float_type->width();
  38. } else {
  39. assert(type->AsInteger());
  40. return type->AsInteger()->width();
  41. }
  42. }
  43. // Returns true if |type| is Float or a vector of Float.
  44. bool HasFloatingPoint(const analysis::Type* type) {
  45. if (type->AsFloat()) {
  46. return true;
  47. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  48. return vec_type->element_type()->AsFloat() != nullptr;
  49. }
  50. return false;
  51. }
  52. // Returns false if |val| is NaN, infinite or subnormal.
  53. template <typename T>
  54. bool IsValidResult(T val) {
  55. int classified = std::fpclassify(val);
  56. switch (classified) {
  57. case FP_NAN:
  58. case FP_INFINITE:
  59. case FP_SUBNORMAL:
  60. return false;
  61. default:
  62. return true;
  63. }
  64. }
  65. const analysis::Constant* ConstInput(
  66. const std::vector<const analysis::Constant*>& constants) {
  67. return constants[0] ? constants[0] : constants[1];
  68. }
  69. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  70. Instruction* inst) {
  71. uint32_t in_op = c ? 1u : 0u;
  72. return context->get_def_use_mgr()->GetDef(
  73. inst->GetSingleWordInOperand(in_op));
  74. }
  75. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  76. // constant.
  77. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  78. const analysis::Constant* c) {
  79. assert(c);
  80. assert(c->type()->AsFloat());
  81. uint32_t width = c->type()->AsFloat()->width();
  82. assert(width == 32 || width == 64);
  83. std::vector<uint32_t> words;
  84. if (width == 64) {
  85. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  86. words = result.GetWords();
  87. } else {
  88. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  89. words = result.GetWords();
  90. }
  91. const analysis::Constant* negated_const =
  92. const_mgr->GetConstant(c->type(), std::move(words));
  93. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  94. }
  95. std::vector<uint32_t> ExtractInts(uint64_t val) {
  96. std::vector<uint32_t> words;
  97. words.push_back(static_cast<uint32_t>(val));
  98. words.push_back(static_cast<uint32_t>(val >> 32));
  99. return words;
  100. }
  101. // Negates the integer constant |c|. Returns the id of the defining instruction.
  102. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  103. const analysis::Constant* c) {
  104. assert(c);
  105. assert(c->type()->AsInteger());
  106. uint32_t width = c->type()->AsInteger()->width();
  107. assert(width == 32 || width == 64);
  108. std::vector<uint32_t> words;
  109. if (width == 64) {
  110. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  111. words = ExtractInts(uval);
  112. } else {
  113. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  114. }
  115. const analysis::Constant* negated_const =
  116. const_mgr->GetConstant(c->type(), std::move(words));
  117. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  118. }
  119. // Negates the vector constant |c|. Returns the id of the defining instruction.
  120. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  121. const analysis::Constant* c) {
  122. assert(const_mgr && c);
  123. assert(c->type()->AsVector());
  124. if (c->AsNullConstant()) {
  125. // 0.0 vs -0.0 shouldn't matter.
  126. return const_mgr->GetDefiningInstruction(c)->result_id();
  127. } else {
  128. const analysis::Type* component_type =
  129. c->AsVectorConstant()->component_type();
  130. std::vector<uint32_t> words;
  131. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  132. if (component_type->AsFloat()) {
  133. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  134. } else {
  135. assert(component_type->AsInteger());
  136. words.push_back(NegateIntegerConstant(const_mgr, comp));
  137. }
  138. }
  139. const analysis::Constant* negated_const =
  140. const_mgr->GetConstant(c->type(), std::move(words));
  141. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  142. }
  143. }
  144. // Negates |c|. Returns the id of the defining instruction.
  145. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  146. const analysis::Constant* c) {
  147. if (c->type()->AsVector()) {
  148. return NegateVectorConstant(const_mgr, c);
  149. } else if (c->type()->AsFloat()) {
  150. return NegateFloatingPointConstant(const_mgr, c);
  151. } else {
  152. assert(c->type()->AsInteger());
  153. return NegateIntegerConstant(const_mgr, c);
  154. }
  155. }
  156. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  157. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  158. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  159. const analysis::Constant* c) {
  160. assert(const_mgr && c);
  161. assert(c->type()->AsFloat());
  162. uint32_t width = c->type()->AsFloat()->width();
  163. assert(width == 32 || width == 64);
  164. std::vector<uint32_t> words;
  165. if (width == 64) {
  166. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  167. if (!IsValidResult(result.getAsFloat())) return 0;
  168. words = result.GetWords();
  169. } else {
  170. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  171. if (!IsValidResult(result.getAsFloat())) return 0;
  172. words = result.GetWords();
  173. }
  174. const analysis::Constant* negated_const =
  175. const_mgr->GetConstant(c->type(), std::move(words));
  176. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  177. }
  178. // Replaces fdiv where second operand is constant with fmul.
  179. FoldingRule ReciprocalFDiv() {
  180. return [](IRContext* context, Instruction* inst,
  181. const std::vector<const analysis::Constant*>& constants) {
  182. assert(inst->opcode() == SpvOpFDiv);
  183. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  184. const analysis::Type* type =
  185. context->get_type_mgr()->GetType(inst->type_id());
  186. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  187. uint32_t width = ElementWidth(type);
  188. if (width != 32 && width != 64) return false;
  189. if (constants[1] != nullptr) {
  190. uint32_t id = 0;
  191. if (const analysis::VectorConstant* vector_const =
  192. constants[1]->AsVectorConstant()) {
  193. std::vector<uint32_t> neg_ids;
  194. for (auto& comp : vector_const->GetComponents()) {
  195. id = Reciprocal(const_mgr, comp);
  196. if (id == 0) return false;
  197. neg_ids.push_back(id);
  198. }
  199. const analysis::Constant* negated_const =
  200. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  201. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  202. } else if (constants[1]->AsFloatConstant()) {
  203. id = Reciprocal(const_mgr, constants[1]);
  204. if (id == 0) return false;
  205. } else {
  206. // Don't fold a null constant.
  207. return false;
  208. }
  209. inst->SetOpcode(SpvOpFMul);
  210. inst->SetInOperands(
  211. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  212. {SPV_OPERAND_TYPE_ID, {id}}});
  213. return true;
  214. }
  215. return false;
  216. };
  217. }
  218. // Elides consecutive negate instructions.
  219. FoldingRule MergeNegateArithmetic() {
  220. return [](IRContext* context, Instruction* inst,
  221. const std::vector<const analysis::Constant*>& constants) {
  222. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  223. (void)constants;
  224. const analysis::Type* type =
  225. context->get_type_mgr()->GetType(inst->type_id());
  226. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  227. return false;
  228. Instruction* op_inst =
  229. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  230. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  231. return false;
  232. if (op_inst->opcode() == inst->opcode()) {
  233. // Elide negates.
  234. inst->SetOpcode(SpvOpCopyObject);
  235. inst->SetInOperands(
  236. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  237. return true;
  238. }
  239. return false;
  240. };
  241. }
  242. // Merges negate into a mul or div operation if that operation contains a
  243. // constant operand.
  244. // Cases:
  245. // -(x * 2) = x * -2
  246. // -(2 * x) = x * -2
  247. // -(x / 2) = x / -2
  248. // -(2 / x) = -2 / x
  249. FoldingRule MergeNegateMulDivArithmetic() {
  250. return [](IRContext* context, Instruction* inst,
  251. const std::vector<const analysis::Constant*>& constants) {
  252. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  253. (void)constants;
  254. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  255. const analysis::Type* type =
  256. context->get_type_mgr()->GetType(inst->type_id());
  257. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  258. return false;
  259. Instruction* op_inst =
  260. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  261. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  262. return false;
  263. uint32_t width = ElementWidth(type);
  264. if (width != 32 && width != 64) return false;
  265. SpvOp opcode = op_inst->opcode();
  266. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  267. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  268. std::vector<const analysis::Constant*> op_constants =
  269. const_mgr->GetOperandConstants(op_inst);
  270. // Merge negate into mul or div if one operand is constant.
  271. if (op_constants[0] || op_constants[1]) {
  272. bool zero_is_variable = op_constants[0] == nullptr;
  273. const analysis::Constant* c = ConstInput(op_constants);
  274. uint32_t neg_id = NegateConstant(const_mgr, c);
  275. uint32_t non_const_id = zero_is_variable
  276. ? op_inst->GetSingleWordInOperand(0u)
  277. : op_inst->GetSingleWordInOperand(1u);
  278. // Change this instruction to a mul/div.
  279. inst->SetOpcode(op_inst->opcode());
  280. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  281. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  282. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  283. inst->SetInOperands(
  284. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  285. } else {
  286. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  287. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  288. }
  289. return true;
  290. }
  291. }
  292. return false;
  293. };
  294. }
  295. // Merges negate into a add or sub operation if that operation contains a
  296. // constant operand.
  297. // Cases:
  298. // -(x + 2) = -2 - x
  299. // -(2 + x) = -2 - x
  300. // -(x - 2) = 2 - x
  301. // -(2 - x) = x - 2
  302. FoldingRule MergeNegateAddSubArithmetic() {
  303. return [](IRContext* context, Instruction* inst,
  304. const std::vector<const analysis::Constant*>& constants) {
  305. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  306. (void)constants;
  307. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  308. const analysis::Type* type =
  309. context->get_type_mgr()->GetType(inst->type_id());
  310. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  311. return false;
  312. Instruction* op_inst =
  313. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  314. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  315. return false;
  316. uint32_t width = ElementWidth(type);
  317. if (width != 32 && width != 64) return false;
  318. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  319. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  320. std::vector<const analysis::Constant*> op_constants =
  321. const_mgr->GetOperandConstants(op_inst);
  322. if (op_constants[0] || op_constants[1]) {
  323. bool zero_is_variable = op_constants[0] == nullptr;
  324. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  325. (op_inst->opcode() == SpvOpIAdd);
  326. bool swap_operands = !is_add || zero_is_variable;
  327. bool negate_const = is_add;
  328. const analysis::Constant* c = ConstInput(op_constants);
  329. uint32_t const_id = 0;
  330. if (negate_const) {
  331. const_id = NegateConstant(const_mgr, c);
  332. } else {
  333. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  334. : op_inst->GetSingleWordInOperand(0u);
  335. }
  336. // Swap operands if necessary and make the instruction a subtraction.
  337. uint32_t op0 =
  338. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  339. uint32_t op1 =
  340. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  341. if (swap_operands) std::swap(op0, op1);
  342. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  343. inst->SetInOperands(
  344. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  345. return true;
  346. }
  347. }
  348. return false;
  349. };
  350. }
  351. // Returns true if |c| has a zero element.
  352. bool HasZero(const analysis::Constant* c) {
  353. if (c->AsNullConstant()) {
  354. return true;
  355. }
  356. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  357. for (auto& comp : vec_const->GetComponents())
  358. if (HasZero(comp)) return true;
  359. } else {
  360. assert(c->AsScalarConstant());
  361. return c->AsScalarConstant()->IsZero();
  362. }
  363. return false;
  364. }
  365. // Performs |input1| |opcode| |input2| and returns the merged constant result
  366. // id. Returns 0 if the result is not a valid value. The input types must be
  367. // Float.
  368. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  369. SpvOp opcode,
  370. const analysis::Constant* input1,
  371. const analysis::Constant* input2) {
  372. const analysis::Type* type = input1->type();
  373. assert(type->AsFloat());
  374. uint32_t width = type->AsFloat()->width();
  375. assert(width == 32 || width == 64);
  376. std::vector<uint32_t> words;
  377. #define FOLD_OP(op) \
  378. if (width == 64) { \
  379. utils::FloatProxy<double> val = \
  380. input1->GetDouble() op input2->GetDouble(); \
  381. double dval = val.getAsFloat(); \
  382. if (!IsValidResult(dval)) return 0; \
  383. words = val.GetWords(); \
  384. } else { \
  385. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  386. float fval = val.getAsFloat(); \
  387. if (!IsValidResult(fval)) return 0; \
  388. words = val.GetWords(); \
  389. }
  390. switch (opcode) {
  391. case SpvOpFMul:
  392. FOLD_OP(*);
  393. break;
  394. case SpvOpFDiv:
  395. if (HasZero(input2)) return 0;
  396. FOLD_OP(/);
  397. break;
  398. case SpvOpFAdd:
  399. FOLD_OP(+);
  400. break;
  401. case SpvOpFSub:
  402. FOLD_OP(-);
  403. break;
  404. default:
  405. assert(false && "Unexpected operation");
  406. break;
  407. }
  408. #undef FOLD_OP
  409. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  410. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  411. }
  412. // Performs |input1| |opcode| |input2| and returns the merged constant result
  413. // id. Returns 0 if the result is not a valid value. The input types must be
  414. // Integers.
  415. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  416. SpvOp opcode, const analysis::Constant* input1,
  417. const analysis::Constant* input2) {
  418. assert(input1->type()->AsInteger());
  419. const analysis::Integer* type = input1->type()->AsInteger();
  420. uint32_t width = type->AsInteger()->width();
  421. assert(width == 32 || width == 64);
  422. std::vector<uint32_t> words;
  423. #define FOLD_OP(op) \
  424. if (width == 64) { \
  425. if (type->IsSigned()) { \
  426. int64_t val = input1->GetS64() op input2->GetS64(); \
  427. words = ExtractInts(static_cast<uint64_t>(val)); \
  428. } else { \
  429. uint64_t val = input1->GetU64() op input2->GetU64(); \
  430. words = ExtractInts(val); \
  431. } \
  432. } else { \
  433. if (type->IsSigned()) { \
  434. int32_t val = input1->GetS32() op input2->GetS32(); \
  435. words.push_back(static_cast<uint32_t>(val)); \
  436. } else { \
  437. uint32_t val = input1->GetU32() op input2->GetU32(); \
  438. words.push_back(val); \
  439. } \
  440. }
  441. switch (opcode) {
  442. case SpvOpIMul:
  443. FOLD_OP(*);
  444. break;
  445. case SpvOpSDiv:
  446. case SpvOpUDiv:
  447. assert(false && "Should not merge integer division");
  448. break;
  449. case SpvOpIAdd:
  450. FOLD_OP(+);
  451. break;
  452. case SpvOpISub:
  453. FOLD_OP(-);
  454. break;
  455. default:
  456. assert(false && "Unexpected operation");
  457. break;
  458. }
  459. #undef FOLD_OP
  460. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  461. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  462. }
  463. // Performs |input1| |opcode| |input2| and returns the merged constant result
  464. // id. Returns 0 if the result is not a valid value. The input types must be
  465. // Integers, Floats or Vectors of such.
  466. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  467. const analysis::Constant* input1,
  468. const analysis::Constant* input2) {
  469. assert(input1 && input2);
  470. assert(input1->type() == input2->type());
  471. const analysis::Type* type = input1->type();
  472. std::vector<uint32_t> words;
  473. if (const analysis::Vector* vector_type = type->AsVector()) {
  474. const analysis::Type* ele_type = vector_type->element_type();
  475. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  476. uint32_t id = 0;
  477. const analysis::Constant* input1_comp = nullptr;
  478. if (const analysis::VectorConstant* input1_vector =
  479. input1->AsVectorConstant()) {
  480. input1_comp = input1_vector->GetComponents()[i];
  481. } else {
  482. assert(input1->AsNullConstant());
  483. input1_comp = const_mgr->GetConstant(ele_type, {});
  484. }
  485. const analysis::Constant* input2_comp = nullptr;
  486. if (const analysis::VectorConstant* input2_vector =
  487. input2->AsVectorConstant()) {
  488. input2_comp = input2_vector->GetComponents()[i];
  489. } else {
  490. assert(input2->AsNullConstant());
  491. input2_comp = const_mgr->GetConstant(ele_type, {});
  492. }
  493. if (ele_type->AsFloat()) {
  494. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  495. input2_comp);
  496. } else {
  497. assert(ele_type->AsInteger());
  498. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  499. input2_comp);
  500. }
  501. if (id == 0) return 0;
  502. words.push_back(id);
  503. }
  504. const analysis::Constant* merged_const =
  505. const_mgr->GetConstant(type, words);
  506. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  507. } else if (type->AsFloat()) {
  508. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  509. } else {
  510. assert(type->AsInteger());
  511. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  512. }
  513. }
  514. // Merges consecutive multiplies where each contains one constant operand.
  515. // Cases:
  516. // 2 * (x * 2) = x * 4
  517. // 2 * (2 * x) = x * 4
  518. // (x * 2) * 2 = x * 4
  519. // (2 * x) * 2 = x * 4
  520. FoldingRule MergeMulMulArithmetic() {
  521. return [](IRContext* context, Instruction* inst,
  522. const std::vector<const analysis::Constant*>& constants) {
  523. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  524. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  525. const analysis::Type* type =
  526. context->get_type_mgr()->GetType(inst->type_id());
  527. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  528. return false;
  529. uint32_t width = ElementWidth(type);
  530. if (width != 32 && width != 64) return false;
  531. // Determine the constant input and the variable input in |inst|.
  532. const analysis::Constant* const_input1 = ConstInput(constants);
  533. if (!const_input1) return false;
  534. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  535. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  536. return false;
  537. if (other_inst->opcode() == inst->opcode()) {
  538. std::vector<const analysis::Constant*> other_constants =
  539. const_mgr->GetOperandConstants(other_inst);
  540. const analysis::Constant* const_input2 = ConstInput(other_constants);
  541. if (!const_input2) return false;
  542. bool other_first_is_variable = other_constants[0] == nullptr;
  543. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  544. const_input1, const_input2);
  545. if (merged_id == 0) return false;
  546. uint32_t non_const_id = other_first_is_variable
  547. ? other_inst->GetSingleWordInOperand(0u)
  548. : other_inst->GetSingleWordInOperand(1u);
  549. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  550. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  551. return true;
  552. }
  553. return false;
  554. };
  555. }
  556. // Merges divides into subsequent multiplies if each instruction contains one
  557. // constant operand. Does not support integer operations.
  558. // Cases:
  559. // 2 * (x / 2) = x * 1
  560. // 2 * (2 / x) = 4 / x
  561. // (x / 2) * 2 = x * 1
  562. // (2 / x) * 2 = 4 / x
  563. // (y / x) * x = y
  564. // x * (y / x) = y
  565. FoldingRule MergeMulDivArithmetic() {
  566. return [](IRContext* context, Instruction* inst,
  567. const std::vector<const analysis::Constant*>& constants) {
  568. assert(inst->opcode() == SpvOpFMul);
  569. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  570. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  571. const analysis::Type* type =
  572. context->get_type_mgr()->GetType(inst->type_id());
  573. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  574. uint32_t width = ElementWidth(type);
  575. if (width != 32 && width != 64) return false;
  576. for (uint32_t i = 0; i < 2; i++) {
  577. uint32_t op_id = inst->GetSingleWordInOperand(i);
  578. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  579. if (op_inst->opcode() == SpvOpFDiv) {
  580. if (op_inst->GetSingleWordInOperand(1) ==
  581. inst->GetSingleWordInOperand(1 - i)) {
  582. inst->SetOpcode(SpvOpCopyObject);
  583. inst->SetInOperands(
  584. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  585. return true;
  586. }
  587. }
  588. }
  589. const analysis::Constant* const_input1 = ConstInput(constants);
  590. if (!const_input1) return false;
  591. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  592. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  593. if (other_inst->opcode() == SpvOpFDiv) {
  594. std::vector<const analysis::Constant*> other_constants =
  595. const_mgr->GetOperandConstants(other_inst);
  596. const analysis::Constant* const_input2 = ConstInput(other_constants);
  597. if (!const_input2 || HasZero(const_input2)) return false;
  598. bool other_first_is_variable = other_constants[0] == nullptr;
  599. // If the variable value is the second operand of the divide, multiply
  600. // the constants together. Otherwise divide the constants.
  601. uint32_t merged_id = PerformOperation(
  602. const_mgr,
  603. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  604. const_input1, const_input2);
  605. if (merged_id == 0) return false;
  606. uint32_t non_const_id = other_first_is_variable
  607. ? other_inst->GetSingleWordInOperand(0u)
  608. : other_inst->GetSingleWordInOperand(1u);
  609. // If the variable value is on the second operand of the div, then this
  610. // operation is a div. Otherwise it should be a multiply.
  611. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  612. : other_inst->opcode());
  613. if (other_first_is_variable) {
  614. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  615. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  616. } else {
  617. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  618. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  619. }
  620. return true;
  621. }
  622. return false;
  623. };
  624. }
  625. // Merges multiply of constant and negation.
  626. // Cases:
  627. // (-x) * 2 = x * -2
  628. // 2 * (-x) = x * -2
  629. FoldingRule MergeMulNegateArithmetic() {
  630. return [](IRContext* context, Instruction* inst,
  631. const std::vector<const analysis::Constant*>& constants) {
  632. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  633. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  634. const analysis::Type* type =
  635. context->get_type_mgr()->GetType(inst->type_id());
  636. bool uses_float = HasFloatingPoint(type);
  637. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  638. uint32_t width = ElementWidth(type);
  639. if (width != 32 && width != 64) return false;
  640. const analysis::Constant* const_input1 = ConstInput(constants);
  641. if (!const_input1) return false;
  642. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  643. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  644. return false;
  645. if (other_inst->opcode() == SpvOpFNegate ||
  646. other_inst->opcode() == SpvOpSNegate) {
  647. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  648. inst->SetInOperands(
  649. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  650. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  651. return true;
  652. }
  653. return false;
  654. };
  655. }
  656. // Merges consecutive divides if each instruction contains one constant operand.
  657. // Does not support integer division.
  658. // Cases:
  659. // 2 / (x / 2) = 4 / x
  660. // 4 / (2 / x) = 2 * x
  661. // (4 / x) / 2 = 2 / x
  662. // (x / 2) / 2 = x / 4
  663. FoldingRule MergeDivDivArithmetic() {
  664. return [](IRContext* context, Instruction* inst,
  665. const std::vector<const analysis::Constant*>& constants) {
  666. assert(inst->opcode() == SpvOpFDiv);
  667. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  668. const analysis::Type* type =
  669. context->get_type_mgr()->GetType(inst->type_id());
  670. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  671. uint32_t width = ElementWidth(type);
  672. if (width != 32 && width != 64) return false;
  673. const analysis::Constant* const_input1 = ConstInput(constants);
  674. if (!const_input1 || HasZero(const_input1)) return false;
  675. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  676. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  677. bool first_is_variable = constants[0] == nullptr;
  678. if (other_inst->opcode() == inst->opcode()) {
  679. std::vector<const analysis::Constant*> other_constants =
  680. const_mgr->GetOperandConstants(other_inst);
  681. const analysis::Constant* const_input2 = ConstInput(other_constants);
  682. if (!const_input2 || HasZero(const_input2)) return false;
  683. bool other_first_is_variable = other_constants[0] == nullptr;
  684. SpvOp merge_op = inst->opcode();
  685. if (other_first_is_variable) {
  686. // Constants magnify.
  687. merge_op = SpvOpFMul;
  688. }
  689. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  690. // because it is commutative.
  691. if (first_is_variable) std::swap(const_input1, const_input2);
  692. uint32_t merged_id =
  693. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  694. if (merged_id == 0) return false;
  695. uint32_t non_const_id = other_first_is_variable
  696. ? other_inst->GetSingleWordInOperand(0u)
  697. : other_inst->GetSingleWordInOperand(1u);
  698. SpvOp op = inst->opcode();
  699. if (!first_is_variable && !other_first_is_variable) {
  700. // Effectively div of 1/x, so change to multiply.
  701. op = SpvOpFMul;
  702. }
  703. uint32_t op1 = merged_id;
  704. uint32_t op2 = non_const_id;
  705. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  706. inst->SetOpcode(op);
  707. inst->SetInOperands(
  708. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  709. return true;
  710. }
  711. return false;
  712. };
  713. }
  714. // Fold multiplies succeeded by divides where each instruction contains a
  715. // constant operand. Does not support integer divide.
  716. // Cases:
  717. // 4 / (x * 2) = 2 / x
  718. // 4 / (2 * x) = 2 / x
  719. // (x * 4) / 2 = x * 2
  720. // (4 * x) / 2 = x * 2
  721. // (x * y) / x = y
  722. // (y * x) / x = y
  723. FoldingRule MergeDivMulArithmetic() {
  724. return [](IRContext* context, Instruction* inst,
  725. const std::vector<const analysis::Constant*>& constants) {
  726. assert(inst->opcode() == SpvOpFDiv);
  727. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  728. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  729. const analysis::Type* type =
  730. context->get_type_mgr()->GetType(inst->type_id());
  731. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  732. uint32_t width = ElementWidth(type);
  733. if (width != 32 && width != 64) return false;
  734. uint32_t op_id = inst->GetSingleWordInOperand(0);
  735. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  736. if (op_inst->opcode() == SpvOpFMul) {
  737. for (uint32_t i = 0; i < 2; i++) {
  738. if (op_inst->GetSingleWordInOperand(i) ==
  739. inst->GetSingleWordInOperand(1)) {
  740. inst->SetOpcode(SpvOpCopyObject);
  741. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  742. {op_inst->GetSingleWordInOperand(1 - i)}}});
  743. return true;
  744. }
  745. }
  746. }
  747. const analysis::Constant* const_input1 = ConstInput(constants);
  748. if (!const_input1 || HasZero(const_input1)) return false;
  749. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  750. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  751. bool first_is_variable = constants[0] == nullptr;
  752. if (other_inst->opcode() == SpvOpFMul) {
  753. std::vector<const analysis::Constant*> other_constants =
  754. const_mgr->GetOperandConstants(other_inst);
  755. const analysis::Constant* const_input2 = ConstInput(other_constants);
  756. if (!const_input2) return false;
  757. bool other_first_is_variable = other_constants[0] == nullptr;
  758. // This is an x / (*) case. Swap the inputs.
  759. if (first_is_variable) std::swap(const_input1, const_input2);
  760. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  761. const_input1, const_input2);
  762. if (merged_id == 0) return false;
  763. uint32_t non_const_id = other_first_is_variable
  764. ? other_inst->GetSingleWordInOperand(0u)
  765. : other_inst->GetSingleWordInOperand(1u);
  766. uint32_t op1 = merged_id;
  767. uint32_t op2 = non_const_id;
  768. if (first_is_variable) std::swap(op1, op2);
  769. // Convert to multiply
  770. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  771. inst->SetInOperands(
  772. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  773. return true;
  774. }
  775. return false;
  776. };
  777. }
  778. // Fold divides of a constant and a negation.
  779. // Cases:
  780. // (-x) / 2 = x / -2
  781. // 2 / (-x) = 2 / -x
  782. FoldingRule MergeDivNegateArithmetic() {
  783. return [](IRContext* context, Instruction* inst,
  784. const std::vector<const analysis::Constant*>& constants) {
  785. assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
  786. inst->opcode() == SpvOpUDiv);
  787. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  788. const analysis::Type* type =
  789. context->get_type_mgr()->GetType(inst->type_id());
  790. bool uses_float = HasFloatingPoint(type);
  791. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  792. uint32_t width = ElementWidth(type);
  793. if (width != 32 && width != 64) return false;
  794. const analysis::Constant* const_input1 = ConstInput(constants);
  795. if (!const_input1) return false;
  796. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  797. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  798. return false;
  799. bool first_is_variable = constants[0] == nullptr;
  800. if (other_inst->opcode() == SpvOpFNegate ||
  801. other_inst->opcode() == SpvOpSNegate) {
  802. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  803. if (first_is_variable) {
  804. inst->SetInOperands(
  805. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  806. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  807. } else {
  808. inst->SetInOperands(
  809. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  810. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  811. }
  812. return true;
  813. }
  814. return false;
  815. };
  816. }
  817. // Folds addition of a constant and a negation.
  818. // Cases:
  819. // (-x) + 2 = 2 - x
  820. // 2 + (-x) = 2 - x
  821. FoldingRule MergeAddNegateArithmetic() {
  822. return [](IRContext* context, Instruction* inst,
  823. const std::vector<const analysis::Constant*>& constants) {
  824. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  825. const analysis::Type* type =
  826. context->get_type_mgr()->GetType(inst->type_id());
  827. bool uses_float = HasFloatingPoint(type);
  828. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  829. const analysis::Constant* const_input1 = ConstInput(constants);
  830. if (!const_input1) return false;
  831. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  832. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  833. return false;
  834. if (other_inst->opcode() == SpvOpSNegate ||
  835. other_inst->opcode() == SpvOpFNegate) {
  836. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  837. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  838. : inst->GetSingleWordInOperand(1u);
  839. inst->SetInOperands(
  840. {{SPV_OPERAND_TYPE_ID, {const_id}},
  841. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  842. return true;
  843. }
  844. return false;
  845. };
  846. }
  847. // Folds subtraction of a constant and a negation.
  848. // Cases:
  849. // (-x) - 2 = -2 - x
  850. // 2 - (-x) = x + 2
  851. FoldingRule MergeSubNegateArithmetic() {
  852. return [](IRContext* context, Instruction* inst,
  853. const std::vector<const analysis::Constant*>& constants) {
  854. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  855. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  856. const analysis::Type* type =
  857. context->get_type_mgr()->GetType(inst->type_id());
  858. bool uses_float = HasFloatingPoint(type);
  859. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  860. uint32_t width = ElementWidth(type);
  861. if (width != 32 && width != 64) return false;
  862. const analysis::Constant* const_input1 = ConstInput(constants);
  863. if (!const_input1) return false;
  864. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  865. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  866. return false;
  867. if (other_inst->opcode() == SpvOpSNegate ||
  868. other_inst->opcode() == SpvOpFNegate) {
  869. uint32_t op1 = 0;
  870. uint32_t op2 = 0;
  871. SpvOp opcode = inst->opcode();
  872. if (constants[0] != nullptr) {
  873. op1 = other_inst->GetSingleWordInOperand(0u);
  874. op2 = inst->GetSingleWordInOperand(0u);
  875. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  876. } else {
  877. op1 = NegateConstant(const_mgr, const_input1);
  878. op2 = other_inst->GetSingleWordInOperand(0u);
  879. }
  880. inst->SetOpcode(opcode);
  881. inst->SetInOperands(
  882. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  883. return true;
  884. }
  885. return false;
  886. };
  887. }
  888. // Folds addition of an addition where each operation has a constant operand.
  889. // Cases:
  890. // (x + 2) + 2 = x + 4
  891. // (2 + x) + 2 = x + 4
  892. // 2 + (x + 2) = x + 4
  893. // 2 + (2 + x) = x + 4
  894. FoldingRule MergeAddAddArithmetic() {
  895. return [](IRContext* context, Instruction* inst,
  896. const std::vector<const analysis::Constant*>& constants) {
  897. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  898. const analysis::Type* type =
  899. context->get_type_mgr()->GetType(inst->type_id());
  900. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  901. bool uses_float = HasFloatingPoint(type);
  902. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  903. uint32_t width = ElementWidth(type);
  904. if (width != 32 && width != 64) return false;
  905. const analysis::Constant* const_input1 = ConstInput(constants);
  906. if (!const_input1) return false;
  907. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  908. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  909. return false;
  910. if (other_inst->opcode() == SpvOpFAdd ||
  911. other_inst->opcode() == SpvOpIAdd) {
  912. std::vector<const analysis::Constant*> other_constants =
  913. const_mgr->GetOperandConstants(other_inst);
  914. const analysis::Constant* const_input2 = ConstInput(other_constants);
  915. if (!const_input2) return false;
  916. Instruction* non_const_input =
  917. NonConstInput(context, other_constants[0], other_inst);
  918. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  919. const_input1, const_input2);
  920. if (merged_id == 0) return false;
  921. inst->SetInOperands(
  922. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  923. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  924. return true;
  925. }
  926. return false;
  927. };
  928. }
  929. // Folds addition of a subtraction where each operation has a constant operand.
  930. // Cases:
  931. // (x - 2) + 2 = x + 0
  932. // (2 - x) + 2 = 4 - x
  933. // 2 + (x - 2) = x + 0
  934. // 2 + (2 - x) = 4 - x
  935. FoldingRule MergeAddSubArithmetic() {
  936. return [](IRContext* context, Instruction* inst,
  937. const std::vector<const analysis::Constant*>& constants) {
  938. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  939. const analysis::Type* type =
  940. context->get_type_mgr()->GetType(inst->type_id());
  941. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  942. bool uses_float = HasFloatingPoint(type);
  943. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  944. uint32_t width = ElementWidth(type);
  945. if (width != 32 && width != 64) return false;
  946. const analysis::Constant* const_input1 = ConstInput(constants);
  947. if (!const_input1) return false;
  948. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  949. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  950. return false;
  951. if (other_inst->opcode() == SpvOpFSub ||
  952. other_inst->opcode() == SpvOpISub) {
  953. std::vector<const analysis::Constant*> other_constants =
  954. const_mgr->GetOperandConstants(other_inst);
  955. const analysis::Constant* const_input2 = ConstInput(other_constants);
  956. if (!const_input2) return false;
  957. bool first_is_variable = other_constants[0] == nullptr;
  958. SpvOp op = inst->opcode();
  959. uint32_t op1 = 0;
  960. uint32_t op2 = 0;
  961. if (first_is_variable) {
  962. // Subtract constants. Non-constant operand is first.
  963. op1 = other_inst->GetSingleWordInOperand(0u);
  964. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  965. const_input2);
  966. } else {
  967. // Add constants. Constant operand is first. Change the opcode.
  968. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  969. const_input2);
  970. op2 = other_inst->GetSingleWordInOperand(1u);
  971. op = other_inst->opcode();
  972. }
  973. if (op1 == 0 || op2 == 0) return false;
  974. inst->SetOpcode(op);
  975. inst->SetInOperands(
  976. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  977. return true;
  978. }
  979. return false;
  980. };
  981. }
  982. // Folds subtraction of an addition where each operand has a constant operand.
  983. // Cases:
  984. // (x + 2) - 2 = x + 0
  985. // (2 + x) - 2 = x + 0
  986. // 2 - (x + 2) = 0 - x
  987. // 2 - (2 + x) = 0 - x
  988. FoldingRule MergeSubAddArithmetic() {
  989. return [](IRContext* context, Instruction* inst,
  990. const std::vector<const analysis::Constant*>& constants) {
  991. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  992. const analysis::Type* type =
  993. context->get_type_mgr()->GetType(inst->type_id());
  994. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  995. bool uses_float = HasFloatingPoint(type);
  996. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  997. uint32_t width = ElementWidth(type);
  998. if (width != 32 && width != 64) return false;
  999. const analysis::Constant* const_input1 = ConstInput(constants);
  1000. if (!const_input1) return false;
  1001. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1002. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1003. return false;
  1004. if (other_inst->opcode() == SpvOpFAdd ||
  1005. other_inst->opcode() == SpvOpIAdd) {
  1006. std::vector<const analysis::Constant*> other_constants =
  1007. const_mgr->GetOperandConstants(other_inst);
  1008. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1009. if (!const_input2) return false;
  1010. Instruction* non_const_input =
  1011. NonConstInput(context, other_constants[0], other_inst);
  1012. // If the first operand of the sub is not a constant, swap the constants
  1013. // so the subtraction has the correct operands.
  1014. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1015. // Subtract the constants.
  1016. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1017. const_input1, const_input2);
  1018. SpvOp op = inst->opcode();
  1019. uint32_t op1 = 0;
  1020. uint32_t op2 = 0;
  1021. if (constants[0] == nullptr) {
  1022. // Non-constant operand is first. Change the opcode.
  1023. op1 = non_const_input->result_id();
  1024. op2 = merged_id;
  1025. op = other_inst->opcode();
  1026. } else {
  1027. // Constant operand is first.
  1028. op1 = merged_id;
  1029. op2 = non_const_input->result_id();
  1030. }
  1031. if (op1 == 0 || op2 == 0) return false;
  1032. inst->SetOpcode(op);
  1033. inst->SetInOperands(
  1034. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1035. return true;
  1036. }
  1037. return false;
  1038. };
  1039. }
  1040. // Folds subtraction of a subtraction where each operand has a constant operand.
  1041. // Cases:
  1042. // (x - 2) - 2 = x - 4
  1043. // (2 - x) - 2 = 0 - x
  1044. // 2 - (x - 2) = 4 - x
  1045. // 2 - (2 - x) = x + 0
  1046. FoldingRule MergeSubSubArithmetic() {
  1047. return [](IRContext* context, Instruction* inst,
  1048. const std::vector<const analysis::Constant*>& constants) {
  1049. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1050. const analysis::Type* type =
  1051. context->get_type_mgr()->GetType(inst->type_id());
  1052. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1053. bool uses_float = HasFloatingPoint(type);
  1054. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1055. uint32_t width = ElementWidth(type);
  1056. if (width != 32 && width != 64) return false;
  1057. const analysis::Constant* const_input1 = ConstInput(constants);
  1058. if (!const_input1) return false;
  1059. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1060. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1061. return false;
  1062. if (other_inst->opcode() == SpvOpFSub ||
  1063. other_inst->opcode() == SpvOpISub) {
  1064. std::vector<const analysis::Constant*> other_constants =
  1065. const_mgr->GetOperandConstants(other_inst);
  1066. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1067. if (!const_input2) return false;
  1068. Instruction* non_const_input =
  1069. NonConstInput(context, other_constants[0], other_inst);
  1070. // Merge the constants.
  1071. uint32_t merged_id = 0;
  1072. SpvOp merge_op = inst->opcode();
  1073. if (other_constants[0] == nullptr) {
  1074. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1075. } else if (constants[0] == nullptr) {
  1076. std::swap(const_input1, const_input2);
  1077. }
  1078. merged_id =
  1079. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1080. if (merged_id == 0) return false;
  1081. SpvOp op = inst->opcode();
  1082. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1083. // Change the operation.
  1084. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1085. }
  1086. uint32_t op1 = 0;
  1087. uint32_t op2 = 0;
  1088. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1089. op1 = merged_id;
  1090. op2 = non_const_input->result_id();
  1091. } else {
  1092. op1 = non_const_input->result_id();
  1093. op2 = merged_id;
  1094. }
  1095. inst->SetOpcode(op);
  1096. inst->SetInOperands(
  1097. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1098. return true;
  1099. }
  1100. return false;
  1101. };
  1102. }
  1103. FoldingRule IntMultipleBy1() {
  1104. return [](IRContext*, Instruction* inst,
  1105. const std::vector<const analysis::Constant*>& constants) {
  1106. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1107. for (uint32_t i = 0; i < 2; i++) {
  1108. if (constants[i] == nullptr) {
  1109. continue;
  1110. }
  1111. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1112. if (int_constant) {
  1113. uint32_t width = ElementWidth(int_constant->type());
  1114. if (width != 32 && width != 64) return false;
  1115. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1116. : int_constant->GetU64BitValue() == 1ull;
  1117. if (is_one) {
  1118. inst->SetOpcode(SpvOpCopyObject);
  1119. inst->SetInOperands(
  1120. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1121. return true;
  1122. }
  1123. }
  1124. }
  1125. return false;
  1126. };
  1127. }
  1128. FoldingRule CompositeConstructFeedingExtract() {
  1129. return [](IRContext* context, Instruction* inst,
  1130. const std::vector<const analysis::Constant*>&) {
  1131. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1132. // then we can simply use the appropriate element in the construction.
  1133. assert(inst->opcode() == SpvOpCompositeExtract &&
  1134. "Wrong opcode. Should be OpCompositeExtract.");
  1135. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1136. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1137. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1138. Instruction* cinst = def_use_mgr->GetDef(cid);
  1139. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1140. return false;
  1141. }
  1142. std::vector<Operand> operands;
  1143. analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
  1144. if (composite_type->AsVector() == nullptr) {
  1145. // Get the element being extracted from the OpCompositeConstruct
  1146. // Since it is not a vector, it is simple to extract the single element.
  1147. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1148. uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
  1149. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1150. // Add the remaining indices for extraction.
  1151. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1152. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1153. {inst->GetSingleWordInOperand(i)}});
  1154. }
  1155. } else {
  1156. // With vectors we have to handle the case where it is concatenating
  1157. // vectors.
  1158. assert(inst->NumInOperands() == 2 &&
  1159. "Expecting a vector of scalar values.");
  1160. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1161. for (uint32_t construct_index = 0;
  1162. construct_index < cinst->NumInOperands(); ++construct_index) {
  1163. uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
  1164. Instruction* element_def = def_use_mgr->GetDef(element_id);
  1165. analysis::Vector* element_type =
  1166. type_mgr->GetType(element_def->type_id())->AsVector();
  1167. if (element_type) {
  1168. uint32_t vector_size = element_type->element_count();
  1169. if (vector_size < element_index) {
  1170. // The element we want comes after this vector.
  1171. element_index -= vector_size;
  1172. } else {
  1173. // We want an element of this vector.
  1174. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1175. operands.push_back(
  1176. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
  1177. break;
  1178. }
  1179. } else {
  1180. if (element_index == 0) {
  1181. // This is a scalar, and we this is the element we are extracting.
  1182. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1183. break;
  1184. } else {
  1185. // Skip over this scalar value.
  1186. --element_index;
  1187. }
  1188. }
  1189. }
  1190. }
  1191. // If there were no extra indices, then we have the final object. No need
  1192. // to extract even more.
  1193. if (operands.size() == 1) {
  1194. inst->SetOpcode(SpvOpCopyObject);
  1195. }
  1196. inst->SetInOperands(std::move(operands));
  1197. return true;
  1198. };
  1199. }
  1200. FoldingRule CompositeExtractFeedingConstruct() {
  1201. // If the OpCompositeConstruct is simply putting back together elements that
  1202. // where extracted from the same souce, we can simlpy reuse the source.
  1203. //
  1204. // This is a common code pattern because of the way that scalar replacement
  1205. // works.
  1206. return [](IRContext* context, Instruction* inst,
  1207. const std::vector<const analysis::Constant*>&) {
  1208. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1209. "Wrong opcode. Should be OpCompositeConstruct.");
  1210. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1211. uint32_t original_id = 0;
  1212. // Check each element to make sure they are:
  1213. // - extractions
  1214. // - extracting the same position they are inserting
  1215. // - all extract from the same id.
  1216. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1217. uint32_t element_id = inst->GetSingleWordInOperand(i);
  1218. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1219. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1220. return false;
  1221. }
  1222. if (element_inst->NumInOperands() != 2) {
  1223. return false;
  1224. }
  1225. if (element_inst->GetSingleWordInOperand(1) != i) {
  1226. return false;
  1227. }
  1228. if (i == 0) {
  1229. original_id =
  1230. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1231. } else if (original_id != element_inst->GetSingleWordInOperand(
  1232. kExtractCompositeIdInIdx)) {
  1233. return false;
  1234. }
  1235. }
  1236. // The last check it to see that the object being extracted from is the
  1237. // correct type.
  1238. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1239. if (original_inst->type_id() != inst->type_id()) {
  1240. return false;
  1241. }
  1242. // Simplify by using the original object.
  1243. inst->SetOpcode(SpvOpCopyObject);
  1244. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1245. return true;
  1246. };
  1247. }
  1248. FoldingRule InsertFeedingExtract() {
  1249. return [](IRContext* context, Instruction* inst,
  1250. const std::vector<const analysis::Constant*>&) {
  1251. assert(inst->opcode() == SpvOpCompositeExtract &&
  1252. "Wrong opcode. Should be OpCompositeExtract.");
  1253. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1254. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1255. Instruction* cinst = def_use_mgr->GetDef(cid);
  1256. if (cinst->opcode() != SpvOpCompositeInsert) {
  1257. return false;
  1258. }
  1259. // Find the first position where the list of insert and extract indicies
  1260. // differ, if at all.
  1261. uint32_t i;
  1262. for (i = 1; i < inst->NumInOperands(); ++i) {
  1263. if (i + 1 >= cinst->NumInOperands()) {
  1264. break;
  1265. }
  1266. if (inst->GetSingleWordInOperand(i) !=
  1267. cinst->GetSingleWordInOperand(i + 1)) {
  1268. break;
  1269. }
  1270. }
  1271. // We are extracting the element that was inserted.
  1272. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1273. inst->SetOpcode(SpvOpCopyObject);
  1274. inst->SetInOperands(
  1275. {{SPV_OPERAND_TYPE_ID,
  1276. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1277. return true;
  1278. }
  1279. // Extracting the value that was inserted along with values for the base
  1280. // composite. Cannot do anything.
  1281. if (i == inst->NumInOperands()) {
  1282. return false;
  1283. }
  1284. // Extracting an element of the value that was inserted. Extract from
  1285. // that value directly.
  1286. if (i + 1 == cinst->NumInOperands()) {
  1287. std::vector<Operand> operands;
  1288. operands.push_back(
  1289. {SPV_OPERAND_TYPE_ID,
  1290. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1291. for (; i < inst->NumInOperands(); ++i) {
  1292. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1293. {inst->GetSingleWordInOperand(i)}});
  1294. }
  1295. inst->SetInOperands(std::move(operands));
  1296. return true;
  1297. }
  1298. // Extracting a value that is disjoint from the element being inserted.
  1299. // Rewrite the extract to use the composite input to the insert.
  1300. std::vector<Operand> operands;
  1301. operands.push_back(
  1302. {SPV_OPERAND_TYPE_ID,
  1303. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1304. for (i = 1; i < inst->NumInOperands(); ++i) {
  1305. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1306. {inst->GetSingleWordInOperand(i)}});
  1307. }
  1308. inst->SetInOperands(std::move(operands));
  1309. return true;
  1310. };
  1311. }
  1312. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1313. // operands of the VectorShuffle. We just need to adjust the index in the
  1314. // extract instruction.
  1315. FoldingRule VectorShuffleFeedingExtract() {
  1316. return [](IRContext* context, Instruction* inst,
  1317. const std::vector<const analysis::Constant*>&) {
  1318. assert(inst->opcode() == SpvOpCompositeExtract &&
  1319. "Wrong opcode. Should be OpCompositeExtract.");
  1320. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1321. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1322. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1323. Instruction* cinst = def_use_mgr->GetDef(cid);
  1324. if (cinst->opcode() != SpvOpVectorShuffle) {
  1325. return false;
  1326. }
  1327. // Find the size of the first vector operand of the VectorShuffle
  1328. Instruction* first_input =
  1329. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1330. analysis::Type* first_input_type =
  1331. type_mgr->GetType(first_input->type_id());
  1332. assert(first_input_type->AsVector() &&
  1333. "Input to vector shuffle should be vectors.");
  1334. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1335. // Get index of the element the vector shuffle is placing in the position
  1336. // being extracted.
  1337. uint32_t new_index =
  1338. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1339. // Extracting an undefined value so fold this extract into an undef.
  1340. const uint32_t undef_literal_value = 0xffffffff;
  1341. if (new_index == undef_literal_value) {
  1342. inst->SetOpcode(SpvOpUndef);
  1343. inst->SetInOperands({});
  1344. return true;
  1345. }
  1346. // Get the id of the of the vector the elemtent comes from, and update the
  1347. // index if needed.
  1348. uint32_t new_vector = 0;
  1349. if (new_index < first_input_size) {
  1350. new_vector = cinst->GetSingleWordInOperand(0);
  1351. } else {
  1352. new_vector = cinst->GetSingleWordInOperand(1);
  1353. new_index -= first_input_size;
  1354. }
  1355. // Update the extract instruction.
  1356. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1357. inst->SetInOperand(1, {new_index});
  1358. return true;
  1359. };
  1360. }
  1361. // When an FMix with is feeding an Extract that extracts an element whose
  1362. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1363. // operands of the FMix.
  1364. FoldingRule FMixFeedingExtract() {
  1365. return [](IRContext* context, Instruction* inst,
  1366. const std::vector<const analysis::Constant*>&) {
  1367. assert(inst->opcode() == SpvOpCompositeExtract &&
  1368. "Wrong opcode. Should be OpCompositeExtract.");
  1369. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1370. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1371. uint32_t composite_id =
  1372. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1373. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1374. if (composite_inst->opcode() != SpvOpExtInst) {
  1375. return false;
  1376. }
  1377. uint32_t inst_set_id =
  1378. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1379. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1380. inst_set_id ||
  1381. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1382. GLSLstd450FMix) {
  1383. return false;
  1384. }
  1385. // Get the |a| for the FMix instruction.
  1386. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1387. std::unique_ptr<Instruction> a(inst->Clone(context));
  1388. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1389. context->get_instruction_folder().FoldInstruction(a.get());
  1390. if (a->opcode() != SpvOpCopyObject) {
  1391. return false;
  1392. }
  1393. const analysis::Constant* a_const =
  1394. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1395. if (!a_const) {
  1396. return false;
  1397. }
  1398. bool use_x = false;
  1399. assert(a_const->type()->AsFloat());
  1400. double element_value = a_const->GetValueAsDouble();
  1401. if (element_value == 0.0) {
  1402. use_x = true;
  1403. } else if (element_value == 1.0) {
  1404. use_x = false;
  1405. } else {
  1406. return false;
  1407. }
  1408. // Get the id of the of the vector the element comes from.
  1409. uint32_t new_vector = 0;
  1410. if (use_x) {
  1411. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1412. } else {
  1413. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1414. }
  1415. // Update the extract instruction.
  1416. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1417. return true;
  1418. };
  1419. }
  1420. FoldingRule RedundantPhi() {
  1421. // An OpPhi instruction where all values are the same or the result of the phi
  1422. // itself, can be replaced by the value itself.
  1423. return [](IRContext*, Instruction* inst,
  1424. const std::vector<const analysis::Constant*>&) {
  1425. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1426. uint32_t incoming_value = 0;
  1427. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1428. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1429. if (op_id == inst->result_id()) {
  1430. continue;
  1431. }
  1432. if (incoming_value == 0) {
  1433. incoming_value = op_id;
  1434. } else if (op_id != incoming_value) {
  1435. // Found two possible value. Can't simplify.
  1436. return false;
  1437. }
  1438. }
  1439. if (incoming_value == 0) {
  1440. // Code looks invalid. Don't do anything.
  1441. return false;
  1442. }
  1443. // We have a single incoming value. Simplify using that value.
  1444. inst->SetOpcode(SpvOpCopyObject);
  1445. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1446. return true;
  1447. };
  1448. }
  1449. FoldingRule RedundantSelect() {
  1450. // An OpSelect instruction where both values are the same or the condition is
  1451. // constant can be replaced by one of the values
  1452. return [](IRContext*, Instruction* inst,
  1453. const std::vector<const analysis::Constant*>& constants) {
  1454. assert(inst->opcode() == SpvOpSelect &&
  1455. "Wrong opcode. Should be OpSelect.");
  1456. assert(inst->NumInOperands() == 3);
  1457. assert(constants.size() == 3);
  1458. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1459. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1460. if (true_id == false_id) {
  1461. // Both results are the same, condition doesn't matter
  1462. inst->SetOpcode(SpvOpCopyObject);
  1463. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1464. return true;
  1465. } else if (constants[0]) {
  1466. const analysis::Type* type = constants[0]->type();
  1467. if (type->AsBool()) {
  1468. // Scalar constant value, select the corresponding value.
  1469. inst->SetOpcode(SpvOpCopyObject);
  1470. if (constants[0]->AsNullConstant() ||
  1471. !constants[0]->AsBoolConstant()->value()) {
  1472. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1473. } else {
  1474. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1475. }
  1476. return true;
  1477. } else {
  1478. assert(type->AsVector());
  1479. if (constants[0]->AsNullConstant()) {
  1480. // All values come from false id.
  1481. inst->SetOpcode(SpvOpCopyObject);
  1482. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1483. return true;
  1484. } else {
  1485. // Convert to a vector shuffle.
  1486. std::vector<Operand> ops;
  1487. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  1488. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  1489. const analysis::VectorConstant* vector_const =
  1490. constants[0]->AsVectorConstant();
  1491. uint32_t size =
  1492. static_cast<uint32_t>(vector_const->GetComponents().size());
  1493. for (uint32_t i = 0; i != size; ++i) {
  1494. const analysis::Constant* component =
  1495. vector_const->GetComponents()[i];
  1496. if (component->AsNullConstant() ||
  1497. !component->AsBoolConstant()->value()) {
  1498. // Selecting from the false vector which is the second input
  1499. // vector to the shuffle. Offset the index by |size|.
  1500. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  1501. } else {
  1502. // Selecting from true vector which is the first input vector to
  1503. // the shuffle.
  1504. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  1505. }
  1506. }
  1507. inst->SetOpcode(SpvOpVectorShuffle);
  1508. inst->SetInOperands(std::move(ops));
  1509. return true;
  1510. }
  1511. }
  1512. }
  1513. return false;
  1514. };
  1515. }
  1516. enum class FloatConstantKind { Unknown, Zero, One };
  1517. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  1518. if (constant == nullptr) {
  1519. return FloatConstantKind::Unknown;
  1520. }
  1521. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  1522. if (constant->AsNullConstant()) {
  1523. return FloatConstantKind::Zero;
  1524. } else if (const analysis::VectorConstant* vc =
  1525. constant->AsVectorConstant()) {
  1526. const std::vector<const analysis::Constant*>& components =
  1527. vc->GetComponents();
  1528. assert(!components.empty());
  1529. FloatConstantKind kind = getFloatConstantKind(components[0]);
  1530. for (size_t i = 1; i < components.size(); ++i) {
  1531. if (getFloatConstantKind(components[i]) != kind) {
  1532. return FloatConstantKind::Unknown;
  1533. }
  1534. }
  1535. return kind;
  1536. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  1537. if (fc->IsZero()) return FloatConstantKind::Zero;
  1538. uint32_t width = fc->type()->AsFloat()->width();
  1539. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  1540. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  1541. if (value == 0.0) {
  1542. return FloatConstantKind::Zero;
  1543. } else if (value == 1.0) {
  1544. return FloatConstantKind::One;
  1545. } else {
  1546. return FloatConstantKind::Unknown;
  1547. }
  1548. } else {
  1549. return FloatConstantKind::Unknown;
  1550. }
  1551. }
  1552. FoldingRule RedundantFAdd() {
  1553. return [](IRContext*, Instruction* inst,
  1554. const std::vector<const analysis::Constant*>& constants) {
  1555. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  1556. assert(constants.size() == 2);
  1557. if (!inst->IsFloatingPointFoldingAllowed()) {
  1558. return false;
  1559. }
  1560. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1561. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1562. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1563. inst->SetOpcode(SpvOpCopyObject);
  1564. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1565. {inst->GetSingleWordInOperand(
  1566. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  1567. return true;
  1568. }
  1569. return false;
  1570. };
  1571. }
  1572. FoldingRule RedundantFSub() {
  1573. return [](IRContext*, Instruction* inst,
  1574. const std::vector<const analysis::Constant*>& constants) {
  1575. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  1576. assert(constants.size() == 2);
  1577. if (!inst->IsFloatingPointFoldingAllowed()) {
  1578. return false;
  1579. }
  1580. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1581. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1582. if (kind0 == FloatConstantKind::Zero) {
  1583. inst->SetOpcode(SpvOpFNegate);
  1584. inst->SetInOperands(
  1585. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  1586. return true;
  1587. }
  1588. if (kind1 == FloatConstantKind::Zero) {
  1589. inst->SetOpcode(SpvOpCopyObject);
  1590. inst->SetInOperands(
  1591. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1592. return true;
  1593. }
  1594. return false;
  1595. };
  1596. }
  1597. FoldingRule RedundantFMul() {
  1598. return [](IRContext*, Instruction* inst,
  1599. const std::vector<const analysis::Constant*>& constants) {
  1600. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  1601. assert(constants.size() == 2);
  1602. if (!inst->IsFloatingPointFoldingAllowed()) {
  1603. return false;
  1604. }
  1605. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1606. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1607. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1608. inst->SetOpcode(SpvOpCopyObject);
  1609. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1610. {inst->GetSingleWordInOperand(
  1611. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  1612. return true;
  1613. }
  1614. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  1615. inst->SetOpcode(SpvOpCopyObject);
  1616. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1617. {inst->GetSingleWordInOperand(
  1618. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  1619. return true;
  1620. }
  1621. return false;
  1622. };
  1623. }
  1624. FoldingRule RedundantFDiv() {
  1625. return [](IRContext*, Instruction* inst,
  1626. const std::vector<const analysis::Constant*>& constants) {
  1627. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  1628. assert(constants.size() == 2);
  1629. if (!inst->IsFloatingPointFoldingAllowed()) {
  1630. return false;
  1631. }
  1632. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1633. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1634. if (kind0 == FloatConstantKind::Zero) {
  1635. inst->SetOpcode(SpvOpCopyObject);
  1636. inst->SetInOperands(
  1637. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1638. return true;
  1639. }
  1640. if (kind1 == FloatConstantKind::One) {
  1641. inst->SetOpcode(SpvOpCopyObject);
  1642. inst->SetInOperands(
  1643. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1644. return true;
  1645. }
  1646. return false;
  1647. };
  1648. }
  1649. FoldingRule RedundantFMix() {
  1650. return [](IRContext* context, Instruction* inst,
  1651. const std::vector<const analysis::Constant*>& constants) {
  1652. assert(inst->opcode() == SpvOpExtInst &&
  1653. "Wrong opcode. Should be OpExtInst.");
  1654. if (!inst->IsFloatingPointFoldingAllowed()) {
  1655. return false;
  1656. }
  1657. uint32_t instSetId =
  1658. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1659. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  1660. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  1661. GLSLstd450FMix) {
  1662. assert(constants.size() == 5);
  1663. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  1664. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  1665. inst->SetOpcode(SpvOpCopyObject);
  1666. inst->SetInOperands(
  1667. {{SPV_OPERAND_TYPE_ID,
  1668. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  1669. ? kFMixXIdInIdx
  1670. : kFMixYIdInIdx)}}});
  1671. return true;
  1672. }
  1673. }
  1674. return false;
  1675. };
  1676. }
  1677. // This rule handles addition of zero for integers.
  1678. FoldingRule RedundantIAdd() {
  1679. return [](IRContext* context, Instruction* inst,
  1680. const std::vector<const analysis::Constant*>& constants) {
  1681. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  1682. uint32_t operand = std::numeric_limits<uint32_t>::max();
  1683. const analysis::Type* operand_type = nullptr;
  1684. if (constants[0] && constants[0]->IsZero()) {
  1685. operand = inst->GetSingleWordInOperand(1);
  1686. operand_type = constants[0]->type();
  1687. } else if (constants[1] && constants[1]->IsZero()) {
  1688. operand = inst->GetSingleWordInOperand(0);
  1689. operand_type = constants[1]->type();
  1690. }
  1691. if (operand != std::numeric_limits<uint32_t>::max()) {
  1692. const analysis::Type* inst_type =
  1693. context->get_type_mgr()->GetType(inst->type_id());
  1694. if (inst_type->IsSame(operand_type)) {
  1695. inst->SetOpcode(SpvOpCopyObject);
  1696. } else {
  1697. inst->SetOpcode(SpvOpBitcast);
  1698. }
  1699. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  1700. return true;
  1701. }
  1702. return false;
  1703. };
  1704. }
  1705. // This rule look for a dot with a constant vector containing a single 1 and
  1706. // the rest 0s. This is the same as doing an extract.
  1707. FoldingRule DotProductDoingExtract() {
  1708. return [](IRContext* context, Instruction* inst,
  1709. const std::vector<const analysis::Constant*>& constants) {
  1710. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  1711. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1712. if (!inst->IsFloatingPointFoldingAllowed()) {
  1713. return false;
  1714. }
  1715. for (int i = 0; i < 2; ++i) {
  1716. if (!constants[i]) {
  1717. continue;
  1718. }
  1719. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  1720. assert(vector_type && "Inputs to OpDot must be vectors.");
  1721. const analysis::Float* element_type =
  1722. vector_type->element_type()->AsFloat();
  1723. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  1724. uint32_t element_width = element_type->width();
  1725. if (element_width != 32 && element_width != 64) {
  1726. return false;
  1727. }
  1728. std::vector<const analysis::Constant*> components;
  1729. components = constants[i]->GetVectorComponents(const_mgr);
  1730. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  1731. uint32_t component_with_one = kNotFound;
  1732. bool all_others_zero = true;
  1733. for (uint32_t j = 0; j < components.size(); ++j) {
  1734. const analysis::Constant* element = components[j];
  1735. double value =
  1736. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  1737. if (value == 0.0) {
  1738. continue;
  1739. } else if (value == 1.0) {
  1740. if (component_with_one == kNotFound) {
  1741. component_with_one = j;
  1742. } else {
  1743. component_with_one = kNotFound;
  1744. break;
  1745. }
  1746. } else {
  1747. all_others_zero = false;
  1748. break;
  1749. }
  1750. }
  1751. if (!all_others_zero || component_with_one == kNotFound) {
  1752. continue;
  1753. }
  1754. std::vector<Operand> operands;
  1755. operands.push_back(
  1756. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  1757. operands.push_back(
  1758. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  1759. inst->SetOpcode(SpvOpCompositeExtract);
  1760. inst->SetInOperands(std::move(operands));
  1761. return true;
  1762. }
  1763. return false;
  1764. };
  1765. }
  1766. // If we are storing an undef, then we can remove the store.
  1767. //
  1768. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  1769. // is complicated. Waiting to see if it is needed.
  1770. FoldingRule StoringUndef() {
  1771. return [](IRContext* context, Instruction* inst,
  1772. const std::vector<const analysis::Constant*>&) {
  1773. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  1774. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1775. // If this is a volatile store, the store cannot be removed.
  1776. if (inst->NumInOperands() == 3) {
  1777. if (inst->GetSingleWordInOperand(3) & SpvMemoryAccessVolatileMask) {
  1778. return false;
  1779. }
  1780. }
  1781. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  1782. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  1783. if (object_inst->opcode() == SpvOpUndef) {
  1784. inst->ToNop();
  1785. return true;
  1786. }
  1787. return false;
  1788. };
  1789. }
  1790. FoldingRule VectorShuffleFeedingShuffle() {
  1791. return [](IRContext* context, Instruction* inst,
  1792. const std::vector<const analysis::Constant*>&) {
  1793. assert(inst->opcode() == SpvOpVectorShuffle &&
  1794. "Wrong opcode. Should be OpVectorShuffle.");
  1795. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1796. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1797. Instruction* feeding_shuffle_inst =
  1798. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  1799. analysis::Vector* op0_type =
  1800. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  1801. uint32_t op0_length = op0_type->element_count();
  1802. bool feeder_is_op0 = true;
  1803. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1804. feeding_shuffle_inst =
  1805. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  1806. feeder_is_op0 = false;
  1807. }
  1808. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1809. return false;
  1810. }
  1811. Instruction* feeder2 =
  1812. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  1813. analysis::Vector* feeder_op0_type =
  1814. type_mgr->GetType(feeder2->type_id())->AsVector();
  1815. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  1816. uint32_t new_feeder_id = 0;
  1817. std::vector<Operand> new_operands;
  1818. new_operands.resize(
  1819. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  1820. const uint32_t undef_literal = 0xffffffff;
  1821. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  1822. uint32_t component_index = inst->GetSingleWordInOperand(op);
  1823. // Do not interpret the undefined value literal as coming from operand 1.
  1824. if (component_index != undef_literal &&
  1825. feeder_is_op0 == (component_index < op0_length)) {
  1826. // This component comes from the feeding_shuffle_inst. Update
  1827. // |component_index| to be the index into the operand of the feeder.
  1828. // Adjust component_index to get the index into the operands of the
  1829. // feeding_shuffle_inst.
  1830. if (component_index >= op0_length) {
  1831. component_index -= op0_length;
  1832. }
  1833. component_index =
  1834. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  1835. // Check if we are using a component from the first or second operand of
  1836. // the feeding instruction.
  1837. if (component_index < feeder_op0_length) {
  1838. if (new_feeder_id == 0) {
  1839. // First time through, save the id of the operand the element comes
  1840. // from.
  1841. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  1842. } else if (new_feeder_id !=
  1843. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  1844. // We need both elements of the feeding_shuffle_inst, so we cannot
  1845. // fold.
  1846. return false;
  1847. }
  1848. } else {
  1849. if (new_feeder_id == 0) {
  1850. // First time through, save the id of the operand the element comes
  1851. // from.
  1852. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  1853. } else if (new_feeder_id !=
  1854. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  1855. // We need both elements of the feeding_shuffle_inst, so we cannot
  1856. // fold.
  1857. return false;
  1858. }
  1859. component_index -= feeder_op0_length;
  1860. }
  1861. if (!feeder_is_op0) {
  1862. component_index += op0_length;
  1863. }
  1864. }
  1865. new_operands.push_back(
  1866. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  1867. }
  1868. if (new_feeder_id == 0) {
  1869. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1870. const analysis::Type* type =
  1871. type_mgr->GetType(feeding_shuffle_inst->type_id());
  1872. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  1873. new_feeder_id =
  1874. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  1875. }
  1876. if (feeder_is_op0) {
  1877. // If the size of the first vector operand changed then the indices
  1878. // referring to the second operand need to be adjusted.
  1879. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  1880. analysis::Type* new_feeder_type =
  1881. type_mgr->GetType(new_feeder_inst->type_id());
  1882. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  1883. int32_t adjustment = op0_length - new_op0_size;
  1884. if (adjustment != 0) {
  1885. for (uint32_t i = 2; i < new_operands.size(); i++) {
  1886. if (inst->GetSingleWordInOperand(i) >= op0_length) {
  1887. new_operands[i].words[0] -= adjustment;
  1888. }
  1889. }
  1890. }
  1891. new_operands[0].words[0] = new_feeder_id;
  1892. new_operands[1] = inst->GetInOperand(1);
  1893. } else {
  1894. new_operands[1].words[0] = new_feeder_id;
  1895. new_operands[0] = inst->GetInOperand(0);
  1896. }
  1897. inst->SetInOperands(std::move(new_operands));
  1898. return true;
  1899. };
  1900. }
  1901. } // namespace
  1902. FoldingRules::FoldingRules() {
  1903. // Add all folding rules to the list for the opcodes to which they apply.
  1904. // Note that the order in which rules are added to the list matters. If a rule
  1905. // applies to the instruction, the rest of the rules will not be attempted.
  1906. // Take that into consideration.
  1907. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
  1908. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  1909. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
  1910. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  1911. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  1912. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  1913. rules_[SpvOpExtInst].push_back(RedundantFMix());
  1914. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  1915. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  1916. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  1917. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  1918. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  1919. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  1920. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  1921. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  1922. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  1923. rules_[SpvOpFMul].push_back(RedundantFMul());
  1924. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  1925. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  1926. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  1927. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  1928. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  1929. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  1930. rules_[SpvOpFSub].push_back(RedundantFSub());
  1931. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  1932. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  1933. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  1934. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  1935. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  1936. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  1937. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  1938. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  1939. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  1940. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  1941. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  1942. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  1943. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  1944. rules_[SpvOpPhi].push_back(RedundantPhi());
  1945. rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
  1946. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  1947. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  1948. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  1949. rules_[SpvOpSelect].push_back(RedundantSelect());
  1950. rules_[SpvOpStore].push_back(StoringUndef());
  1951. rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
  1952. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  1953. }
  1954. } // namespace opt
  1955. } // namespace spvtools