const_folding_rules.cpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/const_folding_rules.h"
  15. #include "source/opt/ir_context.h"
  16. namespace spvtools {
  17. namespace opt {
  18. namespace {
  19. const uint32_t kExtractCompositeIdInIdx = 0;
  20. // Returns true if |type| is Float or a vector of Float.
  21. bool HasFloatingPoint(const analysis::Type* type) {
  22. if (type->AsFloat()) {
  23. return true;
  24. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  25. return vec_type->element_type()->AsFloat() != nullptr;
  26. }
  27. return false;
  28. }
  29. // Folds an OpcompositeExtract where input is a composite constant.
  30. ConstantFoldingRule FoldExtractWithConstants() {
  31. return [](IRContext* context, Instruction* inst,
  32. const std::vector<const analysis::Constant*>& constants)
  33. -> const analysis::Constant* {
  34. const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
  35. if (c == nullptr) {
  36. return nullptr;
  37. }
  38. for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
  39. uint32_t element_index = inst->GetSingleWordInOperand(i);
  40. if (c->AsNullConstant()) {
  41. // Return Null for the return type.
  42. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  43. analysis::TypeManager* type_mgr = context->get_type_mgr();
  44. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
  45. }
  46. auto cc = c->AsCompositeConstant();
  47. assert(cc != nullptr);
  48. auto components = cc->GetComponents();
  49. // Protect against invalid IR. Refuse to fold if the index is out
  50. // of bounds.
  51. if (element_index >= components.size()) return nullptr;
  52. c = components[element_index];
  53. }
  54. return c;
  55. };
  56. }
  57. ConstantFoldingRule FoldVectorShuffleWithConstants() {
  58. return [](IRContext* context, Instruction* inst,
  59. const std::vector<const analysis::Constant*>& constants)
  60. -> const analysis::Constant* {
  61. assert(inst->opcode() == SpvOpVectorShuffle);
  62. const analysis::Constant* c1 = constants[0];
  63. const analysis::Constant* c2 = constants[1];
  64. if (c1 == nullptr || c2 == nullptr) {
  65. return nullptr;
  66. }
  67. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  68. const analysis::Type* element_type = c1->type()->AsVector()->element_type();
  69. std::vector<const analysis::Constant*> c1_components;
  70. if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
  71. c1_components = vec_const->GetComponents();
  72. } else {
  73. assert(c1->AsNullConstant());
  74. const analysis::Constant* element =
  75. const_mgr->GetConstant(element_type, {});
  76. c1_components.resize(c1->type()->AsVector()->element_count(), element);
  77. }
  78. std::vector<const analysis::Constant*> c2_components;
  79. if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
  80. c2_components = vec_const->GetComponents();
  81. } else {
  82. assert(c2->AsNullConstant());
  83. const analysis::Constant* element =
  84. const_mgr->GetConstant(element_type, {});
  85. c2_components.resize(c2->type()->AsVector()->element_count(), element);
  86. }
  87. std::vector<uint32_t> ids;
  88. const uint32_t undef_literal_value = 0xffffffff;
  89. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  90. uint32_t index = inst->GetSingleWordInOperand(i);
  91. if (index == undef_literal_value) {
  92. // Don't fold shuffle with undef literal value.
  93. return nullptr;
  94. } else if (index < c1_components.size()) {
  95. Instruction* member_inst =
  96. const_mgr->GetDefiningInstruction(c1_components[index]);
  97. ids.push_back(member_inst->result_id());
  98. } else {
  99. Instruction* member_inst = const_mgr->GetDefiningInstruction(
  100. c2_components[index - c1_components.size()]);
  101. ids.push_back(member_inst->result_id());
  102. }
  103. }
  104. analysis::TypeManager* type_mgr = context->get_type_mgr();
  105. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
  106. };
  107. }
  108. ConstantFoldingRule FoldVectorTimesScalar() {
  109. return [](IRContext* context, Instruction* inst,
  110. const std::vector<const analysis::Constant*>& constants)
  111. -> const analysis::Constant* {
  112. assert(inst->opcode() == SpvOpVectorTimesScalar);
  113. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  114. analysis::TypeManager* type_mgr = context->get_type_mgr();
  115. if (!inst->IsFloatingPointFoldingAllowed()) {
  116. if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
  117. return nullptr;
  118. }
  119. }
  120. const analysis::Constant* c1 = constants[0];
  121. const analysis::Constant* c2 = constants[1];
  122. if (c1 && c1->IsZero()) {
  123. return c1;
  124. }
  125. if (c2 && c2->IsZero()) {
  126. // Get or create the NullConstant for this type.
  127. std::vector<uint32_t> ids;
  128. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
  129. }
  130. if (c1 == nullptr || c2 == nullptr) {
  131. return nullptr;
  132. }
  133. // Check result type.
  134. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  135. const analysis::Vector* vector_type = result_type->AsVector();
  136. assert(vector_type != nullptr);
  137. const analysis::Type* element_type = vector_type->element_type();
  138. assert(element_type != nullptr);
  139. const analysis::Float* float_type = element_type->AsFloat();
  140. assert(float_type != nullptr);
  141. // Check types of c1 and c2.
  142. assert(c1->type()->AsVector() == vector_type);
  143. assert(c1->type()->AsVector()->element_type() == element_type &&
  144. c2->type() == element_type);
  145. // Get a float vector that is the result of vector-times-scalar.
  146. std::vector<const analysis::Constant*> c1_components =
  147. c1->GetVectorComponents(const_mgr);
  148. std::vector<uint32_t> ids;
  149. if (float_type->width() == 32) {
  150. float scalar = c2->GetFloat();
  151. for (uint32_t i = 0; i < c1_components.size(); ++i) {
  152. utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
  153. std::vector<uint32_t> words = result.GetWords();
  154. const analysis::Constant* new_elem =
  155. const_mgr->GetConstant(float_type, words);
  156. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  157. }
  158. return const_mgr->GetConstant(vector_type, ids);
  159. } else if (float_type->width() == 64) {
  160. double scalar = c2->GetDouble();
  161. for (uint32_t i = 0; i < c1_components.size(); ++i) {
  162. utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
  163. scalar);
  164. std::vector<uint32_t> words = result.GetWords();
  165. const analysis::Constant* new_elem =
  166. const_mgr->GetConstant(float_type, words);
  167. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  168. }
  169. return const_mgr->GetConstant(vector_type, ids);
  170. }
  171. return nullptr;
  172. };
  173. }
  174. ConstantFoldingRule FoldCompositeWithConstants() {
  175. // Folds an OpCompositeConstruct where all of the inputs are constants to a
  176. // constant. A new constant is created if necessary.
  177. return [](IRContext* context, Instruction* inst,
  178. const std::vector<const analysis::Constant*>& constants)
  179. -> const analysis::Constant* {
  180. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  181. analysis::TypeManager* type_mgr = context->get_type_mgr();
  182. const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
  183. Instruction* type_inst =
  184. context->get_def_use_mgr()->GetDef(inst->type_id());
  185. std::vector<uint32_t> ids;
  186. for (uint32_t i = 0; i < constants.size(); ++i) {
  187. const analysis::Constant* element_const = constants[i];
  188. if (element_const == nullptr) {
  189. return nullptr;
  190. }
  191. uint32_t component_type_id = 0;
  192. if (type_inst->opcode() == SpvOpTypeStruct) {
  193. component_type_id = type_inst->GetSingleWordInOperand(i);
  194. } else if (type_inst->opcode() == SpvOpTypeArray) {
  195. component_type_id = type_inst->GetSingleWordInOperand(0);
  196. }
  197. uint32_t element_id =
  198. const_mgr->FindDeclaredConstant(element_const, component_type_id);
  199. if (element_id == 0) {
  200. return nullptr;
  201. }
  202. ids.push_back(element_id);
  203. }
  204. return const_mgr->GetConstant(new_type, ids);
  205. };
  206. }
  207. // The interface for a function that returns the result of applying a scalar
  208. // floating-point binary operation on |a| and |b|. The type of the return value
  209. // will be |type|. The input constants must also be of type |type|.
  210. using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
  211. const analysis::Type* result_type, const analysis::Constant* a,
  212. analysis::ConstantManager*)>;
  213. // The interface for a function that returns the result of applying a scalar
  214. // floating-point binary operation on |a| and |b|. The type of the return value
  215. // will be |type|. The input constants must also be of type |type|.
  216. using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
  217. const analysis::Type* result_type, const analysis::Constant* a,
  218. const analysis::Constant* b, analysis::ConstantManager*)>;
  219. // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
  220. // using |scalar_rule| and unary float point vectors ops by applying
  221. // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
  222. // that is returned assumes that |constants| contains 1 entry. If they are
  223. // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
  224. // whose element type is |Float| or |Integer|.
  225. ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
  226. return [scalar_rule](IRContext* context, Instruction* inst,
  227. const std::vector<const analysis::Constant*>& constants)
  228. -> const analysis::Constant* {
  229. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  230. analysis::TypeManager* type_mgr = context->get_type_mgr();
  231. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  232. const analysis::Vector* vector_type = result_type->AsVector();
  233. if (!inst->IsFloatingPointFoldingAllowed()) {
  234. return nullptr;
  235. }
  236. if (constants[0] == nullptr) {
  237. return nullptr;
  238. }
  239. if (vector_type != nullptr) {
  240. std::vector<const analysis::Constant*> a_components;
  241. std::vector<const analysis::Constant*> results_components;
  242. a_components = constants[0]->GetVectorComponents(const_mgr);
  243. // Fold each component of the vector.
  244. for (uint32_t i = 0; i < a_components.size(); ++i) {
  245. results_components.push_back(scalar_rule(vector_type->element_type(),
  246. a_components[i], const_mgr));
  247. if (results_components[i] == nullptr) {
  248. return nullptr;
  249. }
  250. }
  251. // Build the constant object and return it.
  252. std::vector<uint32_t> ids;
  253. for (const analysis::Constant* member : results_components) {
  254. ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
  255. }
  256. return const_mgr->GetConstant(vector_type, ids);
  257. } else {
  258. return scalar_rule(result_type, constants[0], const_mgr);
  259. }
  260. };
  261. }
  262. // Returns a |ConstantFoldingRule| that folds floating point scalars using
  263. // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
  264. // elements of the vector. The |ConstantFoldingRule| that is returned assumes
  265. // that |constants| contains 2 entries. If they are not |nullptr|, then their
  266. // type is either |Float| or a |Vector| whose element type is |Float|.
  267. ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
  268. return [scalar_rule](IRContext* context, Instruction* inst,
  269. const std::vector<const analysis::Constant*>& constants)
  270. -> const analysis::Constant* {
  271. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  272. analysis::TypeManager* type_mgr = context->get_type_mgr();
  273. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  274. const analysis::Vector* vector_type = result_type->AsVector();
  275. if (!inst->IsFloatingPointFoldingAllowed()) {
  276. return nullptr;
  277. }
  278. if (constants[0] == nullptr || constants[1] == nullptr) {
  279. return nullptr;
  280. }
  281. if (vector_type != nullptr) {
  282. std::vector<const analysis::Constant*> a_components;
  283. std::vector<const analysis::Constant*> b_components;
  284. std::vector<const analysis::Constant*> results_components;
  285. a_components = constants[0]->GetVectorComponents(const_mgr);
  286. b_components = constants[1]->GetVectorComponents(const_mgr);
  287. // Fold each component of the vector.
  288. for (uint32_t i = 0; i < a_components.size(); ++i) {
  289. results_components.push_back(scalar_rule(vector_type->element_type(),
  290. a_components[i],
  291. b_components[i], const_mgr));
  292. if (results_components[i] == nullptr) {
  293. return nullptr;
  294. }
  295. }
  296. // Build the constant object and return it.
  297. std::vector<uint32_t> ids;
  298. for (const analysis::Constant* member : results_components) {
  299. ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
  300. }
  301. return const_mgr->GetConstant(vector_type, ids);
  302. } else {
  303. return scalar_rule(result_type, constants[0], constants[1], const_mgr);
  304. }
  305. };
  306. }
  307. // This macro defines a |UnaryScalarFoldingRule| that performs float to
  308. // integer conversion.
  309. // TODO(greg-lunarg): Support for 64-bit integer types.
  310. UnaryScalarFoldingRule FoldFToIOp() {
  311. return [](const analysis::Type* result_type, const analysis::Constant* a,
  312. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  313. assert(result_type != nullptr && a != nullptr);
  314. const analysis::Integer* integer_type = result_type->AsInteger();
  315. const analysis::Float* float_type = a->type()->AsFloat();
  316. assert(float_type != nullptr);
  317. assert(integer_type != nullptr);
  318. if (integer_type->width() != 32) return nullptr;
  319. if (float_type->width() == 32) {
  320. float fa = a->GetFloat();
  321. uint32_t result = integer_type->IsSigned()
  322. ? static_cast<uint32_t>(static_cast<int32_t>(fa))
  323. : static_cast<uint32_t>(fa);
  324. std::vector<uint32_t> words = {result};
  325. return const_mgr->GetConstant(result_type, words);
  326. } else if (float_type->width() == 64) {
  327. double fa = a->GetDouble();
  328. uint32_t result = integer_type->IsSigned()
  329. ? static_cast<uint32_t>(static_cast<int32_t>(fa))
  330. : static_cast<uint32_t>(fa);
  331. std::vector<uint32_t> words = {result};
  332. return const_mgr->GetConstant(result_type, words);
  333. }
  334. return nullptr;
  335. };
  336. }
  337. // This function defines a |UnaryScalarFoldingRule| that performs integer to
  338. // float conversion.
  339. // TODO(greg-lunarg): Support for 64-bit integer types.
  340. UnaryScalarFoldingRule FoldIToFOp() {
  341. return [](const analysis::Type* result_type, const analysis::Constant* a,
  342. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  343. assert(result_type != nullptr && a != nullptr);
  344. const analysis::Integer* integer_type = a->type()->AsInteger();
  345. const analysis::Float* float_type = result_type->AsFloat();
  346. assert(float_type != nullptr);
  347. assert(integer_type != nullptr);
  348. if (integer_type->width() != 32) return nullptr;
  349. uint32_t ua = a->GetU32();
  350. if (float_type->width() == 32) {
  351. float result_val = integer_type->IsSigned()
  352. ? static_cast<float>(static_cast<int32_t>(ua))
  353. : static_cast<float>(ua);
  354. utils::FloatProxy<float> result(result_val);
  355. std::vector<uint32_t> words = {result.data()};
  356. return const_mgr->GetConstant(result_type, words);
  357. } else if (float_type->width() == 64) {
  358. double result_val = integer_type->IsSigned()
  359. ? static_cast<double>(static_cast<int32_t>(ua))
  360. : static_cast<double>(ua);
  361. utils::FloatProxy<double> result(result_val);
  362. std::vector<uint32_t> words = result.GetWords();
  363. return const_mgr->GetConstant(result_type, words);
  364. }
  365. return nullptr;
  366. };
  367. }
  368. // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
  369. UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
  370. return [](const analysis::Type* result_type, const analysis::Constant* a,
  371. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  372. assert(result_type != nullptr && a != nullptr);
  373. const analysis::Float* float_type = a->type()->AsFloat();
  374. assert(float_type != nullptr);
  375. if (float_type->width() != 32) {
  376. return nullptr;
  377. }
  378. float fa = a->GetFloat();
  379. utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
  380. utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
  381. utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
  382. orignal.castTo(quantized, utils::round_direction::kToZero);
  383. quantized.castTo(result, utils::round_direction::kToZero);
  384. std::vector<uint32_t> words = {result.getBits()};
  385. return const_mgr->GetConstant(result_type, words);
  386. };
  387. }
  388. // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
  389. // operator |op| must work for both float and double, and use syntax "f1 op f2".
  390. #define FOLD_FPARITH_OP(op) \
  391. [](const analysis::Type* result_type, const analysis::Constant* a, \
  392. const analysis::Constant* b, \
  393. analysis::ConstantManager* const_mgr_in_macro) \
  394. -> const analysis::Constant* { \
  395. assert(result_type != nullptr && a != nullptr && b != nullptr); \
  396. assert(result_type == a->type() && result_type == b->type()); \
  397. const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
  398. assert(float_type_in_macro != nullptr); \
  399. if (float_type_in_macro->width() == 32) { \
  400. float fa = a->GetFloat(); \
  401. float fb = b->GetFloat(); \
  402. utils::FloatProxy<float> result_in_macro(fa op fb); \
  403. std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
  404. return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
  405. } else if (float_type_in_macro->width() == 64) { \
  406. double fa = a->GetDouble(); \
  407. double fb = b->GetDouble(); \
  408. utils::FloatProxy<double> result_in_macro(fa op fb); \
  409. std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
  410. return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
  411. } \
  412. return nullptr; \
  413. }
  414. // Define the folding rule for conversion between floating point and integer
  415. ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
  416. ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
  417. ConstantFoldingRule FoldQuantizeToF16() {
  418. return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
  419. }
  420. // Define the folding rules for subtraction, addition, multiplication, and
  421. // division for floating point values.
  422. ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
  423. ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
  424. ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
  425. ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
  426. bool CompareFloatingPoint(bool op_result, bool op_unordered,
  427. bool need_ordered) {
  428. if (need_ordered) {
  429. // operands are ordered and Operand 1 is |op| Operand 2
  430. return !op_unordered && op_result;
  431. } else {
  432. // operands are unordered or Operand 1 is |op| Operand 2
  433. return op_unordered || op_result;
  434. }
  435. }
  436. // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
  437. // operator |op| must work for both float and double, and use syntax "f1 op f2".
  438. #define FOLD_FPCMP_OP(op, ord) \
  439. [](const analysis::Type* result_type, const analysis::Constant* a, \
  440. const analysis::Constant* b, \
  441. analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
  442. assert(result_type != nullptr && a != nullptr && b != nullptr); \
  443. assert(result_type->AsBool()); \
  444. assert(a->type() == b->type()); \
  445. const analysis::Float* float_type = a->type()->AsFloat(); \
  446. assert(float_type != nullptr); \
  447. if (float_type->width() == 32) { \
  448. float fa = a->GetFloat(); \
  449. float fb = b->GetFloat(); \
  450. bool result = CompareFloatingPoint( \
  451. fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
  452. std::vector<uint32_t> words = {uint32_t(result)}; \
  453. return const_mgr->GetConstant(result_type, words); \
  454. } else if (float_type->width() == 64) { \
  455. double fa = a->GetDouble(); \
  456. double fb = b->GetDouble(); \
  457. bool result = CompareFloatingPoint( \
  458. fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
  459. std::vector<uint32_t> words = {uint32_t(result)}; \
  460. return const_mgr->GetConstant(result_type, words); \
  461. } \
  462. return nullptr; \
  463. }
  464. // Define the folding rules for ordered and unordered comparison for floating
  465. // point values.
  466. ConstantFoldingRule FoldFOrdEqual() {
  467. return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
  468. }
  469. ConstantFoldingRule FoldFUnordEqual() {
  470. return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
  471. }
  472. ConstantFoldingRule FoldFOrdNotEqual() {
  473. return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
  474. }
  475. ConstantFoldingRule FoldFUnordNotEqual() {
  476. return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
  477. }
  478. ConstantFoldingRule FoldFOrdLessThan() {
  479. return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
  480. }
  481. ConstantFoldingRule FoldFUnordLessThan() {
  482. return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
  483. }
  484. ConstantFoldingRule FoldFOrdGreaterThan() {
  485. return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
  486. }
  487. ConstantFoldingRule FoldFUnordGreaterThan() {
  488. return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
  489. }
  490. ConstantFoldingRule FoldFOrdLessThanEqual() {
  491. return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
  492. }
  493. ConstantFoldingRule FoldFUnordLessThanEqual() {
  494. return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
  495. }
  496. ConstantFoldingRule FoldFOrdGreaterThanEqual() {
  497. return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
  498. }
  499. ConstantFoldingRule FoldFUnordGreaterThanEqual() {
  500. return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
  501. }
  502. // Folds an OpDot where all of the inputs are constants to a
  503. // constant. A new constant is created if necessary.
  504. ConstantFoldingRule FoldOpDotWithConstants() {
  505. return [](IRContext* context, Instruction* inst,
  506. const std::vector<const analysis::Constant*>& constants)
  507. -> const analysis::Constant* {
  508. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  509. analysis::TypeManager* type_mgr = context->get_type_mgr();
  510. const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
  511. assert(new_type->AsFloat() && "OpDot should have a float return type.");
  512. const analysis::Float* float_type = new_type->AsFloat();
  513. if (!inst->IsFloatingPointFoldingAllowed()) {
  514. return nullptr;
  515. }
  516. // If one of the operands is 0, then the result is 0.
  517. bool has_zero_operand = false;
  518. for (int i = 0; i < 2; ++i) {
  519. if (constants[i]) {
  520. if (constants[i]->AsNullConstant() ||
  521. constants[i]->AsVectorConstant()->IsZero()) {
  522. has_zero_operand = true;
  523. break;
  524. }
  525. }
  526. }
  527. if (has_zero_operand) {
  528. if (float_type->width() == 32) {
  529. utils::FloatProxy<float> result(0.0f);
  530. std::vector<uint32_t> words = result.GetWords();
  531. return const_mgr->GetConstant(float_type, words);
  532. }
  533. if (float_type->width() == 64) {
  534. utils::FloatProxy<double> result(0.0);
  535. std::vector<uint32_t> words = result.GetWords();
  536. return const_mgr->GetConstant(float_type, words);
  537. }
  538. return nullptr;
  539. }
  540. if (constants[0] == nullptr || constants[1] == nullptr) {
  541. return nullptr;
  542. }
  543. std::vector<const analysis::Constant*> a_components;
  544. std::vector<const analysis::Constant*> b_components;
  545. a_components = constants[0]->GetVectorComponents(const_mgr);
  546. b_components = constants[1]->GetVectorComponents(const_mgr);
  547. utils::FloatProxy<double> result(0.0);
  548. std::vector<uint32_t> words = result.GetWords();
  549. const analysis::Constant* result_const =
  550. const_mgr->GetConstant(float_type, words);
  551. for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
  552. ++i) {
  553. if (a_components[i] == nullptr || b_components[i] == nullptr) {
  554. return nullptr;
  555. }
  556. const analysis::Constant* component = FOLD_FPARITH_OP(*)(
  557. new_type, a_components[i], b_components[i], const_mgr);
  558. if (component == nullptr) {
  559. return nullptr;
  560. }
  561. result_const =
  562. FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
  563. }
  564. return result_const;
  565. };
  566. }
  567. // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
  568. // from zero.
  569. UnaryScalarFoldingRule FoldFNegateOp() {
  570. return [](const analysis::Type* result_type, const analysis::Constant* a,
  571. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  572. assert(result_type != nullptr && a != nullptr);
  573. assert(result_type == a->type());
  574. const analysis::Float* float_type = result_type->AsFloat();
  575. assert(float_type != nullptr);
  576. if (float_type->width() == 32) {
  577. float fa = a->GetFloat();
  578. utils::FloatProxy<float> result(-fa);
  579. std::vector<uint32_t> words = result.GetWords();
  580. return const_mgr->GetConstant(result_type, words);
  581. } else if (float_type->width() == 64) {
  582. double da = a->GetDouble();
  583. utils::FloatProxy<double> result(-da);
  584. std::vector<uint32_t> words = result.GetWords();
  585. return const_mgr->GetConstant(result_type, words);
  586. }
  587. return nullptr;
  588. };
  589. }
  590. ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
  591. ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
  592. return [cmp_opcode](IRContext* context, Instruction* inst,
  593. const std::vector<const analysis::Constant*>& constants)
  594. -> const analysis::Constant* {
  595. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  596. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  597. if (!inst->IsFloatingPointFoldingAllowed()) {
  598. return nullptr;
  599. }
  600. uint32_t non_const_idx = (constants[0] ? 1 : 0);
  601. uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
  602. Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
  603. analysis::TypeManager* type_mgr = context->get_type_mgr();
  604. const analysis::Type* operand_type =
  605. type_mgr->GetType(operand_inst->type_id());
  606. if (!operand_type->AsFloat()) {
  607. return nullptr;
  608. }
  609. if (operand_type->AsFloat()->width() != 32 &&
  610. operand_type->AsFloat()->width() != 64) {
  611. return nullptr;
  612. }
  613. if (operand_inst->opcode() != SpvOpExtInst) {
  614. return nullptr;
  615. }
  616. if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
  617. return nullptr;
  618. }
  619. if (constants[1] == nullptr && constants[0] == nullptr) {
  620. return nullptr;
  621. }
  622. uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
  623. const analysis::Constant* max_const =
  624. const_mgr->FindDeclaredConstant(max_id);
  625. uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
  626. const analysis::Constant* min_const =
  627. const_mgr->FindDeclaredConstant(min_id);
  628. bool found_result = false;
  629. bool result = false;
  630. switch (cmp_opcode) {
  631. case SpvOpFOrdLessThan:
  632. case SpvOpFUnordLessThan:
  633. case SpvOpFOrdGreaterThanEqual:
  634. case SpvOpFUnordGreaterThanEqual:
  635. if (constants[0]) {
  636. if (min_const) {
  637. if (constants[0]->GetValueAsDouble() <
  638. min_const->GetValueAsDouble()) {
  639. found_result = true;
  640. result = (cmp_opcode == SpvOpFOrdLessThan ||
  641. cmp_opcode == SpvOpFUnordLessThan);
  642. }
  643. }
  644. if (max_const) {
  645. if (constants[0]->GetValueAsDouble() >=
  646. max_const->GetValueAsDouble()) {
  647. found_result = true;
  648. result = !(cmp_opcode == SpvOpFOrdLessThan ||
  649. cmp_opcode == SpvOpFUnordLessThan);
  650. }
  651. }
  652. }
  653. if (constants[1]) {
  654. if (max_const) {
  655. if (max_const->GetValueAsDouble() <
  656. constants[1]->GetValueAsDouble()) {
  657. found_result = true;
  658. result = (cmp_opcode == SpvOpFOrdLessThan ||
  659. cmp_opcode == SpvOpFUnordLessThan);
  660. }
  661. }
  662. if (min_const) {
  663. if (min_const->GetValueAsDouble() >=
  664. constants[1]->GetValueAsDouble()) {
  665. found_result = true;
  666. result = !(cmp_opcode == SpvOpFOrdLessThan ||
  667. cmp_opcode == SpvOpFUnordLessThan);
  668. }
  669. }
  670. }
  671. break;
  672. case SpvOpFOrdGreaterThan:
  673. case SpvOpFUnordGreaterThan:
  674. case SpvOpFOrdLessThanEqual:
  675. case SpvOpFUnordLessThanEqual:
  676. if (constants[0]) {
  677. if (min_const) {
  678. if (constants[0]->GetValueAsDouble() <=
  679. min_const->GetValueAsDouble()) {
  680. found_result = true;
  681. result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
  682. cmp_opcode == SpvOpFUnordLessThanEqual);
  683. }
  684. }
  685. if (max_const) {
  686. if (constants[0]->GetValueAsDouble() >
  687. max_const->GetValueAsDouble()) {
  688. found_result = true;
  689. result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
  690. cmp_opcode == SpvOpFUnordLessThanEqual);
  691. }
  692. }
  693. }
  694. if (constants[1]) {
  695. if (max_const) {
  696. if (max_const->GetValueAsDouble() <=
  697. constants[1]->GetValueAsDouble()) {
  698. found_result = true;
  699. result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
  700. cmp_opcode == SpvOpFUnordLessThanEqual);
  701. }
  702. }
  703. if (min_const) {
  704. if (min_const->GetValueAsDouble() >
  705. constants[1]->GetValueAsDouble()) {
  706. found_result = true;
  707. result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
  708. cmp_opcode == SpvOpFUnordLessThanEqual);
  709. }
  710. }
  711. }
  712. break;
  713. default:
  714. return nullptr;
  715. }
  716. if (!found_result) {
  717. return nullptr;
  718. }
  719. const analysis::Type* bool_type =
  720. context->get_type_mgr()->GetType(inst->type_id());
  721. const analysis::Constant* result_const =
  722. const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
  723. assert(result_const);
  724. return result_const;
  725. };
  726. }
  727. } // namespace
  728. ConstantFoldingRules::ConstantFoldingRules() {
  729. // Add all folding rules to the list for the opcodes to which they apply.
  730. // Note that the order in which rules are added to the list matters. If a rule
  731. // applies to the instruction, the rest of the rules will not be attempted.
  732. // Take that into consideration.
  733. rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
  734. rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
  735. rules_[SpvOpConvertFToS].push_back(FoldFToI());
  736. rules_[SpvOpConvertFToU].push_back(FoldFToI());
  737. rules_[SpvOpConvertSToF].push_back(FoldIToF());
  738. rules_[SpvOpConvertUToF].push_back(FoldIToF());
  739. rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
  740. rules_[SpvOpFAdd].push_back(FoldFAdd());
  741. rules_[SpvOpFDiv].push_back(FoldFDiv());
  742. rules_[SpvOpFMul].push_back(FoldFMul());
  743. rules_[SpvOpFSub].push_back(FoldFSub());
  744. rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
  745. rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
  746. rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
  747. rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
  748. rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
  749. rules_[SpvOpFOrdLessThan].push_back(
  750. FoldFClampFeedingCompare(SpvOpFOrdLessThan));
  751. rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
  752. rules_[SpvOpFUnordLessThan].push_back(
  753. FoldFClampFeedingCompare(SpvOpFUnordLessThan));
  754. rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
  755. rules_[SpvOpFOrdGreaterThan].push_back(
  756. FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
  757. rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
  758. rules_[SpvOpFUnordGreaterThan].push_back(
  759. FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
  760. rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
  761. rules_[SpvOpFOrdLessThanEqual].push_back(
  762. FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
  763. rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
  764. rules_[SpvOpFUnordLessThanEqual].push_back(
  765. FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
  766. rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
  767. rules_[SpvOpFOrdGreaterThanEqual].push_back(
  768. FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
  769. rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
  770. rules_[SpvOpFUnordGreaterThanEqual].push_back(
  771. FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
  772. rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
  773. rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
  774. rules_[SpvOpFNegate].push_back(FoldFNegate());
  775. rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
  776. }
  777. } // namespace opt
  778. } // namespace spvtools