const_folding_rules.cpp 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/const_folding_rules.h"
  15. #include "source/opt/ir_context.h"
  16. namespace spvtools {
  17. namespace opt {
  18. namespace {
  19. constexpr uint32_t kExtractCompositeIdInIdx = 0;
  20. // Returns a constants with the value NaN of the given type. Only works for
  21. // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
  22. const analysis::Constant* GetNan(const analysis::Type* type,
  23. analysis::ConstantManager* const_mgr) {
  24. const analysis::Float* float_type = type->AsFloat();
  25. if (float_type == nullptr) {
  26. return nullptr;
  27. }
  28. switch (float_type->width()) {
  29. case 32:
  30. return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
  31. case 64:
  32. return const_mgr->GetDoubleConst(
  33. std::numeric_limits<double>::quiet_NaN());
  34. default:
  35. return nullptr;
  36. }
  37. }
  38. // Returns a constants with the value INF of the given type. Only works for
  39. // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
  40. const analysis::Constant* GetInf(const analysis::Type* type,
  41. analysis::ConstantManager* const_mgr) {
  42. const analysis::Float* float_type = type->AsFloat();
  43. if (float_type == nullptr) {
  44. return nullptr;
  45. }
  46. switch (float_type->width()) {
  47. case 32:
  48. return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
  49. case 64:
  50. return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
  51. default:
  52. return nullptr;
  53. }
  54. }
  55. // Returns true if |type| is Float or a vector of Float.
  56. bool HasFloatingPoint(const analysis::Type* type) {
  57. if (type->AsFloat()) {
  58. return true;
  59. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  60. return vec_type->element_type()->AsFloat() != nullptr;
  61. }
  62. return false;
  63. }
  64. // Returns a constants with the value |-val| of the given type. Only works for
  65. // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
  66. const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
  67. const analysis::Constant* val,
  68. analysis::ConstantManager* const_mgr) {
  69. const analysis::Float* float_type = result_type->AsFloat();
  70. assert(float_type != nullptr);
  71. if (float_type->width() == 32) {
  72. float fa = val->GetFloat();
  73. return const_mgr->GetFloatConst(-fa);
  74. } else if (float_type->width() == 64) {
  75. double da = val->GetDouble();
  76. return const_mgr->GetDoubleConst(-da);
  77. }
  78. return nullptr;
  79. }
  80. // Folds an OpcompositeExtract where input is a composite constant.
  81. ConstantFoldingRule FoldExtractWithConstants() {
  82. return [](IRContext* context, Instruction* inst,
  83. const std::vector<const analysis::Constant*>& constants)
  84. -> const analysis::Constant* {
  85. const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
  86. if (c == nullptr) {
  87. return nullptr;
  88. }
  89. for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
  90. uint32_t element_index = inst->GetSingleWordInOperand(i);
  91. if (c->AsNullConstant()) {
  92. // Return Null for the return type.
  93. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  94. analysis::TypeManager* type_mgr = context->get_type_mgr();
  95. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
  96. }
  97. auto cc = c->AsCompositeConstant();
  98. assert(cc != nullptr);
  99. auto components = cc->GetComponents();
  100. // Protect against invalid IR. Refuse to fold if the index is out
  101. // of bounds.
  102. if (element_index >= components.size()) return nullptr;
  103. c = components[element_index];
  104. }
  105. return c;
  106. };
  107. }
  108. // Folds an OpcompositeInsert where input is a composite constant.
  109. ConstantFoldingRule FoldInsertWithConstants() {
  110. return [](IRContext* context, Instruction* inst,
  111. const std::vector<const analysis::Constant*>& constants)
  112. -> const analysis::Constant* {
  113. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  114. const analysis::Constant* object = constants[0];
  115. const analysis::Constant* composite = constants[1];
  116. if (object == nullptr || composite == nullptr) {
  117. return nullptr;
  118. }
  119. // If there is more than 1 index, then each additional constant used by the
  120. // index will need to be recreated to use the inserted object.
  121. std::vector<const analysis::Constant*> chain;
  122. std::vector<const analysis::Constant*> components;
  123. const analysis::Type* type = nullptr;
  124. const uint32_t final_index = (inst->NumInOperands() - 1);
  125. // Work down hierarchy of all indexes
  126. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  127. type = composite->type();
  128. if (composite->AsNullConstant()) {
  129. // Make new composite so it can be inserted in the index with the
  130. // non-null value
  131. const auto new_composite = const_mgr->GetNullCompositeConstant(type);
  132. // Keep track of any indexes along the way to last index
  133. if (i != final_index) {
  134. chain.push_back(new_composite);
  135. }
  136. components = new_composite->AsCompositeConstant()->GetComponents();
  137. } else {
  138. // Keep track of any indexes along the way to last index
  139. if (i != final_index) {
  140. chain.push_back(composite);
  141. }
  142. components = composite->AsCompositeConstant()->GetComponents();
  143. }
  144. const uint32_t index = inst->GetSingleWordInOperand(i);
  145. composite = components[index];
  146. }
  147. // Final index in hierarchy is inserted with new object.
  148. const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
  149. std::vector<uint32_t> ids;
  150. for (size_t i = 0; i < components.size(); i++) {
  151. const analysis::Constant* constant =
  152. (i == final_operand) ? object : components[i];
  153. Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
  154. ids.push_back(member_inst->result_id());
  155. }
  156. const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
  157. // Work backwards up the chain and replace each index with new constant.
  158. for (size_t i = chain.size(); i > 0; i--) {
  159. // Need to insert any previous instruction into the module first.
  160. // Can't just insert in types_values_begin() because it will move above
  161. // where the types are declared.
  162. // Can't compare with location of inst because not all new added
  163. // instructions are added to types_values_
  164. auto iter = context->types_values_end();
  165. Module::inst_iterator* pos = &iter;
  166. const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
  167. composite = chain[i - 1];
  168. components = composite->AsCompositeConstant()->GetComponents();
  169. type = composite->type();
  170. ids.clear();
  171. for (size_t k = 0; k < components.size(); k++) {
  172. const uint32_t index =
  173. inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
  174. const analysis::Constant* constant =
  175. (k == index) ? new_constant : components[k];
  176. const uint32_t constant_id =
  177. const_mgr->FindDeclaredConstant(constant, 0);
  178. ids.push_back(constant_id);
  179. }
  180. new_constant = const_mgr->GetConstant(type, ids);
  181. }
  182. // If multiple constants were created, only need to return the top index.
  183. return new_constant;
  184. };
  185. }
  186. ConstantFoldingRule FoldVectorShuffleWithConstants() {
  187. return [](IRContext* context, Instruction* inst,
  188. const std::vector<const analysis::Constant*>& constants)
  189. -> const analysis::Constant* {
  190. assert(inst->opcode() == spv::Op::OpVectorShuffle);
  191. const analysis::Constant* c1 = constants[0];
  192. const analysis::Constant* c2 = constants[1];
  193. if (c1 == nullptr || c2 == nullptr) {
  194. return nullptr;
  195. }
  196. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  197. const analysis::Type* element_type = c1->type()->AsVector()->element_type();
  198. std::vector<const analysis::Constant*> c1_components;
  199. if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
  200. c1_components = vec_const->GetComponents();
  201. } else {
  202. assert(c1->AsNullConstant());
  203. const analysis::Constant* element =
  204. const_mgr->GetConstant(element_type, {});
  205. c1_components.resize(c1->type()->AsVector()->element_count(), element);
  206. }
  207. std::vector<const analysis::Constant*> c2_components;
  208. if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
  209. c2_components = vec_const->GetComponents();
  210. } else {
  211. assert(c2->AsNullConstant());
  212. const analysis::Constant* element =
  213. const_mgr->GetConstant(element_type, {});
  214. c2_components.resize(c2->type()->AsVector()->element_count(), element);
  215. }
  216. std::vector<uint32_t> ids;
  217. const uint32_t undef_literal_value = 0xffffffff;
  218. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  219. uint32_t index = inst->GetSingleWordInOperand(i);
  220. if (index == undef_literal_value) {
  221. // Don't fold shuffle with undef literal value.
  222. return nullptr;
  223. } else if (index < c1_components.size()) {
  224. Instruction* member_inst =
  225. const_mgr->GetDefiningInstruction(c1_components[index]);
  226. ids.push_back(member_inst->result_id());
  227. } else {
  228. Instruction* member_inst = const_mgr->GetDefiningInstruction(
  229. c2_components[index - c1_components.size()]);
  230. ids.push_back(member_inst->result_id());
  231. }
  232. }
  233. analysis::TypeManager* type_mgr = context->get_type_mgr();
  234. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
  235. };
  236. }
  237. ConstantFoldingRule FoldVectorTimesScalar() {
  238. return [](IRContext* context, Instruction* inst,
  239. const std::vector<const analysis::Constant*>& constants)
  240. -> const analysis::Constant* {
  241. assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
  242. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  243. analysis::TypeManager* type_mgr = context->get_type_mgr();
  244. if (!inst->IsFloatingPointFoldingAllowed()) {
  245. if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
  246. return nullptr;
  247. }
  248. }
  249. const analysis::Constant* c1 = constants[0];
  250. const analysis::Constant* c2 = constants[1];
  251. if (c1 && c1->IsZero()) {
  252. return c1;
  253. }
  254. if (c2 && c2->IsZero()) {
  255. // Get or create the NullConstant for this type.
  256. std::vector<uint32_t> ids;
  257. return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
  258. }
  259. if (c1 == nullptr || c2 == nullptr) {
  260. return nullptr;
  261. }
  262. // Check result type.
  263. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  264. const analysis::Vector* vector_type = result_type->AsVector();
  265. assert(vector_type != nullptr);
  266. const analysis::Type* element_type = vector_type->element_type();
  267. assert(element_type != nullptr);
  268. const analysis::Float* float_type = element_type->AsFloat();
  269. assert(float_type != nullptr);
  270. // Check types of c1 and c2.
  271. assert(c1->type()->AsVector() == vector_type);
  272. assert(c1->type()->AsVector()->element_type() == element_type &&
  273. c2->type() == element_type);
  274. // Get a float vector that is the result of vector-times-scalar.
  275. std::vector<const analysis::Constant*> c1_components =
  276. c1->GetVectorComponents(const_mgr);
  277. std::vector<uint32_t> ids;
  278. if (float_type->width() == 32) {
  279. float scalar = c2->GetFloat();
  280. for (uint32_t i = 0; i < c1_components.size(); ++i) {
  281. utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
  282. std::vector<uint32_t> words = result.GetWords();
  283. const analysis::Constant* new_elem =
  284. const_mgr->GetConstant(float_type, words);
  285. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  286. }
  287. return const_mgr->GetConstant(vector_type, ids);
  288. } else if (float_type->width() == 64) {
  289. double scalar = c2->GetDouble();
  290. for (uint32_t i = 0; i < c1_components.size(); ++i) {
  291. utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
  292. scalar);
  293. std::vector<uint32_t> words = result.GetWords();
  294. const analysis::Constant* new_elem =
  295. const_mgr->GetConstant(float_type, words);
  296. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  297. }
  298. return const_mgr->GetConstant(vector_type, ids);
  299. }
  300. return nullptr;
  301. };
  302. }
  303. ConstantFoldingRule FoldVectorTimesMatrix() {
  304. return [](IRContext* context, Instruction* inst,
  305. const std::vector<const analysis::Constant*>& constants)
  306. -> const analysis::Constant* {
  307. assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
  308. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  309. analysis::TypeManager* type_mgr = context->get_type_mgr();
  310. if (!inst->IsFloatingPointFoldingAllowed()) {
  311. if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
  312. return nullptr;
  313. }
  314. }
  315. const analysis::Constant* c1 = constants[0];
  316. const analysis::Constant* c2 = constants[1];
  317. if (c1 == nullptr || c2 == nullptr) {
  318. return nullptr;
  319. }
  320. // Check result type.
  321. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  322. const analysis::Vector* vector_type = result_type->AsVector();
  323. assert(vector_type != nullptr);
  324. const analysis::Type* element_type = vector_type->element_type();
  325. assert(element_type != nullptr);
  326. const analysis::Float* float_type = element_type->AsFloat();
  327. assert(float_type != nullptr);
  328. // Check types of c1 and c2.
  329. assert(c1->type()->AsVector() == vector_type);
  330. assert(c1->type()->AsVector()->element_type() == element_type &&
  331. c2->type()->AsMatrix()->element_type() == vector_type);
  332. // Get a float vector that is the result of vector-times-matrix.
  333. std::vector<const analysis::Constant*> c1_components =
  334. c1->GetVectorComponents(const_mgr);
  335. std::vector<const analysis::Constant*> c2_components =
  336. c2->AsMatrixConstant()->GetComponents();
  337. uint32_t resultVectorSize = result_type->AsVector()->element_count();
  338. std::vector<uint32_t> ids;
  339. if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
  340. std::vector<uint32_t> words(float_type->width() / 32, 0);
  341. for (uint32_t i = 0; i < resultVectorSize; ++i) {
  342. const analysis::Constant* new_elem =
  343. const_mgr->GetConstant(float_type, words);
  344. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  345. }
  346. return const_mgr->GetConstant(vector_type, ids);
  347. }
  348. if (float_type->width() == 32) {
  349. for (uint32_t i = 0; i < resultVectorSize; ++i) {
  350. float result_scalar = 0.0f;
  351. const analysis::VectorConstant* c2_vec =
  352. c2_components[i]->AsVectorConstant();
  353. for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
  354. float c1_scalar = c1_components[j]->GetFloat();
  355. float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
  356. result_scalar += c1_scalar * c2_scalar;
  357. }
  358. utils::FloatProxy<float> result(result_scalar);
  359. std::vector<uint32_t> words = result.GetWords();
  360. const analysis::Constant* new_elem =
  361. const_mgr->GetConstant(float_type, words);
  362. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  363. }
  364. return const_mgr->GetConstant(vector_type, ids);
  365. } else if (float_type->width() == 64) {
  366. for (uint32_t i = 0; i < c2_components.size(); ++i) {
  367. double result_scalar = 0.0;
  368. const analysis::VectorConstant* c2_vec =
  369. c2_components[i]->AsVectorConstant();
  370. for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
  371. double c1_scalar = c1_components[j]->GetDouble();
  372. double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
  373. result_scalar += c1_scalar * c2_scalar;
  374. }
  375. utils::FloatProxy<double> result(result_scalar);
  376. std::vector<uint32_t> words = result.GetWords();
  377. const analysis::Constant* new_elem =
  378. const_mgr->GetConstant(float_type, words);
  379. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  380. }
  381. return const_mgr->GetConstant(vector_type, ids);
  382. }
  383. return nullptr;
  384. };
  385. }
  386. ConstantFoldingRule FoldMatrixTimesVector() {
  387. return [](IRContext* context, Instruction* inst,
  388. const std::vector<const analysis::Constant*>& constants)
  389. -> const analysis::Constant* {
  390. assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
  391. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  392. analysis::TypeManager* type_mgr = context->get_type_mgr();
  393. if (!inst->IsFloatingPointFoldingAllowed()) {
  394. if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
  395. return nullptr;
  396. }
  397. }
  398. const analysis::Constant* c1 = constants[0];
  399. const analysis::Constant* c2 = constants[1];
  400. if (c1 == nullptr || c2 == nullptr) {
  401. return nullptr;
  402. }
  403. // Check result type.
  404. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  405. const analysis::Vector* vector_type = result_type->AsVector();
  406. assert(vector_type != nullptr);
  407. const analysis::Type* element_type = vector_type->element_type();
  408. assert(element_type != nullptr);
  409. const analysis::Float* float_type = element_type->AsFloat();
  410. assert(float_type != nullptr);
  411. // Check types of c1 and c2.
  412. assert(c1->type()->AsMatrix()->element_type() == vector_type);
  413. assert(c2->type()->AsVector()->element_type() == element_type);
  414. // Get a float vector that is the result of matrix-times-vector.
  415. std::vector<const analysis::Constant*> c1_components =
  416. c1->AsMatrixConstant()->GetComponents();
  417. std::vector<const analysis::Constant*> c2_components =
  418. c2->GetVectorComponents(const_mgr);
  419. uint32_t resultVectorSize = result_type->AsVector()->element_count();
  420. std::vector<uint32_t> ids;
  421. if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
  422. std::vector<uint32_t> words(float_type->width() / 32, 0);
  423. for (uint32_t i = 0; i < resultVectorSize; ++i) {
  424. const analysis::Constant* new_elem =
  425. const_mgr->GetConstant(float_type, words);
  426. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  427. }
  428. return const_mgr->GetConstant(vector_type, ids);
  429. }
  430. if (float_type->width() == 32) {
  431. for (uint32_t i = 0; i < resultVectorSize; ++i) {
  432. float result_scalar = 0.0f;
  433. for (uint32_t j = 0; j < c1_components.size(); ++j) {
  434. float c1_scalar = c1_components[j]
  435. ->AsVectorConstant()
  436. ->GetComponents()[i]
  437. ->GetFloat();
  438. float c2_scalar = c2_components[j]->GetFloat();
  439. result_scalar += c1_scalar * c2_scalar;
  440. }
  441. utils::FloatProxy<float> result(result_scalar);
  442. std::vector<uint32_t> words = result.GetWords();
  443. const analysis::Constant* new_elem =
  444. const_mgr->GetConstant(float_type, words);
  445. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  446. }
  447. return const_mgr->GetConstant(vector_type, ids);
  448. } else if (float_type->width() == 64) {
  449. for (uint32_t i = 0; i < resultVectorSize; ++i) {
  450. double result_scalar = 0.0;
  451. for (uint32_t j = 0; j < c1_components.size(); ++j) {
  452. double c1_scalar = c1_components[j]
  453. ->AsVectorConstant()
  454. ->GetComponents()[i]
  455. ->GetDouble();
  456. double c2_scalar = c2_components[j]->GetDouble();
  457. result_scalar += c1_scalar * c2_scalar;
  458. }
  459. utils::FloatProxy<double> result(result_scalar);
  460. std::vector<uint32_t> words = result.GetWords();
  461. const analysis::Constant* new_elem =
  462. const_mgr->GetConstant(float_type, words);
  463. ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
  464. }
  465. return const_mgr->GetConstant(vector_type, ids);
  466. }
  467. return nullptr;
  468. };
  469. }
  470. ConstantFoldingRule FoldCompositeWithConstants() {
  471. // Folds an OpCompositeConstruct where all of the inputs are constants to a
  472. // constant. A new constant is created if necessary.
  473. return [](IRContext* context, Instruction* inst,
  474. const std::vector<const analysis::Constant*>& constants)
  475. -> const analysis::Constant* {
  476. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  477. analysis::TypeManager* type_mgr = context->get_type_mgr();
  478. const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
  479. Instruction* type_inst =
  480. context->get_def_use_mgr()->GetDef(inst->type_id());
  481. std::vector<uint32_t> ids;
  482. for (uint32_t i = 0; i < constants.size(); ++i) {
  483. const analysis::Constant* element_const = constants[i];
  484. if (element_const == nullptr) {
  485. return nullptr;
  486. }
  487. uint32_t component_type_id = 0;
  488. if (type_inst->opcode() == spv::Op::OpTypeStruct) {
  489. component_type_id = type_inst->GetSingleWordInOperand(i);
  490. } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
  491. component_type_id = type_inst->GetSingleWordInOperand(0);
  492. }
  493. uint32_t element_id =
  494. const_mgr->FindDeclaredConstant(element_const, component_type_id);
  495. if (element_id == 0) {
  496. return nullptr;
  497. }
  498. ids.push_back(element_id);
  499. }
  500. return const_mgr->GetConstant(new_type, ids);
  501. };
  502. }
  503. // The interface for a function that returns the result of applying a scalar
  504. // floating-point binary operation on |a| and |b|. The type of the return value
  505. // will be |type|. The input constants must also be of type |type|.
  506. using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
  507. const analysis::Type* result_type, const analysis::Constant* a,
  508. analysis::ConstantManager*)>;
  509. // The interface for a function that returns the result of applying a scalar
  510. // floating-point binary operation on |a| and |b|. The type of the return value
  511. // will be |type|. The input constants must also be of type |type|.
  512. using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
  513. const analysis::Type* result_type, const analysis::Constant* a,
  514. const analysis::Constant* b, analysis::ConstantManager*)>;
  515. // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
  516. // using |scalar_rule| and unary float point vectors ops by applying
  517. // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
  518. // that is returned assumes that |constants| contains 1 entry. If they are
  519. // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
  520. // whose element type is |Float| or |Integer|.
  521. ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
  522. return [scalar_rule](IRContext* context, Instruction* inst,
  523. const std::vector<const analysis::Constant*>& constants)
  524. -> const analysis::Constant* {
  525. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  526. analysis::TypeManager* type_mgr = context->get_type_mgr();
  527. const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  528. const analysis::Vector* vector_type = result_type->AsVector();
  529. if (!inst->IsFloatingPointFoldingAllowed()) {
  530. return nullptr;
  531. }
  532. const analysis::Constant* arg =
  533. (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
  534. if (arg == nullptr) {
  535. return nullptr;
  536. }
  537. if (vector_type != nullptr) {
  538. std::vector<const analysis::Constant*> a_components;
  539. std::vector<const analysis::Constant*> results_components;
  540. a_components = arg->GetVectorComponents(const_mgr);
  541. // Fold each component of the vector.
  542. for (uint32_t i = 0; i < a_components.size(); ++i) {
  543. results_components.push_back(scalar_rule(vector_type->element_type(),
  544. a_components[i], const_mgr));
  545. if (results_components[i] == nullptr) {
  546. return nullptr;
  547. }
  548. }
  549. // Build the constant object and return it.
  550. std::vector<uint32_t> ids;
  551. for (const analysis::Constant* member : results_components) {
  552. ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
  553. }
  554. return const_mgr->GetConstant(vector_type, ids);
  555. } else {
  556. return scalar_rule(result_type, arg, const_mgr);
  557. }
  558. };
  559. }
  560. // Returns the result of folding the constants in |constants| according the
  561. // |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
  562. // per component.
  563. const analysis::Constant* FoldFPBinaryOp(
  564. BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
  565. const std::vector<const analysis::Constant*>& constants,
  566. IRContext* context) {
  567. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  568. analysis::TypeManager* type_mgr = context->get_type_mgr();
  569. const analysis::Type* result_type = type_mgr->GetType(result_type_id);
  570. const analysis::Vector* vector_type = result_type->AsVector();
  571. if (constants[0] == nullptr || constants[1] == nullptr) {
  572. return nullptr;
  573. }
  574. if (vector_type != nullptr) {
  575. std::vector<const analysis::Constant*> a_components;
  576. std::vector<const analysis::Constant*> b_components;
  577. std::vector<const analysis::Constant*> results_components;
  578. a_components = constants[0]->GetVectorComponents(const_mgr);
  579. b_components = constants[1]->GetVectorComponents(const_mgr);
  580. // Fold each component of the vector.
  581. for (uint32_t i = 0; i < a_components.size(); ++i) {
  582. results_components.push_back(scalar_rule(vector_type->element_type(),
  583. a_components[i], b_components[i],
  584. const_mgr));
  585. if (results_components[i] == nullptr) {
  586. return nullptr;
  587. }
  588. }
  589. // Build the constant object and return it.
  590. std::vector<uint32_t> ids;
  591. for (const analysis::Constant* member : results_components) {
  592. ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
  593. }
  594. return const_mgr->GetConstant(vector_type, ids);
  595. } else {
  596. return scalar_rule(result_type, constants[0], constants[1], const_mgr);
  597. }
  598. }
  599. // Returns a |ConstantFoldingRule| that folds floating point scalars using
  600. // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
  601. // elements of the vector. The |ConstantFoldingRule| that is returned assumes
  602. // that |constants| contains 2 entries. If they are not |nullptr|, then their
  603. // type is either |Float| or a |Vector| whose element type is |Float|.
  604. ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
  605. return [scalar_rule](IRContext* context, Instruction* inst,
  606. const std::vector<const analysis::Constant*>& constants)
  607. -> const analysis::Constant* {
  608. if (!inst->IsFloatingPointFoldingAllowed()) {
  609. return nullptr;
  610. }
  611. if (inst->opcode() == spv::Op::OpExtInst) {
  612. return FoldFPBinaryOp(scalar_rule, inst->type_id(),
  613. {constants[1], constants[2]}, context);
  614. }
  615. return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
  616. };
  617. }
  618. // This macro defines a |UnaryScalarFoldingRule| that performs float to
  619. // integer conversion.
  620. // TODO(greg-lunarg): Support for 64-bit integer types.
  621. UnaryScalarFoldingRule FoldFToIOp() {
  622. return [](const analysis::Type* result_type, const analysis::Constant* a,
  623. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  624. assert(result_type != nullptr && a != nullptr);
  625. const analysis::Integer* integer_type = result_type->AsInteger();
  626. const analysis::Float* float_type = a->type()->AsFloat();
  627. assert(float_type != nullptr);
  628. assert(integer_type != nullptr);
  629. if (integer_type->width() != 32) return nullptr;
  630. if (float_type->width() == 32) {
  631. float fa = a->GetFloat();
  632. uint32_t result = integer_type->IsSigned()
  633. ? static_cast<uint32_t>(static_cast<int32_t>(fa))
  634. : static_cast<uint32_t>(fa);
  635. std::vector<uint32_t> words = {result};
  636. return const_mgr->GetConstant(result_type, words);
  637. } else if (float_type->width() == 64) {
  638. double fa = a->GetDouble();
  639. uint32_t result = integer_type->IsSigned()
  640. ? static_cast<uint32_t>(static_cast<int32_t>(fa))
  641. : static_cast<uint32_t>(fa);
  642. std::vector<uint32_t> words = {result};
  643. return const_mgr->GetConstant(result_type, words);
  644. }
  645. return nullptr;
  646. };
  647. }
  648. // This function defines a |UnaryScalarFoldingRule| that performs integer to
  649. // float conversion.
  650. // TODO(greg-lunarg): Support for 64-bit integer types.
  651. UnaryScalarFoldingRule FoldIToFOp() {
  652. return [](const analysis::Type* result_type, const analysis::Constant* a,
  653. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  654. assert(result_type != nullptr && a != nullptr);
  655. const analysis::Integer* integer_type = a->type()->AsInteger();
  656. const analysis::Float* float_type = result_type->AsFloat();
  657. assert(float_type != nullptr);
  658. assert(integer_type != nullptr);
  659. if (integer_type->width() != 32) return nullptr;
  660. uint32_t ua = a->GetU32();
  661. if (float_type->width() == 32) {
  662. float result_val = integer_type->IsSigned()
  663. ? static_cast<float>(static_cast<int32_t>(ua))
  664. : static_cast<float>(ua);
  665. utils::FloatProxy<float> result(result_val);
  666. std::vector<uint32_t> words = {result.data()};
  667. return const_mgr->GetConstant(result_type, words);
  668. } else if (float_type->width() == 64) {
  669. double result_val = integer_type->IsSigned()
  670. ? static_cast<double>(static_cast<int32_t>(ua))
  671. : static_cast<double>(ua);
  672. utils::FloatProxy<double> result(result_val);
  673. std::vector<uint32_t> words = result.GetWords();
  674. return const_mgr->GetConstant(result_type, words);
  675. }
  676. return nullptr;
  677. };
  678. }
  679. // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
  680. UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
  681. return [](const analysis::Type* result_type, const analysis::Constant* a,
  682. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  683. assert(result_type != nullptr && a != nullptr);
  684. const analysis::Float* float_type = a->type()->AsFloat();
  685. assert(float_type != nullptr);
  686. if (float_type->width() != 32) {
  687. return nullptr;
  688. }
  689. float fa = a->GetFloat();
  690. utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
  691. utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
  692. utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
  693. orignal.castTo(quantized, utils::round_direction::kToZero);
  694. quantized.castTo(result, utils::round_direction::kToZero);
  695. std::vector<uint32_t> words = {result.getBits()};
  696. return const_mgr->GetConstant(result_type, words);
  697. };
  698. }
  699. // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
  700. // operator |op| must work for both float and double, and use syntax "f1 op f2".
  701. #define FOLD_FPARITH_OP(op) \
  702. [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
  703. const analysis::Constant* b, \
  704. analysis::ConstantManager* const_mgr_in_macro) \
  705. -> const analysis::Constant* { \
  706. assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
  707. assert(result_type_in_macro == a->type() && \
  708. result_type_in_macro == b->type()); \
  709. const analysis::Float* float_type_in_macro = \
  710. result_type_in_macro->AsFloat(); \
  711. assert(float_type_in_macro != nullptr); \
  712. if (float_type_in_macro->width() == 32) { \
  713. float fa = a->GetFloat(); \
  714. float fb = b->GetFloat(); \
  715. utils::FloatProxy<float> result_in_macro(fa op fb); \
  716. std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
  717. return const_mgr_in_macro->GetConstant(result_type_in_macro, \
  718. words_in_macro); \
  719. } else if (float_type_in_macro->width() == 64) { \
  720. double fa = a->GetDouble(); \
  721. double fb = b->GetDouble(); \
  722. utils::FloatProxy<double> result_in_macro(fa op fb); \
  723. std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
  724. return const_mgr_in_macro->GetConstant(result_type_in_macro, \
  725. words_in_macro); \
  726. } \
  727. return nullptr; \
  728. }
  729. // Define the folding rule for conversion between floating point and integer
  730. ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
  731. ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
  732. ConstantFoldingRule FoldQuantizeToF16() {
  733. return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
  734. }
  735. // Define the folding rules for subtraction, addition, multiplication, and
  736. // division for floating point values.
  737. ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
  738. ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
  739. ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
  740. // Returns the constant that results from evaluating |numerator| / 0.0. Returns
  741. // |nullptr| if the result could not be evaluated.
  742. const analysis::Constant* FoldFPScalarDivideByZero(
  743. const analysis::Type* result_type, const analysis::Constant* numerator,
  744. analysis::ConstantManager* const_mgr) {
  745. if (numerator == nullptr) {
  746. return nullptr;
  747. }
  748. if (numerator->IsZero()) {
  749. return GetNan(result_type, const_mgr);
  750. }
  751. const analysis::Constant* result = GetInf(result_type, const_mgr);
  752. if (result == nullptr) {
  753. return nullptr;
  754. }
  755. if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
  756. result = NegateFPConst(result_type, result, const_mgr);
  757. }
  758. return result;
  759. }
  760. // Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
  761. // if it cannot be folded.
  762. const analysis::Constant* FoldScalarFPDivide(
  763. const analysis::Type* result_type, const analysis::Constant* numerator,
  764. const analysis::Constant* denominator,
  765. analysis::ConstantManager* const_mgr) {
  766. if (denominator == nullptr) {
  767. return nullptr;
  768. }
  769. if (denominator->IsZero()) {
  770. return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
  771. }
  772. const analysis::FloatConstant* denominator_float =
  773. denominator->AsFloatConstant();
  774. if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
  775. const analysis::Constant* result =
  776. FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
  777. if (result != nullptr)
  778. result = NegateFPConst(result_type, result, const_mgr);
  779. return result;
  780. } else {
  781. return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
  782. }
  783. }
  784. // Returns the constant folding rule to fold |OpFDiv| with two constants.
  785. ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
  786. bool CompareFloatingPoint(bool op_result, bool op_unordered,
  787. bool need_ordered) {
  788. if (need_ordered) {
  789. // operands are ordered and Operand 1 is |op| Operand 2
  790. return !op_unordered && op_result;
  791. } else {
  792. // operands are unordered or Operand 1 is |op| Operand 2
  793. return op_unordered || op_result;
  794. }
  795. }
  796. // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
  797. // operator |op| must work for both float and double, and use syntax "f1 op f2".
  798. #define FOLD_FPCMP_OP(op, ord) \
  799. [](const analysis::Type* result_type, const analysis::Constant* a, \
  800. const analysis::Constant* b, \
  801. analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
  802. assert(result_type != nullptr && a != nullptr && b != nullptr); \
  803. assert(result_type->AsBool()); \
  804. assert(a->type() == b->type()); \
  805. const analysis::Float* float_type = a->type()->AsFloat(); \
  806. assert(float_type != nullptr); \
  807. if (float_type->width() == 32) { \
  808. float fa = a->GetFloat(); \
  809. float fb = b->GetFloat(); \
  810. bool result = CompareFloatingPoint( \
  811. fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
  812. std::vector<uint32_t> words = {uint32_t(result)}; \
  813. return const_mgr->GetConstant(result_type, words); \
  814. } else if (float_type->width() == 64) { \
  815. double fa = a->GetDouble(); \
  816. double fb = b->GetDouble(); \
  817. bool result = CompareFloatingPoint( \
  818. fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
  819. std::vector<uint32_t> words = {uint32_t(result)}; \
  820. return const_mgr->GetConstant(result_type, words); \
  821. } \
  822. return nullptr; \
  823. }
  824. // Define the folding rules for ordered and unordered comparison for floating
  825. // point values.
  826. ConstantFoldingRule FoldFOrdEqual() {
  827. return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
  828. }
  829. ConstantFoldingRule FoldFUnordEqual() {
  830. return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
  831. }
  832. ConstantFoldingRule FoldFOrdNotEqual() {
  833. return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
  834. }
  835. ConstantFoldingRule FoldFUnordNotEqual() {
  836. return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
  837. }
  838. ConstantFoldingRule FoldFOrdLessThan() {
  839. return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
  840. }
  841. ConstantFoldingRule FoldFUnordLessThan() {
  842. return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
  843. }
  844. ConstantFoldingRule FoldFOrdGreaterThan() {
  845. return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
  846. }
  847. ConstantFoldingRule FoldFUnordGreaterThan() {
  848. return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
  849. }
  850. ConstantFoldingRule FoldFOrdLessThanEqual() {
  851. return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
  852. }
  853. ConstantFoldingRule FoldFUnordLessThanEqual() {
  854. return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
  855. }
  856. ConstantFoldingRule FoldFOrdGreaterThanEqual() {
  857. return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
  858. }
  859. ConstantFoldingRule FoldFUnordGreaterThanEqual() {
  860. return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
  861. }
  862. // Folds an OpDot where all of the inputs are constants to a
  863. // constant. A new constant is created if necessary.
  864. ConstantFoldingRule FoldOpDotWithConstants() {
  865. return [](IRContext* context, Instruction* inst,
  866. const std::vector<const analysis::Constant*>& constants)
  867. -> const analysis::Constant* {
  868. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  869. analysis::TypeManager* type_mgr = context->get_type_mgr();
  870. const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
  871. assert(new_type->AsFloat() && "OpDot should have a float return type.");
  872. const analysis::Float* float_type = new_type->AsFloat();
  873. if (!inst->IsFloatingPointFoldingAllowed()) {
  874. return nullptr;
  875. }
  876. // If one of the operands is 0, then the result is 0.
  877. bool has_zero_operand = false;
  878. for (int i = 0; i < 2; ++i) {
  879. if (constants[i]) {
  880. if (constants[i]->AsNullConstant() ||
  881. constants[i]->AsVectorConstant()->IsZero()) {
  882. has_zero_operand = true;
  883. break;
  884. }
  885. }
  886. }
  887. if (has_zero_operand) {
  888. if (float_type->width() == 32) {
  889. utils::FloatProxy<float> result(0.0f);
  890. std::vector<uint32_t> words = result.GetWords();
  891. return const_mgr->GetConstant(float_type, words);
  892. }
  893. if (float_type->width() == 64) {
  894. utils::FloatProxy<double> result(0.0);
  895. std::vector<uint32_t> words = result.GetWords();
  896. return const_mgr->GetConstant(float_type, words);
  897. }
  898. return nullptr;
  899. }
  900. if (constants[0] == nullptr || constants[1] == nullptr) {
  901. return nullptr;
  902. }
  903. std::vector<const analysis::Constant*> a_components;
  904. std::vector<const analysis::Constant*> b_components;
  905. a_components = constants[0]->GetVectorComponents(const_mgr);
  906. b_components = constants[1]->GetVectorComponents(const_mgr);
  907. utils::FloatProxy<double> result(0.0);
  908. std::vector<uint32_t> words = result.GetWords();
  909. const analysis::Constant* result_const =
  910. const_mgr->GetConstant(float_type, words);
  911. for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
  912. ++i) {
  913. if (a_components[i] == nullptr || b_components[i] == nullptr) {
  914. return nullptr;
  915. }
  916. const analysis::Constant* component = FOLD_FPARITH_OP(*)(
  917. new_type, a_components[i], b_components[i], const_mgr);
  918. if (component == nullptr) {
  919. return nullptr;
  920. }
  921. result_const =
  922. FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
  923. }
  924. return result_const;
  925. };
  926. }
  927. // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
  928. // from zero.
  929. UnaryScalarFoldingRule FoldFNegateOp() {
  930. return [](const analysis::Type* result_type, const analysis::Constant* a,
  931. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  932. assert(result_type != nullptr && a != nullptr);
  933. assert(result_type == a->type());
  934. return NegateFPConst(result_type, a, const_mgr);
  935. };
  936. }
  937. ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
  938. ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
  939. return [cmp_opcode](IRContext* context, Instruction* inst,
  940. const std::vector<const analysis::Constant*>& constants)
  941. -> const analysis::Constant* {
  942. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  943. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  944. if (!inst->IsFloatingPointFoldingAllowed()) {
  945. return nullptr;
  946. }
  947. uint32_t non_const_idx = (constants[0] ? 1 : 0);
  948. uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
  949. Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
  950. analysis::TypeManager* type_mgr = context->get_type_mgr();
  951. const analysis::Type* operand_type =
  952. type_mgr->GetType(operand_inst->type_id());
  953. if (!operand_type->AsFloat()) {
  954. return nullptr;
  955. }
  956. if (operand_type->AsFloat()->width() != 32 &&
  957. operand_type->AsFloat()->width() != 64) {
  958. return nullptr;
  959. }
  960. if (operand_inst->opcode() != spv::Op::OpExtInst) {
  961. return nullptr;
  962. }
  963. if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
  964. return nullptr;
  965. }
  966. if (constants[1] == nullptr && constants[0] == nullptr) {
  967. return nullptr;
  968. }
  969. uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
  970. const analysis::Constant* max_const =
  971. const_mgr->FindDeclaredConstant(max_id);
  972. uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
  973. const analysis::Constant* min_const =
  974. const_mgr->FindDeclaredConstant(min_id);
  975. bool found_result = false;
  976. bool result = false;
  977. switch (cmp_opcode) {
  978. case spv::Op::OpFOrdLessThan:
  979. case spv::Op::OpFUnordLessThan:
  980. case spv::Op::OpFOrdGreaterThanEqual:
  981. case spv::Op::OpFUnordGreaterThanEqual:
  982. if (constants[0]) {
  983. if (min_const) {
  984. if (constants[0]->GetValueAsDouble() <
  985. min_const->GetValueAsDouble()) {
  986. found_result = true;
  987. result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
  988. cmp_opcode == spv::Op::OpFUnordLessThan);
  989. }
  990. }
  991. if (max_const) {
  992. if (constants[0]->GetValueAsDouble() >=
  993. max_const->GetValueAsDouble()) {
  994. found_result = true;
  995. result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
  996. cmp_opcode == spv::Op::OpFUnordLessThan);
  997. }
  998. }
  999. }
  1000. if (constants[1]) {
  1001. if (max_const) {
  1002. if (max_const->GetValueAsDouble() <
  1003. constants[1]->GetValueAsDouble()) {
  1004. found_result = true;
  1005. result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
  1006. cmp_opcode == spv::Op::OpFUnordLessThan);
  1007. }
  1008. }
  1009. if (min_const) {
  1010. if (min_const->GetValueAsDouble() >=
  1011. constants[1]->GetValueAsDouble()) {
  1012. found_result = true;
  1013. result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
  1014. cmp_opcode == spv::Op::OpFUnordLessThan);
  1015. }
  1016. }
  1017. }
  1018. break;
  1019. case spv::Op::OpFOrdGreaterThan:
  1020. case spv::Op::OpFUnordGreaterThan:
  1021. case spv::Op::OpFOrdLessThanEqual:
  1022. case spv::Op::OpFUnordLessThanEqual:
  1023. if (constants[0]) {
  1024. if (min_const) {
  1025. if (constants[0]->GetValueAsDouble() <=
  1026. min_const->GetValueAsDouble()) {
  1027. found_result = true;
  1028. result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
  1029. cmp_opcode == spv::Op::OpFUnordLessThanEqual);
  1030. }
  1031. }
  1032. if (max_const) {
  1033. if (constants[0]->GetValueAsDouble() >
  1034. max_const->GetValueAsDouble()) {
  1035. found_result = true;
  1036. result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
  1037. cmp_opcode == spv::Op::OpFUnordLessThanEqual);
  1038. }
  1039. }
  1040. }
  1041. if (constants[1]) {
  1042. if (max_const) {
  1043. if (max_const->GetValueAsDouble() <=
  1044. constants[1]->GetValueAsDouble()) {
  1045. found_result = true;
  1046. result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
  1047. cmp_opcode == spv::Op::OpFUnordLessThanEqual);
  1048. }
  1049. }
  1050. if (min_const) {
  1051. if (min_const->GetValueAsDouble() >
  1052. constants[1]->GetValueAsDouble()) {
  1053. found_result = true;
  1054. result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
  1055. cmp_opcode == spv::Op::OpFUnordLessThanEqual);
  1056. }
  1057. }
  1058. }
  1059. break;
  1060. default:
  1061. return nullptr;
  1062. }
  1063. if (!found_result) {
  1064. return nullptr;
  1065. }
  1066. const analysis::Type* bool_type =
  1067. context->get_type_mgr()->GetType(inst->type_id());
  1068. const analysis::Constant* result_const =
  1069. const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
  1070. assert(result_const);
  1071. return result_const;
  1072. };
  1073. }
  1074. ConstantFoldingRule FoldFMix() {
  1075. return [](IRContext* context, Instruction* inst,
  1076. const std::vector<const analysis::Constant*>& constants)
  1077. -> const analysis::Constant* {
  1078. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1079. assert(inst->opcode() == spv::Op::OpExtInst &&
  1080. "Expecting an extended instruction.");
  1081. assert(inst->GetSingleWordInOperand(0) ==
  1082. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
  1083. "Expecting a GLSLstd450 extended instruction.");
  1084. assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
  1085. "Expecting and FMix instruction.");
  1086. if (!inst->IsFloatingPointFoldingAllowed()) {
  1087. return nullptr;
  1088. }
  1089. // Make sure all FMix operands are constants.
  1090. for (uint32_t i = 1; i < 4; i++) {
  1091. if (constants[i] == nullptr) {
  1092. return nullptr;
  1093. }
  1094. }
  1095. const analysis::Constant* one;
  1096. bool is_vector = false;
  1097. const analysis::Type* result_type = constants[1]->type();
  1098. const analysis::Type* base_type = result_type;
  1099. if (base_type->AsVector()) {
  1100. is_vector = true;
  1101. base_type = base_type->AsVector()->element_type();
  1102. }
  1103. assert(base_type->AsFloat() != nullptr &&
  1104. "FMix is suppose to act on floats or vectors of floats.");
  1105. if (base_type->AsFloat()->width() == 32) {
  1106. one = const_mgr->GetConstant(base_type,
  1107. utils::FloatProxy<float>(1.0f).GetWords());
  1108. } else {
  1109. one = const_mgr->GetConstant(base_type,
  1110. utils::FloatProxy<double>(1.0).GetWords());
  1111. }
  1112. if (is_vector) {
  1113. uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
  1114. one =
  1115. const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
  1116. }
  1117. const analysis::Constant* temp1 = FoldFPBinaryOp(
  1118. FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
  1119. if (temp1 == nullptr) {
  1120. return nullptr;
  1121. }
  1122. const analysis::Constant* temp2 = FoldFPBinaryOp(
  1123. FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
  1124. if (temp2 == nullptr) {
  1125. return nullptr;
  1126. }
  1127. const analysis::Constant* temp3 =
  1128. FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
  1129. {constants[2], constants[3]}, context);
  1130. if (temp3 == nullptr) {
  1131. return nullptr;
  1132. }
  1133. return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
  1134. context);
  1135. };
  1136. }
  1137. const analysis::Constant* FoldMin(const analysis::Type* result_type,
  1138. const analysis::Constant* a,
  1139. const analysis::Constant* b,
  1140. analysis::ConstantManager*) {
  1141. if (const analysis::Integer* int_type = result_type->AsInteger()) {
  1142. if (int_type->width() == 32) {
  1143. if (int_type->IsSigned()) {
  1144. int32_t va = a->GetS32();
  1145. int32_t vb = b->GetS32();
  1146. return (va < vb ? a : b);
  1147. } else {
  1148. uint32_t va = a->GetU32();
  1149. uint32_t vb = b->GetU32();
  1150. return (va < vb ? a : b);
  1151. }
  1152. } else if (int_type->width() == 64) {
  1153. if (int_type->IsSigned()) {
  1154. int64_t va = a->GetS64();
  1155. int64_t vb = b->GetS64();
  1156. return (va < vb ? a : b);
  1157. } else {
  1158. uint64_t va = a->GetU64();
  1159. uint64_t vb = b->GetU64();
  1160. return (va < vb ? a : b);
  1161. }
  1162. }
  1163. } else if (const analysis::Float* float_type = result_type->AsFloat()) {
  1164. if (float_type->width() == 32) {
  1165. float va = a->GetFloat();
  1166. float vb = b->GetFloat();
  1167. return (va < vb ? a : b);
  1168. } else if (float_type->width() == 64) {
  1169. double va = a->GetDouble();
  1170. double vb = b->GetDouble();
  1171. return (va < vb ? a : b);
  1172. }
  1173. }
  1174. return nullptr;
  1175. }
  1176. const analysis::Constant* FoldMax(const analysis::Type* result_type,
  1177. const analysis::Constant* a,
  1178. const analysis::Constant* b,
  1179. analysis::ConstantManager*) {
  1180. if (const analysis::Integer* int_type = result_type->AsInteger()) {
  1181. if (int_type->width() == 32) {
  1182. if (int_type->IsSigned()) {
  1183. int32_t va = a->GetS32();
  1184. int32_t vb = b->GetS32();
  1185. return (va > vb ? a : b);
  1186. } else {
  1187. uint32_t va = a->GetU32();
  1188. uint32_t vb = b->GetU32();
  1189. return (va > vb ? a : b);
  1190. }
  1191. } else if (int_type->width() == 64) {
  1192. if (int_type->IsSigned()) {
  1193. int64_t va = a->GetS64();
  1194. int64_t vb = b->GetS64();
  1195. return (va > vb ? a : b);
  1196. } else {
  1197. uint64_t va = a->GetU64();
  1198. uint64_t vb = b->GetU64();
  1199. return (va > vb ? a : b);
  1200. }
  1201. }
  1202. } else if (const analysis::Float* float_type = result_type->AsFloat()) {
  1203. if (float_type->width() == 32) {
  1204. float va = a->GetFloat();
  1205. float vb = b->GetFloat();
  1206. return (va > vb ? a : b);
  1207. } else if (float_type->width() == 64) {
  1208. double va = a->GetDouble();
  1209. double vb = b->GetDouble();
  1210. return (va > vb ? a : b);
  1211. }
  1212. }
  1213. return nullptr;
  1214. }
  1215. // Fold an clamp instruction when all three operands are constant.
  1216. const analysis::Constant* FoldClamp1(
  1217. IRContext* context, Instruction* inst,
  1218. const std::vector<const analysis::Constant*>& constants) {
  1219. assert(inst->opcode() == spv::Op::OpExtInst &&
  1220. "Expecting an extended instruction.");
  1221. assert(inst->GetSingleWordInOperand(0) ==
  1222. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
  1223. "Expecting a GLSLstd450 extended instruction.");
  1224. // Make sure all Clamp operands are constants.
  1225. for (uint32_t i = 1; i < 4; i++) {
  1226. if (constants[i] == nullptr) {
  1227. return nullptr;
  1228. }
  1229. }
  1230. const analysis::Constant* temp = FoldFPBinaryOp(
  1231. FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
  1232. if (temp == nullptr) {
  1233. return nullptr;
  1234. }
  1235. return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
  1236. context);
  1237. }
  1238. // Fold a clamp instruction when |x <= min_val|.
  1239. const analysis::Constant* FoldClamp2(
  1240. IRContext* context, Instruction* inst,
  1241. const std::vector<const analysis::Constant*>& constants) {
  1242. assert(inst->opcode() == spv::Op::OpExtInst &&
  1243. "Expecting an extended instruction.");
  1244. assert(inst->GetSingleWordInOperand(0) ==
  1245. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
  1246. "Expecting a GLSLstd450 extended instruction.");
  1247. const analysis::Constant* x = constants[1];
  1248. const analysis::Constant* min_val = constants[2];
  1249. if (x == nullptr || min_val == nullptr) {
  1250. return nullptr;
  1251. }
  1252. const analysis::Constant* temp =
  1253. FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
  1254. if (temp == min_val) {
  1255. // We can assume that |min_val| is less than |max_val|. Therefore, if the
  1256. // result of the max operation is |min_val|, we know the result of the min
  1257. // operation, even if |max_val| is not a constant.
  1258. return min_val;
  1259. }
  1260. return nullptr;
  1261. }
  1262. // Fold a clamp instruction when |x >= max_val|.
  1263. const analysis::Constant* FoldClamp3(
  1264. IRContext* context, Instruction* inst,
  1265. const std::vector<const analysis::Constant*>& constants) {
  1266. assert(inst->opcode() == spv::Op::OpExtInst &&
  1267. "Expecting an extended instruction.");
  1268. assert(inst->GetSingleWordInOperand(0) ==
  1269. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
  1270. "Expecting a GLSLstd450 extended instruction.");
  1271. const analysis::Constant* x = constants[1];
  1272. const analysis::Constant* max_val = constants[3];
  1273. if (x == nullptr || max_val == nullptr) {
  1274. return nullptr;
  1275. }
  1276. const analysis::Constant* temp =
  1277. FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
  1278. if (temp == max_val) {
  1279. // We can assume that |min_val| is less than |max_val|. Therefore, if the
  1280. // result of the max operation is |min_val|, we know the result of the min
  1281. // operation, even if |max_val| is not a constant.
  1282. return max_val;
  1283. }
  1284. return nullptr;
  1285. }
  1286. UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
  1287. return
  1288. [fp](const analysis::Type* result_type, const analysis::Constant* a,
  1289. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  1290. assert(result_type != nullptr && a != nullptr);
  1291. const analysis::Float* float_type = a->type()->AsFloat();
  1292. assert(float_type != nullptr);
  1293. assert(float_type == result_type->AsFloat());
  1294. if (float_type->width() == 32) {
  1295. float fa = a->GetFloat();
  1296. float res = static_cast<float>(fp(fa));
  1297. utils::FloatProxy<float> result(res);
  1298. std::vector<uint32_t> words = result.GetWords();
  1299. return const_mgr->GetConstant(result_type, words);
  1300. } else if (float_type->width() == 64) {
  1301. double fa = a->GetDouble();
  1302. double res = fp(fa);
  1303. utils::FloatProxy<double> result(res);
  1304. std::vector<uint32_t> words = result.GetWords();
  1305. return const_mgr->GetConstant(result_type, words);
  1306. }
  1307. return nullptr;
  1308. };
  1309. }
  1310. BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
  1311. double)) {
  1312. return
  1313. [fp](const analysis::Type* result_type, const analysis::Constant* a,
  1314. const analysis::Constant* b,
  1315. analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
  1316. assert(result_type != nullptr && a != nullptr);
  1317. const analysis::Float* float_type = a->type()->AsFloat();
  1318. assert(float_type != nullptr);
  1319. assert(float_type == result_type->AsFloat());
  1320. assert(float_type == b->type()->AsFloat());
  1321. if (float_type->width() == 32) {
  1322. float fa = a->GetFloat();
  1323. float fb = b->GetFloat();
  1324. float res = static_cast<float>(fp(fa, fb));
  1325. utils::FloatProxy<float> result(res);
  1326. std::vector<uint32_t> words = result.GetWords();
  1327. return const_mgr->GetConstant(result_type, words);
  1328. } else if (float_type->width() == 64) {
  1329. double fa = a->GetDouble();
  1330. double fb = b->GetDouble();
  1331. double res = fp(fa, fb);
  1332. utils::FloatProxy<double> result(res);
  1333. std::vector<uint32_t> words = result.GetWords();
  1334. return const_mgr->GetConstant(result_type, words);
  1335. }
  1336. return nullptr;
  1337. };
  1338. }
  1339. } // namespace
  1340. void ConstantFoldingRules::AddFoldingRules() {
  1341. // Add all folding rules to the list for the opcodes to which they apply.
  1342. // Note that the order in which rules are added to the list matters. If a rule
  1343. // applies to the instruction, the rest of the rules will not be attempted.
  1344. // Take that into consideration.
  1345. rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
  1346. rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
  1347. rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
  1348. rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
  1349. rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
  1350. rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
  1351. rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
  1352. rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
  1353. rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
  1354. rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
  1355. rules_[spv::Op::OpFMul].push_back(FoldFMul());
  1356. rules_[spv::Op::OpFSub].push_back(FoldFSub());
  1357. rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
  1358. rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
  1359. rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
  1360. rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
  1361. rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
  1362. rules_[spv::Op::OpFOrdLessThan].push_back(
  1363. FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
  1364. rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
  1365. rules_[spv::Op::OpFUnordLessThan].push_back(
  1366. FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
  1367. rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
  1368. rules_[spv::Op::OpFOrdGreaterThan].push_back(
  1369. FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
  1370. rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
  1371. rules_[spv::Op::OpFUnordGreaterThan].push_back(
  1372. FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
  1373. rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
  1374. rules_[spv::Op::OpFOrdLessThanEqual].push_back(
  1375. FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
  1376. rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
  1377. rules_[spv::Op::OpFUnordLessThanEqual].push_back(
  1378. FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
  1379. rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
  1380. rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
  1381. FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
  1382. rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
  1383. FoldFUnordGreaterThanEqual());
  1384. rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
  1385. FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
  1386. rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
  1387. rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
  1388. rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
  1389. rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
  1390. rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
  1391. rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
  1392. // Add rules for GLSLstd450
  1393. FeatureManager* feature_manager = context_->get_feature_mgr();
  1394. uint32_t ext_inst_glslstd450_id =
  1395. feature_manager->GetExtInstImportId_GLSLstd450();
  1396. if (ext_inst_glslstd450_id != 0) {
  1397. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
  1398. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
  1399. FoldFPBinaryOp(FoldMin));
  1400. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
  1401. FoldFPBinaryOp(FoldMin));
  1402. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
  1403. FoldFPBinaryOp(FoldMin));
  1404. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
  1405. FoldFPBinaryOp(FoldMax));
  1406. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
  1407. FoldFPBinaryOp(FoldMax));
  1408. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
  1409. FoldFPBinaryOp(FoldMax));
  1410. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
  1411. FoldClamp1);
  1412. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
  1413. FoldClamp2);
  1414. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
  1415. FoldClamp3);
  1416. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
  1417. FoldClamp1);
  1418. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
  1419. FoldClamp2);
  1420. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
  1421. FoldClamp3);
  1422. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
  1423. FoldClamp1);
  1424. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
  1425. FoldClamp2);
  1426. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
  1427. FoldClamp3);
  1428. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
  1429. FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
  1430. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
  1431. FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
  1432. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
  1433. FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
  1434. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
  1435. FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
  1436. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
  1437. FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
  1438. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
  1439. FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
  1440. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
  1441. FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
  1442. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
  1443. FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
  1444. #ifdef __ANDROID__
  1445. // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
  1446. // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
  1447. // available up until ABI 18 so we use a shim
  1448. auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
  1449. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
  1450. FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
  1451. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
  1452. FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
  1453. #else
  1454. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
  1455. FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
  1456. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
  1457. FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
  1458. #endif
  1459. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
  1460. FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
  1461. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
  1462. FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
  1463. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
  1464. FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
  1465. }
  1466. }
  1467. } // namespace opt
  1468. } // namespace spvtools