const_folding_rules.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/const_folding_rules.h"
  15. #include "source/opt/ir_context.h"
  16. namespace spvtools {
  17. namespace opt {
  18. namespace {
  19. const uint32_t kExtractCompositeIdInIdx = 0;
  20. // Returns true if |type| is Float or a vector of Float.
  21. bool HasFloatingPoint(const analysis::Type* type) {
  22. if (type->AsFloat()) {
  23. return true;
  24. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  25. return vec_type->element_type()->AsFloat() != nullptr;
  26. }
  27. return false;
  28. }
  29. // Folds an OpcompositeExtract where input is a composite constant.
  30. ConstantFoldingRule FoldExtractWithConstants() {
  31. return [](IRContext* context, Instruction* inst,
  32. const std::vector<const analysis::Constant*>& constants)
  33. -> const analysis::Constant* {
  34. const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
  35. if (c == nullptr) {
  36. return nullptr;
  37. }
  38. for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
  39. uint32_t element_index = inst->GetSingleWordInOperand(i);
  40. if (c->AsNullConstant()) {
  41. // Return Null for the return type.
  42. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  43. analysis::TypeManager* type_mgr = context->get_type_mgr();
  44. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
  45. }
  46. auto cc = c->AsCompositeConstant();
  47. assert(cc != nullptr);
  48. auto components = cc->GetComponents();
  49. c = components[element_index];
  50. }
  51. return c;
  52. };
  53. }
  54. ConstantFoldingRule FoldVectorShuffleWithConstants() {
  55. return [](IRContext* context, Instruction* inst,
  56. const std::vector<const analysis::Constant*>& constants)
  57. -> const analysis::Constant* {
  58. assert(inst->opcode() == SpvOpVectorShuffle);
  59. const analysis::Constant* c1 = constants[0];
  60. const analysis::Constant* c2 = constants[1];
  61. if (c1 == nullptr || c2 == nullptr) {
  62. return nullptr;
  63. }
  64. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  65. const analysis::Type* element_type = c1->type()->AsVector()->element_type();
  66. std::vector<const analysis::Constant*> c1_components;
  67. if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
  68. c1_components = vec_const->GetComponents();
  69. } else {
  70. assert(c1->AsNullConstant());
  71. const analysis::Constant* element =
  72. const_mgr->GetConstant(element_type, {});
  73. c1_components.resize(c1->type()->AsVector()->element_count(), element);
  74. }
  75. std::vector<const analysis::Constant*> c2_components;
  76. if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
  77. c2_components = vec_const->GetComponents();
  78. } else {
  79. assert(c2->AsNullConstant());
  80. const analysis::Constant* element =
  81. const_mgr->GetConstant(element_type, {});
  82. c2_components.resize(c2->type()->AsVector()->element_count(), element);
  83. }
  84. std::vector<uint32_t> ids;
  85. const uint32_t undef_literal_value = 0xffffffff;
  86. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  87. uint32_t index = inst->GetSingleWordInOperand(i);
  88. if (index == undef_literal_value) {
  89. // Don't fold shuffle with undef literal value.
  90. return nullptr;
  91. } else if (index < c1_components.size()) {
  92. Instruction* member_inst =
  93. const_mgr->GetDefiningInstruction(c1_components[index]);
  94. ids.push_back(member_inst->result_id());
  95. } else {
  96. Instruction* member_inst = const_mgr->GetDefiningInstruction(
  97. c2_components[index - c1_components.size()]);
  98. ids.push_back(member_inst->result_id());
  99. }
  100. }
  101. analysis::TypeManager* type_mgr = context->get_type_mgr();
  102. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
  103. };
  104. }
  105. ConstantFoldingRule FoldVectorTimesScalar() {
  106. return [](IRContext* context, Instruction* inst,
  107. const std::vector<const analysis::Constant*>& constants)
  108. -> const analysis::Constant* {
  109. assert(inst->opcode() == SpvOpVectorTimesScalar);
  110. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  111. analysis::TypeManager* type_mgr = context->get_type_mgr();
  112. if (!inst->IsFloatingPointFoldingAllowed()) {
  113. if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
  114. return nullptr;
  115. }
  116. }
  117. const analysis::Constant* c1 = constants[0];
  118. const analysis::Constant* c2 = constants[1];
  119. if (c1 && c1->IsZero()) {
  120. return c1;
  121. }
  122. if (c2 && c2->IsZero()) {
  123. // Get or create the NullConstant for this type.
  124. std::vector<uint32_t> ids;
  125. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
  126. }
  127. if (c1 == nullptr || c2 == nullptr) {
  128. return nullptr;
  129. }
  130. // Check result type.
  131. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  132. const analysis::Vector* vector_type = result_type->AsVector();
  133. assert(vector_type != nullptr);
  134. const analysis::Type* element_type = vector_type->element_type();
  135. assert(element_type != nullptr);
  136. const analysis::Float* float_type = element_type->AsFloat();
  137. assert(float_type != nullptr);
  138. // Check types of c1 and c2.
  139. assert(c1->type()->AsVector() == vector_type);
  140. assert(c1->type()->AsVector()->element_type() == element_type &&
  141. c2->type() == element_type);
  142. // Get a float vector that is the result of vector-times-scalar.
  143. std::vector<const analysis::Constant*> c1_components =
  144. c1->GetVectorComponents(const_mgr);
  145. std::vector<uint32_t> ids;
  146. if (float_type->width() == 32) {
  147. float scalar = c2->GetFloat();
  148. for (uint32_t i = 0; i < c1_components.size(); ++i) {
  149. utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
  150. std::vector<uint32_t> words = result.GetWords();
  151. const analysis::Constant* new_elem =
  152. const_mgr->GetConstant(float_type, words);
  153. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  154. }
  155. return const_mgr->GetConstant(vector_type, ids);
  156. } else if (float_type->width() == 64) {
  157. double scalar = c2->GetDouble();
  158. for (uint32_t i = 0; i < c1_components.size(); ++i) {
  159. utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
  160. scalar);
  161. std::vector<uint32_t> words = result.GetWords();
  162. const analysis::Constant* new_elem =
  163. const_mgr->GetConstant(float_type, words);
  164. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  165. }
  166. return const_mgr->GetConstant(vector_type, ids);
  167. }
  168. return nullptr;
  169. };
  170. }
  171. ConstantFoldingRule FoldCompositeWithConstants() {
  172. // Folds an OpCompositeConstruct where all of the inputs are constants to a
  173. // constant. A new constant is created if necessary.
  174. return [](IRContext* context, Instruction* inst,
  175. const std::vector<const analysis::Constant*>& constants)
  176. -> const analysis::Constant* {
  177. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  178. analysis::TypeManager* type_mgr = context->get_type_mgr();
  179. const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
  180. Instruction* type_inst =
  181. context->get_def_use_mgr()->GetDef(inst->type_id());
  182. std::vector<uint32_t> ids;
  183. for (uint32_t i = 0; i < constants.size(); ++i) {
  184. const analysis::Constant* element_const = constants[i];
  185. if (element_const == nullptr) {
  186. return nullptr;
  187. }
  188. uint32_t component_type_id = 0;
  189. if (type_inst->opcode() == SpvOpTypeStruct) {
  190. component_type_id = type_inst->GetSingleWordInOperand(i);
  191. } else if (type_inst->opcode() == SpvOpTypeArray) {
  192. component_type_id = type_inst->GetSingleWordInOperand(0);
  193. }
  194. uint32_t element_id =
  195. const_mgr->FindDeclaredConstant(element_const, component_type_id);
  196. if (element_id == 0) {
  197. return nullptr;
  198. }
  199. ids.push_back(element_id);
  200. }
  201. return const_mgr->GetConstant(new_type, ids);
  202. };
  203. }
  204. // The interface for a function that returns the result of applying a scalar
  205. // floating-point binary operation on |a| and |b|. The type of the return value
  206. // will be |type|. The input constants must also be of type |type|.
  207. using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
  208. const analysis::Type* result_type, const analysis::Constant* a,
  209. analysis::ConstantManager*)>;
  210. // The interface for a function that returns the result of applying a scalar
  211. // floating-point binary operation on |a| and |b|. The type of the return value
  212. // will be |type|. The input constants must also be of type |type|.
  213. using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
  214. const analysis::Type* result_type, const analysis::Constant* a,
  215. const analysis::Constant* b, analysis::ConstantManager*)>;
  216. // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
  217. // using |scalar_rule| and unary float point vectors ops by applying
  218. // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
  219. // that is returned assumes that |constants| contains 1 entry. If they are
  220. // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
  221. // whose element type is |Float| or |Integer|.
  222. ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
  223. return [scalar_rule](IRContext* context, Instruction* inst,
  224. const std::vector<const analysis::Constant*>& constants)
  225. -> const analysis::Constant* {
  226. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  227. analysis::TypeManager* type_mgr = context->get_type_mgr();
  228. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  229. const analysis::Vector* vector_type = result_type->AsVector();
  230. if (!inst->IsFloatingPointFoldingAllowed()) {
  231. return nullptr;
  232. }
  233. if (constants[0] == nullptr) {
  234. return nullptr;
  235. }
  236. if (vector_type != nullptr) {
  237. std::vector<const analysis::Constant*> a_components;
  238. std::vector<const analysis::Constant*> results_components;
  239. a_components = constants[0]->GetVectorComponents(const_mgr);
  240. // Fold each component of the vector.
  241. for (uint32_t i = 0; i < a_components.size(); ++i) {
  242. results_components.push_back(scalar_rule(vector_type->element_type(),
  243. a_components[i], const_mgr));
  244. if (results_components[i] == nullptr) {
  245. return nullptr;
  246. }
  247. }
  248. // Build the constant object and return it.
  249. std::vector<uint32_t> ids;
  250. for (const analysis::Constant* member : results_components) {
  251. ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
  252. }
  253. return const_mgr->GetConstant(vector_type, ids);
  254. } else {
  255. return scalar_rule(result_type, constants[0], const_mgr);
  256. }
  257. };
  258. }
  259. // Returns a |ConstantFoldingRule| that folds floating point scalars using
  260. // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
  261. // elements of the vector. The |ConstantFoldingRule| that is returned assumes
  262. // that |constants| contains 2 entries. If they are not |nullptr|, then their
  263. // type is either |Float| or a |Vector| whose element type is |Float|.
  264. ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
  265. return [scalar_rule](IRContext* context, Instruction* inst,
  266. const std::vector<const analysis::Constant*>& constants)
  267. -> const analysis::Constant* {
  268. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  269. analysis::TypeManager* type_mgr = context->get_type_mgr();
  270. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  271. const analysis::Vector* vector_type = result_type->AsVector();
  272. if (!inst->IsFloatingPointFoldingAllowed()) {
  273. return nullptr;
  274. }
  275. if (constants[0] == nullptr || constants[1] == nullptr) {
  276. return nullptr;
  277. }
  278. if (vector_type != nullptr) {
  279. std::vector<const analysis::Constant*> a_components;
  280. std::vector<const analysis::Constant*> b_components;
  281. std::vector<const analysis::Constant*> results_components;
  282. a_components = constants[0]->GetVectorComponents(const_mgr);
  283. b_components = constants[1]->GetVectorComponents(const_mgr);
  284. // Fold each component of the vector.
  285. for (uint32_t i = 0; i < a_components.size(); ++i) {
  286. results_components.push_back(scalar_rule(vector_type->element_type(),
  287. a_components[i],
  288. b_components[i], const_mgr));
  289. if (results_components[i] == nullptr) {
  290. return nullptr;
  291. }
  292. }
  293. // Build the constant object and return it.
  294. std::vector<uint32_t> ids;
  295. for (const analysis::Constant* member : results_components) {
  296. ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
  297. }
  298. return const_mgr->GetConstant(vector_type, ids);
  299. } else {
  300. return scalar_rule(result_type, constants[0], constants[1], const_mgr);
  301. }
  302. };
  303. }
  304. // This macro defines a |UnaryScalarFoldingRule| that performs float to
  305. // integer conversion.
  306. // TODO(greg-lunarg): Support for 64-bit integer types.
  307. UnaryScalarFoldingRule FoldFToIOp() {
  308. return [](const analysis::Type* result_type, const analysis::Constant* a,
  309. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  310. assert(result_type != nullptr && a != nullptr);
  311. const analysis::Integer* integer_type = result_type->AsInteger();
  312. const analysis::Float* float_type = a->type()->AsFloat();
  313. assert(float_type != nullptr);
  314. assert(integer_type != nullptr);
  315. if (integer_type->width() != 32) return nullptr;
  316. if (float_type->width() == 32) {
  317. float fa = a->GetFloat();
  318. uint32_t result = integer_type->IsSigned()
  319. ? static_cast<uint32_t>(static_cast<int32_t>(fa))
  320. : static_cast<uint32_t>(fa);
  321. std::vector<uint32_t> words = {result};
  322. return const_mgr->GetConstant(result_type, words);
  323. } else if (float_type->width() == 64) {
  324. double fa = a->GetDouble();
  325. uint32_t result = integer_type->IsSigned()
  326. ? static_cast<uint32_t>(static_cast<int32_t>(fa))
  327. : static_cast<uint32_t>(fa);
  328. std::vector<uint32_t> words = {result};
  329. return const_mgr->GetConstant(result_type, words);
  330. }
  331. return nullptr;
  332. };
  333. }
  334. // This function defines a |UnaryScalarFoldingRule| that performs integer to
  335. // float conversion.
  336. // TODO(greg-lunarg): Support for 64-bit integer types.
  337. UnaryScalarFoldingRule FoldIToFOp() {
  338. return [](const analysis::Type* result_type, const analysis::Constant* a,
  339. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  340. assert(result_type != nullptr && a != nullptr);
  341. const analysis::Integer* integer_type = a->type()->AsInteger();
  342. const analysis::Float* float_type = result_type->AsFloat();
  343. assert(float_type != nullptr);
  344. assert(integer_type != nullptr);
  345. if (integer_type->width() != 32) return nullptr;
  346. uint32_t ua = a->GetU32();
  347. if (float_type->width() == 32) {
  348. float result_val = integer_type->IsSigned()
  349. ? static_cast<float>(static_cast<int32_t>(ua))
  350. : static_cast<float>(ua);
  351. utils::FloatProxy<float> result(result_val);
  352. std::vector<uint32_t> words = {result.data()};
  353. return const_mgr->GetConstant(result_type, words);
  354. } else if (float_type->width() == 64) {
  355. double result_val = integer_type->IsSigned()
  356. ? static_cast<double>(static_cast<int32_t>(ua))
  357. : static_cast<double>(ua);
  358. utils::FloatProxy<double> result(result_val);
  359. std::vector<uint32_t> words = result.GetWords();
  360. return const_mgr->GetConstant(result_type, words);
  361. }
  362. return nullptr;
  363. };
  364. }
  365. // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
  366. // operator |op| must work for both float and double, and use syntax "f1 op f2".
  367. #define FOLD_FPARITH_OP(op) \
  368. [](const analysis::Type* result_type, const analysis::Constant* a, \
  369. const analysis::Constant* b, \
  370. analysis::ConstantManager* const_mgr_in_macro) \
  371. -> const analysis::Constant* { \
  372. assert(result_type != nullptr && a != nullptr && b != nullptr); \
  373. assert(result_type == a->type() && result_type == b->type()); \
  374. const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
  375. assert(float_type_in_macro != nullptr); \
  376. if (float_type_in_macro->width() == 32) { \
  377. float fa = a->GetFloat(); \
  378. float fb = b->GetFloat(); \
  379. utils::FloatProxy<float> result_in_macro(fa op fb); \
  380. std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
  381. return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
  382. } else if (float_type_in_macro->width() == 64) { \
  383. double fa = a->GetDouble(); \
  384. double fb = b->GetDouble(); \
  385. utils::FloatProxy<double> result_in_macro(fa op fb); \
  386. std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
  387. return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
  388. } \
  389. return nullptr; \
  390. }
  391. // Define the folding rule for conversion between floating point and integer
  392. ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
  393. ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
  394. // Define the folding rules for subtraction, addition, multiplication, and
  395. // division for floating point values.
  396. ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
  397. ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
  398. ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
  399. ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
  400. bool CompareFloatingPoint(bool op_result, bool op_unordered,
  401. bool need_ordered) {
  402. if (need_ordered) {
  403. // operands are ordered and Operand 1 is |op| Operand 2
  404. return !op_unordered && op_result;
  405. } else {
  406. // operands are unordered or Operand 1 is |op| Operand 2
  407. return op_unordered || op_result;
  408. }
  409. }
  410. // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
  411. // operator |op| must work for both float and double, and use syntax "f1 op f2".
  412. #define FOLD_FPCMP_OP(op, ord) \
  413. [](const analysis::Type* result_type, const analysis::Constant* a, \
  414. const analysis::Constant* b, \
  415. analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
  416. assert(result_type != nullptr && a != nullptr && b != nullptr); \
  417. assert(result_type->AsBool()); \
  418. assert(a->type() == b->type()); \
  419. const analysis::Float* float_type = a->type()->AsFloat(); \
  420. assert(float_type != nullptr); \
  421. if (float_type->width() == 32) { \
  422. float fa = a->GetFloat(); \
  423. float fb = b->GetFloat(); \
  424. bool result = CompareFloatingPoint( \
  425. fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
  426. std::vector<uint32_t> words = {uint32_t(result)}; \
  427. return const_mgr->GetConstant(result_type, words); \
  428. } else if (float_type->width() == 64) { \
  429. double fa = a->GetDouble(); \
  430. double fb = b->GetDouble(); \
  431. bool result = CompareFloatingPoint( \
  432. fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
  433. std::vector<uint32_t> words = {uint32_t(result)}; \
  434. return const_mgr->GetConstant(result_type, words); \
  435. } \
  436. return nullptr; \
  437. }
  438. // Define the folding rules for ordered and unordered comparison for floating
  439. // point values.
  440. ConstantFoldingRule FoldFOrdEqual() {
  441. return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
  442. }
  443. ConstantFoldingRule FoldFUnordEqual() {
  444. return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
  445. }
  446. ConstantFoldingRule FoldFOrdNotEqual() {
  447. return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
  448. }
  449. ConstantFoldingRule FoldFUnordNotEqual() {
  450. return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
  451. }
  452. ConstantFoldingRule FoldFOrdLessThan() {
  453. return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
  454. }
  455. ConstantFoldingRule FoldFUnordLessThan() {
  456. return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
  457. }
  458. ConstantFoldingRule FoldFOrdGreaterThan() {
  459. return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
  460. }
  461. ConstantFoldingRule FoldFUnordGreaterThan() {
  462. return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
  463. }
  464. ConstantFoldingRule FoldFOrdLessThanEqual() {
  465. return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
  466. }
  467. ConstantFoldingRule FoldFUnordLessThanEqual() {
  468. return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
  469. }
  470. ConstantFoldingRule FoldFOrdGreaterThanEqual() {
  471. return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
  472. }
  473. ConstantFoldingRule FoldFUnordGreaterThanEqual() {
  474. return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
  475. }
  476. // Folds an OpDot where all of the inputs are constants to a
  477. // constant. A new constant is created if necessary.
  478. ConstantFoldingRule FoldOpDotWithConstants() {
  479. return [](IRContext* context, Instruction* inst,
  480. const std::vector<const analysis::Constant*>& constants)
  481. -> const analysis::Constant* {
  482. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  483. analysis::TypeManager* type_mgr = context->get_type_mgr();
  484. const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
  485. assert(new_type->AsFloat() && "OpDot should have a float return type.");
  486. const analysis::Float* float_type = new_type->AsFloat();
  487. if (!inst->IsFloatingPointFoldingAllowed()) {
  488. return nullptr;
  489. }
  490. // If one of the operands is 0, then the result is 0.
  491. bool has_zero_operand = false;
  492. for (int i = 0; i < 2; ++i) {
  493. if (constants[i]) {
  494. if (constants[i]->AsNullConstant() ||
  495. constants[i]->AsVectorConstant()->IsZero()) {
  496. has_zero_operand = true;
  497. break;
  498. }
  499. }
  500. }
  501. if (has_zero_operand) {
  502. if (float_type->width() == 32) {
  503. utils::FloatProxy<float> result(0.0f);
  504. std::vector<uint32_t> words = result.GetWords();
  505. return const_mgr->GetConstant(float_type, words);
  506. }
  507. if (float_type->width() == 64) {
  508. utils::FloatProxy<double> result(0.0);
  509. std::vector<uint32_t> words = result.GetWords();
  510. return const_mgr->GetConstant(float_type, words);
  511. }
  512. return nullptr;
  513. }
  514. if (constants[0] == nullptr || constants[1] == nullptr) {
  515. return nullptr;
  516. }
  517. std::vector<const analysis::Constant*> a_components;
  518. std::vector<const analysis::Constant*> b_components;
  519. a_components = constants[0]->GetVectorComponents(const_mgr);
  520. b_components = constants[1]->GetVectorComponents(const_mgr);
  521. utils::FloatProxy<double> result(0.0);
  522. std::vector<uint32_t> words = result.GetWords();
  523. const analysis::Constant* result_const =
  524. const_mgr->GetConstant(float_type, words);
  525. for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
  526. ++i) {
  527. if (a_components[i] == nullptr || b_components[i] == nullptr) {
  528. return nullptr;
  529. }
  530. const analysis::Constant* component = FOLD_FPARITH_OP(*)(
  531. new_type, a_components[i], b_components[i], const_mgr);
  532. if (component == nullptr) {
  533. return nullptr;
  534. }
  535. result_const =
  536. FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
  537. }
  538. return result_const;
  539. };
  540. }
  541. // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
  542. // from zero.
  543. UnaryScalarFoldingRule FoldFNegateOp() {
  544. return [](const analysis::Type* result_type, const analysis::Constant* a,
  545. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  546. assert(result_type != nullptr && a != nullptr);
  547. assert(result_type == a->type());
  548. const analysis::Float* float_type = result_type->AsFloat();
  549. assert(float_type != nullptr);
  550. if (float_type->width() == 32) {
  551. float fa = a->GetFloat();
  552. utils::FloatProxy<float> result(-fa);
  553. std::vector<uint32_t> words = result.GetWords();
  554. return const_mgr->GetConstant(result_type, words);
  555. } else if (float_type->width() == 64) {
  556. double da = a->GetDouble();
  557. utils::FloatProxy<double> result(-da);
  558. std::vector<uint32_t> words = result.GetWords();
  559. return const_mgr->GetConstant(result_type, words);
  560. }
  561. return nullptr;
  562. };
  563. }
  564. ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
  565. ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
  566. return [cmp_opcode](IRContext* context, Instruction* inst,
  567. const std::vector<const analysis::Constant*>& constants)
  568. -> const analysis::Constant* {
  569. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  570. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  571. if (!inst->IsFloatingPointFoldingAllowed()) {
  572. return nullptr;
  573. }
  574. uint32_t non_const_idx = (constants[0] ? 1 : 0);
  575. uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
  576. Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
  577. analysis::TypeManager* type_mgr = context->get_type_mgr();
  578. const analysis::Type* operand_type =
  579. type_mgr->GetType(operand_inst->type_id());
  580. if (!operand_type->AsFloat()) {
  581. return nullptr;
  582. }
  583. if (operand_type->AsFloat()->width() != 32 &&
  584. operand_type->AsFloat()->width() != 64) {
  585. return nullptr;
  586. }
  587. if (operand_inst->opcode() != SpvOpExtInst) {
  588. return nullptr;
  589. }
  590. if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
  591. return nullptr;
  592. }
  593. if (constants[1] == nullptr && constants[0] == nullptr) {
  594. return nullptr;
  595. }
  596. uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
  597. const analysis::Constant* max_const =
  598. const_mgr->FindDeclaredConstant(max_id);
  599. uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
  600. const analysis::Constant* min_const =
  601. const_mgr->FindDeclaredConstant(min_id);
  602. bool found_result = false;
  603. bool result = false;
  604. switch (cmp_opcode) {
  605. case SpvOpFOrdLessThan:
  606. case SpvOpFUnordLessThan:
  607. case SpvOpFOrdGreaterThanEqual:
  608. case SpvOpFUnordGreaterThanEqual:
  609. if (constants[0]) {
  610. if (min_const) {
  611. if (constants[0]->GetValueAsDouble() <
  612. min_const->GetValueAsDouble()) {
  613. found_result = true;
  614. result = (cmp_opcode == SpvOpFOrdLessThan ||
  615. cmp_opcode == SpvOpFUnordLessThan);
  616. }
  617. }
  618. if (max_const) {
  619. if (constants[0]->GetValueAsDouble() >=
  620. max_const->GetValueAsDouble()) {
  621. found_result = true;
  622. result = !(cmp_opcode == SpvOpFOrdLessThan ||
  623. cmp_opcode == SpvOpFUnordLessThan);
  624. }
  625. }
  626. }
  627. if (constants[1]) {
  628. if (max_const) {
  629. if (max_const->GetValueAsDouble() <
  630. constants[1]->GetValueAsDouble()) {
  631. found_result = true;
  632. result = (cmp_opcode == SpvOpFOrdLessThan ||
  633. cmp_opcode == SpvOpFUnordLessThan);
  634. }
  635. }
  636. if (min_const) {
  637. if (min_const->GetValueAsDouble() >=
  638. constants[1]->GetValueAsDouble()) {
  639. found_result = true;
  640. result = !(cmp_opcode == SpvOpFOrdLessThan ||
  641. cmp_opcode == SpvOpFUnordLessThan);
  642. }
  643. }
  644. }
  645. break;
  646. case SpvOpFOrdGreaterThan:
  647. case SpvOpFUnordGreaterThan:
  648. case SpvOpFOrdLessThanEqual:
  649. case SpvOpFUnordLessThanEqual:
  650. if (constants[0]) {
  651. if (min_const) {
  652. if (constants[0]->GetValueAsDouble() <=
  653. min_const->GetValueAsDouble()) {
  654. found_result = true;
  655. result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
  656. cmp_opcode == SpvOpFUnordLessThanEqual);
  657. }
  658. }
  659. if (max_const) {
  660. if (constants[0]->GetValueAsDouble() >
  661. max_const->GetValueAsDouble()) {
  662. found_result = true;
  663. result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
  664. cmp_opcode == SpvOpFUnordLessThanEqual);
  665. }
  666. }
  667. }
  668. if (constants[1]) {
  669. if (max_const) {
  670. if (max_const->GetValueAsDouble() <=
  671. constants[1]->GetValueAsDouble()) {
  672. found_result = true;
  673. result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
  674. cmp_opcode == SpvOpFUnordLessThanEqual);
  675. }
  676. }
  677. if (min_const) {
  678. if (min_const->GetValueAsDouble() >
  679. constants[1]->GetValueAsDouble()) {
  680. found_result = true;
  681. result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
  682. cmp_opcode == SpvOpFUnordLessThanEqual);
  683. }
  684. }
  685. }
  686. break;
  687. default:
  688. return nullptr;
  689. }
  690. if (!found_result) {
  691. return nullptr;
  692. }
  693. const analysis::Type* bool_type =
  694. context->get_type_mgr()->GetType(inst->type_id());
  695. const analysis::Constant* result_const =
  696. const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
  697. assert(result_const);
  698. return result_const;
  699. };
  700. }
  701. } // namespace
  702. ConstantFoldingRules::ConstantFoldingRules() {
  703. // Add all folding rules to the list for the opcodes to which they apply.
  704. // Note that the order in which rules are added to the list matters. If a rule
  705. // applies to the instruction, the rest of the rules will not be attempted.
  706. // Take that into consideration.
  707. rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
  708. rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
  709. rules_[SpvOpConvertFToS].push_back(FoldFToI());
  710. rules_[SpvOpConvertFToU].push_back(FoldFToI());
  711. rules_[SpvOpConvertSToF].push_back(FoldIToF());
  712. rules_[SpvOpConvertUToF].push_back(FoldIToF());
  713. rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
  714. rules_[SpvOpFAdd].push_back(FoldFAdd());
  715. rules_[SpvOpFDiv].push_back(FoldFDiv());
  716. rules_[SpvOpFMul].push_back(FoldFMul());
  717. rules_[SpvOpFSub].push_back(FoldFSub());
  718. rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
  719. rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
  720. rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
  721. rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
  722. rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
  723. rules_[SpvOpFOrdLessThan].push_back(
  724. FoldFClampFeedingCompare(SpvOpFOrdLessThan));
  725. rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
  726. rules_[SpvOpFUnordLessThan].push_back(
  727. FoldFClampFeedingCompare(SpvOpFUnordLessThan));
  728. rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
  729. rules_[SpvOpFOrdGreaterThan].push_back(
  730. FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
  731. rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
  732. rules_[SpvOpFUnordGreaterThan].push_back(
  733. FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
  734. rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
  735. rules_[SpvOpFOrdLessThanEqual].push_back(
  736. FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
  737. rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
  738. rules_[SpvOpFUnordLessThanEqual].push_back(
  739. FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
  740. rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
  741. rules_[SpvOpFOrdGreaterThanEqual].push_back(
  742. FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
  743. rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
  744. rules_[SpvOpFUnordGreaterThanEqual].push_back(
  745. FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
  746. rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
  747. rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
  748. rules_[SpvOpFNegate].push_back(FoldFNegate());
  749. }
  750. } // namespace opt
  751. } // namespace spvtools