folding_rules.cpp 94 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <limits>
  16. #include <memory>
  17. #include <utility>
  18. #include "ir_builder.h"
  19. #include "source/latest_version_glsl_std_450_header.h"
  20. #include "source/opt/ir_context.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. const uint32_t kExtractCompositeIdInIdx = 0;
  25. const uint32_t kInsertObjectIdInIdx = 0;
  26. const uint32_t kInsertCompositeIdInIdx = 1;
  27. const uint32_t kExtInstSetIdInIdx = 0;
  28. const uint32_t kExtInstInstructionInIdx = 1;
  29. const uint32_t kFMixXIdInIdx = 2;
  30. const uint32_t kFMixYIdInIdx = 3;
  31. const uint32_t kFMixAIdInIdx = 4;
  32. const uint32_t kStoreObjectInIdx = 1;
  33. // Some image instructions may contain an "image operands" argument.
  34. // Returns the operand index for the "image operands".
  35. // Returns -1 if the instruction does not have image operands.
  36. int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
  37. const auto opcode = inst->opcode();
  38. switch (opcode) {
  39. case SpvOpImageSampleImplicitLod:
  40. case SpvOpImageSampleExplicitLod:
  41. case SpvOpImageSampleProjImplicitLod:
  42. case SpvOpImageSampleProjExplicitLod:
  43. case SpvOpImageFetch:
  44. case SpvOpImageRead:
  45. case SpvOpImageSparseSampleImplicitLod:
  46. case SpvOpImageSparseSampleExplicitLod:
  47. case SpvOpImageSparseSampleProjImplicitLod:
  48. case SpvOpImageSparseSampleProjExplicitLod:
  49. case SpvOpImageSparseFetch:
  50. case SpvOpImageSparseRead:
  51. return inst->NumOperands() > 4 ? 2 : -1;
  52. case SpvOpImageSampleDrefImplicitLod:
  53. case SpvOpImageSampleDrefExplicitLod:
  54. case SpvOpImageSampleProjDrefImplicitLod:
  55. case SpvOpImageSampleProjDrefExplicitLod:
  56. case SpvOpImageGather:
  57. case SpvOpImageDrefGather:
  58. case SpvOpImageSparseSampleDrefImplicitLod:
  59. case SpvOpImageSparseSampleDrefExplicitLod:
  60. case SpvOpImageSparseSampleProjDrefImplicitLod:
  61. case SpvOpImageSparseSampleProjDrefExplicitLod:
  62. case SpvOpImageSparseGather:
  63. case SpvOpImageSparseDrefGather:
  64. return inst->NumOperands() > 5 ? 3 : -1;
  65. case SpvOpImageWrite:
  66. return inst->NumOperands() > 3 ? 3 : -1;
  67. default:
  68. return -1;
  69. }
  70. }
  71. // Returns the element width of |type|.
  72. uint32_t ElementWidth(const analysis::Type* type) {
  73. if (const analysis::Vector* vec_type = type->AsVector()) {
  74. return ElementWidth(vec_type->element_type());
  75. } else if (const analysis::Float* float_type = type->AsFloat()) {
  76. return float_type->width();
  77. } else {
  78. assert(type->AsInteger());
  79. return type->AsInteger()->width();
  80. }
  81. }
  82. // Returns true if |type| is Float or a vector of Float.
  83. bool HasFloatingPoint(const analysis::Type* type) {
  84. if (type->AsFloat()) {
  85. return true;
  86. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  87. return vec_type->element_type()->AsFloat() != nullptr;
  88. }
  89. return false;
  90. }
  91. // Returns false if |val| is NaN, infinite or subnormal.
  92. template <typename T>
  93. bool IsValidResult(T val) {
  94. int classified = std::fpclassify(val);
  95. switch (classified) {
  96. case FP_NAN:
  97. case FP_INFINITE:
  98. case FP_SUBNORMAL:
  99. return false;
  100. default:
  101. return true;
  102. }
  103. }
  104. const analysis::Constant* ConstInput(
  105. const std::vector<const analysis::Constant*>& constants) {
  106. return constants[0] ? constants[0] : constants[1];
  107. }
  108. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  109. Instruction* inst) {
  110. uint32_t in_op = c ? 1u : 0u;
  111. return context->get_def_use_mgr()->GetDef(
  112. inst->GetSingleWordInOperand(in_op));
  113. }
  114. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  115. // constant.
  116. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  117. const analysis::Constant* c) {
  118. assert(c);
  119. assert(c->type()->AsFloat());
  120. uint32_t width = c->type()->AsFloat()->width();
  121. assert(width == 32 || width == 64);
  122. std::vector<uint32_t> words;
  123. if (width == 64) {
  124. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  125. words = result.GetWords();
  126. } else {
  127. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  128. words = result.GetWords();
  129. }
  130. const analysis::Constant* negated_const =
  131. const_mgr->GetConstant(c->type(), std::move(words));
  132. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  133. }
  134. std::vector<uint32_t> ExtractInts(uint64_t val) {
  135. std::vector<uint32_t> words;
  136. words.push_back(static_cast<uint32_t>(val));
  137. words.push_back(static_cast<uint32_t>(val >> 32));
  138. return words;
  139. }
  140. // Negates the integer constant |c|. Returns the id of the defining instruction.
  141. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  142. const analysis::Constant* c) {
  143. assert(c);
  144. assert(c->type()->AsInteger());
  145. uint32_t width = c->type()->AsInteger()->width();
  146. assert(width == 32 || width == 64);
  147. std::vector<uint32_t> words;
  148. if (width == 64) {
  149. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  150. words = ExtractInts(uval);
  151. } else {
  152. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  153. }
  154. const analysis::Constant* negated_const =
  155. const_mgr->GetConstant(c->type(), std::move(words));
  156. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  157. }
  158. // Negates the vector constant |c|. Returns the id of the defining instruction.
  159. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  160. const analysis::Constant* c) {
  161. assert(const_mgr && c);
  162. assert(c->type()->AsVector());
  163. if (c->AsNullConstant()) {
  164. // 0.0 vs -0.0 shouldn't matter.
  165. return const_mgr->GetDefiningInstruction(c)->result_id();
  166. } else {
  167. const analysis::Type* component_type =
  168. c->AsVectorConstant()->component_type();
  169. std::vector<uint32_t> words;
  170. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  171. if (component_type->AsFloat()) {
  172. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  173. } else {
  174. assert(component_type->AsInteger());
  175. words.push_back(NegateIntegerConstant(const_mgr, comp));
  176. }
  177. }
  178. const analysis::Constant* negated_const =
  179. const_mgr->GetConstant(c->type(), std::move(words));
  180. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  181. }
  182. }
  183. // Negates |c|. Returns the id of the defining instruction.
  184. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  185. const analysis::Constant* c) {
  186. if (c->type()->AsVector()) {
  187. return NegateVectorConstant(const_mgr, c);
  188. } else if (c->type()->AsFloat()) {
  189. return NegateFloatingPointConstant(const_mgr, c);
  190. } else {
  191. assert(c->type()->AsInteger());
  192. return NegateIntegerConstant(const_mgr, c);
  193. }
  194. }
  195. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  196. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  197. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  198. const analysis::Constant* c) {
  199. assert(const_mgr && c);
  200. assert(c->type()->AsFloat());
  201. uint32_t width = c->type()->AsFloat()->width();
  202. assert(width == 32 || width == 64);
  203. std::vector<uint32_t> words;
  204. if (width == 64) {
  205. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  206. if (!IsValidResult(result.getAsFloat())) return 0;
  207. words = result.GetWords();
  208. } else {
  209. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  210. if (!IsValidResult(result.getAsFloat())) return 0;
  211. words = result.GetWords();
  212. }
  213. const analysis::Constant* negated_const =
  214. const_mgr->GetConstant(c->type(), std::move(words));
  215. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  216. }
  217. // Replaces fdiv where second operand is constant with fmul.
  218. FoldingRule ReciprocalFDiv() {
  219. return [](IRContext* context, Instruction* inst,
  220. const std::vector<const analysis::Constant*>& constants) {
  221. assert(inst->opcode() == SpvOpFDiv);
  222. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  223. const analysis::Type* type =
  224. context->get_type_mgr()->GetType(inst->type_id());
  225. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  226. uint32_t width = ElementWidth(type);
  227. if (width != 32 && width != 64) return false;
  228. if (constants[1] != nullptr) {
  229. uint32_t id = 0;
  230. if (const analysis::VectorConstant* vector_const =
  231. constants[1]->AsVectorConstant()) {
  232. std::vector<uint32_t> neg_ids;
  233. for (auto& comp : vector_const->GetComponents()) {
  234. id = Reciprocal(const_mgr, comp);
  235. if (id == 0) return false;
  236. neg_ids.push_back(id);
  237. }
  238. const analysis::Constant* negated_const =
  239. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  240. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  241. } else if (constants[1]->AsFloatConstant()) {
  242. id = Reciprocal(const_mgr, constants[1]);
  243. if (id == 0) return false;
  244. } else {
  245. // Don't fold a null constant.
  246. return false;
  247. }
  248. inst->SetOpcode(SpvOpFMul);
  249. inst->SetInOperands(
  250. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  251. {SPV_OPERAND_TYPE_ID, {id}}});
  252. return true;
  253. }
  254. return false;
  255. };
  256. }
  257. // Elides consecutive negate instructions.
  258. FoldingRule MergeNegateArithmetic() {
  259. return [](IRContext* context, Instruction* inst,
  260. const std::vector<const analysis::Constant*>& constants) {
  261. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  262. (void)constants;
  263. const analysis::Type* type =
  264. context->get_type_mgr()->GetType(inst->type_id());
  265. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  266. return false;
  267. Instruction* op_inst =
  268. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  269. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  270. return false;
  271. if (op_inst->opcode() == inst->opcode()) {
  272. // Elide negates.
  273. inst->SetOpcode(SpvOpCopyObject);
  274. inst->SetInOperands(
  275. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  276. return true;
  277. }
  278. return false;
  279. };
  280. }
  281. // Merges negate into a mul or div operation if that operation contains a
  282. // constant operand.
  283. // Cases:
  284. // -(x * 2) = x * -2
  285. // -(2 * x) = x * -2
  286. // -(x / 2) = x / -2
  287. // -(2 / x) = -2 / x
  288. FoldingRule MergeNegateMulDivArithmetic() {
  289. return [](IRContext* context, Instruction* inst,
  290. const std::vector<const analysis::Constant*>& constants) {
  291. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  292. (void)constants;
  293. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  294. const analysis::Type* type =
  295. context->get_type_mgr()->GetType(inst->type_id());
  296. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  297. return false;
  298. Instruction* op_inst =
  299. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  300. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  301. return false;
  302. uint32_t width = ElementWidth(type);
  303. if (width != 32 && width != 64) return false;
  304. SpvOp opcode = op_inst->opcode();
  305. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  306. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  307. std::vector<const analysis::Constant*> op_constants =
  308. const_mgr->GetOperandConstants(op_inst);
  309. // Merge negate into mul or div if one operand is constant.
  310. if (op_constants[0] || op_constants[1]) {
  311. bool zero_is_variable = op_constants[0] == nullptr;
  312. const analysis::Constant* c = ConstInput(op_constants);
  313. uint32_t neg_id = NegateConstant(const_mgr, c);
  314. uint32_t non_const_id = zero_is_variable
  315. ? op_inst->GetSingleWordInOperand(0u)
  316. : op_inst->GetSingleWordInOperand(1u);
  317. // Change this instruction to a mul/div.
  318. inst->SetOpcode(op_inst->opcode());
  319. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  320. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  321. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  322. inst->SetInOperands(
  323. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  324. } else {
  325. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  326. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  327. }
  328. return true;
  329. }
  330. }
  331. return false;
  332. };
  333. }
  334. // Merges negate into a add or sub operation if that operation contains a
  335. // constant operand.
  336. // Cases:
  337. // -(x + 2) = -2 - x
  338. // -(2 + x) = -2 - x
  339. // -(x - 2) = 2 - x
  340. // -(2 - x) = x - 2
  341. FoldingRule MergeNegateAddSubArithmetic() {
  342. return [](IRContext* context, Instruction* inst,
  343. const std::vector<const analysis::Constant*>& constants) {
  344. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  345. (void)constants;
  346. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  347. const analysis::Type* type =
  348. context->get_type_mgr()->GetType(inst->type_id());
  349. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  350. return false;
  351. Instruction* op_inst =
  352. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  353. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  354. return false;
  355. uint32_t width = ElementWidth(type);
  356. if (width != 32 && width != 64) return false;
  357. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  358. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  359. std::vector<const analysis::Constant*> op_constants =
  360. const_mgr->GetOperandConstants(op_inst);
  361. if (op_constants[0] || op_constants[1]) {
  362. bool zero_is_variable = op_constants[0] == nullptr;
  363. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  364. (op_inst->opcode() == SpvOpIAdd);
  365. bool swap_operands = !is_add || zero_is_variable;
  366. bool negate_const = is_add;
  367. const analysis::Constant* c = ConstInput(op_constants);
  368. uint32_t const_id = 0;
  369. if (negate_const) {
  370. const_id = NegateConstant(const_mgr, c);
  371. } else {
  372. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  373. : op_inst->GetSingleWordInOperand(0u);
  374. }
  375. // Swap operands if necessary and make the instruction a subtraction.
  376. uint32_t op0 =
  377. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  378. uint32_t op1 =
  379. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  380. if (swap_operands) std::swap(op0, op1);
  381. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  382. inst->SetInOperands(
  383. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  384. return true;
  385. }
  386. }
  387. return false;
  388. };
  389. }
  390. // Returns true if |c| has a zero element.
  391. bool HasZero(const analysis::Constant* c) {
  392. if (c->AsNullConstant()) {
  393. return true;
  394. }
  395. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  396. for (auto& comp : vec_const->GetComponents())
  397. if (HasZero(comp)) return true;
  398. } else {
  399. assert(c->AsScalarConstant());
  400. return c->AsScalarConstant()->IsZero();
  401. }
  402. return false;
  403. }
  404. // Performs |input1| |opcode| |input2| and returns the merged constant result
  405. // id. Returns 0 if the result is not a valid value. The input types must be
  406. // Float.
  407. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  408. SpvOp opcode,
  409. const analysis::Constant* input1,
  410. const analysis::Constant* input2) {
  411. const analysis::Type* type = input1->type();
  412. assert(type->AsFloat());
  413. uint32_t width = type->AsFloat()->width();
  414. assert(width == 32 || width == 64);
  415. std::vector<uint32_t> words;
  416. #define FOLD_OP(op) \
  417. if (width == 64) { \
  418. utils::FloatProxy<double> val = \
  419. input1->GetDouble() op input2->GetDouble(); \
  420. double dval = val.getAsFloat(); \
  421. if (!IsValidResult(dval)) return 0; \
  422. words = val.GetWords(); \
  423. } else { \
  424. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  425. float fval = val.getAsFloat(); \
  426. if (!IsValidResult(fval)) return 0; \
  427. words = val.GetWords(); \
  428. }
  429. switch (opcode) {
  430. case SpvOpFMul:
  431. FOLD_OP(*);
  432. break;
  433. case SpvOpFDiv:
  434. if (HasZero(input2)) return 0;
  435. FOLD_OP(/);
  436. break;
  437. case SpvOpFAdd:
  438. FOLD_OP(+);
  439. break;
  440. case SpvOpFSub:
  441. FOLD_OP(-);
  442. break;
  443. default:
  444. assert(false && "Unexpected operation");
  445. break;
  446. }
  447. #undef FOLD_OP
  448. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  449. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  450. }
  451. // Performs |input1| |opcode| |input2| and returns the merged constant result
  452. // id. Returns 0 if the result is not a valid value. The input types must be
  453. // Integers.
  454. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  455. SpvOp opcode, const analysis::Constant* input1,
  456. const analysis::Constant* input2) {
  457. assert(input1->type()->AsInteger());
  458. const analysis::Integer* type = input1->type()->AsInteger();
  459. uint32_t width = type->AsInteger()->width();
  460. assert(width == 32 || width == 64);
  461. std::vector<uint32_t> words;
  462. #define FOLD_OP(op) \
  463. if (width == 64) { \
  464. if (type->IsSigned()) { \
  465. int64_t val = input1->GetS64() op input2->GetS64(); \
  466. words = ExtractInts(static_cast<uint64_t>(val)); \
  467. } else { \
  468. uint64_t val = input1->GetU64() op input2->GetU64(); \
  469. words = ExtractInts(val); \
  470. } \
  471. } else { \
  472. if (type->IsSigned()) { \
  473. int32_t val = input1->GetS32() op input2->GetS32(); \
  474. words.push_back(static_cast<uint32_t>(val)); \
  475. } else { \
  476. uint32_t val = input1->GetU32() op input2->GetU32(); \
  477. words.push_back(val); \
  478. } \
  479. }
  480. switch (opcode) {
  481. case SpvOpIMul:
  482. FOLD_OP(*);
  483. break;
  484. case SpvOpSDiv:
  485. case SpvOpUDiv:
  486. assert(false && "Should not merge integer division");
  487. break;
  488. case SpvOpIAdd:
  489. FOLD_OP(+);
  490. break;
  491. case SpvOpISub:
  492. FOLD_OP(-);
  493. break;
  494. default:
  495. assert(false && "Unexpected operation");
  496. break;
  497. }
  498. #undef FOLD_OP
  499. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  500. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  501. }
  502. // Performs |input1| |opcode| |input2| and returns the merged constant result
  503. // id. Returns 0 if the result is not a valid value. The input types must be
  504. // Integers, Floats or Vectors of such.
  505. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  506. const analysis::Constant* input1,
  507. const analysis::Constant* input2) {
  508. assert(input1 && input2);
  509. const analysis::Type* type = input1->type();
  510. std::vector<uint32_t> words;
  511. if (const analysis::Vector* vector_type = type->AsVector()) {
  512. const analysis::Type* ele_type = vector_type->element_type();
  513. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  514. uint32_t id = 0;
  515. const analysis::Constant* input1_comp = nullptr;
  516. if (const analysis::VectorConstant* input1_vector =
  517. input1->AsVectorConstant()) {
  518. input1_comp = input1_vector->GetComponents()[i];
  519. } else {
  520. assert(input1->AsNullConstant());
  521. input1_comp = const_mgr->GetConstant(ele_type, {});
  522. }
  523. const analysis::Constant* input2_comp = nullptr;
  524. if (const analysis::VectorConstant* input2_vector =
  525. input2->AsVectorConstant()) {
  526. input2_comp = input2_vector->GetComponents()[i];
  527. } else {
  528. assert(input2->AsNullConstant());
  529. input2_comp = const_mgr->GetConstant(ele_type, {});
  530. }
  531. if (ele_type->AsFloat()) {
  532. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  533. input2_comp);
  534. } else {
  535. assert(ele_type->AsInteger());
  536. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  537. input2_comp);
  538. }
  539. if (id == 0) return 0;
  540. words.push_back(id);
  541. }
  542. const analysis::Constant* merged_const =
  543. const_mgr->GetConstant(type, words);
  544. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  545. } else if (type->AsFloat()) {
  546. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  547. } else {
  548. assert(type->AsInteger());
  549. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  550. }
  551. }
  552. // Merges consecutive multiplies where each contains one constant operand.
  553. // Cases:
  554. // 2 * (x * 2) = x * 4
  555. // 2 * (2 * x) = x * 4
  556. // (x * 2) * 2 = x * 4
  557. // (2 * x) * 2 = x * 4
  558. FoldingRule MergeMulMulArithmetic() {
  559. return [](IRContext* context, Instruction* inst,
  560. const std::vector<const analysis::Constant*>& constants) {
  561. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  562. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  563. const analysis::Type* type =
  564. context->get_type_mgr()->GetType(inst->type_id());
  565. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  566. return false;
  567. uint32_t width = ElementWidth(type);
  568. if (width != 32 && width != 64) return false;
  569. // Determine the constant input and the variable input in |inst|.
  570. const analysis::Constant* const_input1 = ConstInput(constants);
  571. if (!const_input1) return false;
  572. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  573. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  574. return false;
  575. if (other_inst->opcode() == inst->opcode()) {
  576. std::vector<const analysis::Constant*> other_constants =
  577. const_mgr->GetOperandConstants(other_inst);
  578. const analysis::Constant* const_input2 = ConstInput(other_constants);
  579. if (!const_input2) return false;
  580. bool other_first_is_variable = other_constants[0] == nullptr;
  581. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  582. const_input1, const_input2);
  583. if (merged_id == 0) return false;
  584. uint32_t non_const_id = other_first_is_variable
  585. ? other_inst->GetSingleWordInOperand(0u)
  586. : other_inst->GetSingleWordInOperand(1u);
  587. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  588. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  589. return true;
  590. }
  591. return false;
  592. };
  593. }
  594. // Merges divides into subsequent multiplies if each instruction contains one
  595. // constant operand. Does not support integer operations.
  596. // Cases:
  597. // 2 * (x / 2) = x * 1
  598. // 2 * (2 / x) = 4 / x
  599. // (x / 2) * 2 = x * 1
  600. // (2 / x) * 2 = 4 / x
  601. // (y / x) * x = y
  602. // x * (y / x) = y
  603. FoldingRule MergeMulDivArithmetic() {
  604. return [](IRContext* context, Instruction* inst,
  605. const std::vector<const analysis::Constant*>& constants) {
  606. assert(inst->opcode() == SpvOpFMul);
  607. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  608. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  609. const analysis::Type* type =
  610. context->get_type_mgr()->GetType(inst->type_id());
  611. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  612. uint32_t width = ElementWidth(type);
  613. if (width != 32 && width != 64) return false;
  614. for (uint32_t i = 0; i < 2; i++) {
  615. uint32_t op_id = inst->GetSingleWordInOperand(i);
  616. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  617. if (op_inst->opcode() == SpvOpFDiv) {
  618. if (op_inst->GetSingleWordInOperand(1) ==
  619. inst->GetSingleWordInOperand(1 - i)) {
  620. inst->SetOpcode(SpvOpCopyObject);
  621. inst->SetInOperands(
  622. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  623. return true;
  624. }
  625. }
  626. }
  627. const analysis::Constant* const_input1 = ConstInput(constants);
  628. if (!const_input1) return false;
  629. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  630. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  631. if (other_inst->opcode() == SpvOpFDiv) {
  632. std::vector<const analysis::Constant*> other_constants =
  633. const_mgr->GetOperandConstants(other_inst);
  634. const analysis::Constant* const_input2 = ConstInput(other_constants);
  635. if (!const_input2 || HasZero(const_input2)) return false;
  636. bool other_first_is_variable = other_constants[0] == nullptr;
  637. // If the variable value is the second operand of the divide, multiply
  638. // the constants together. Otherwise divide the constants.
  639. uint32_t merged_id = PerformOperation(
  640. const_mgr,
  641. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  642. const_input1, const_input2);
  643. if (merged_id == 0) return false;
  644. uint32_t non_const_id = other_first_is_variable
  645. ? other_inst->GetSingleWordInOperand(0u)
  646. : other_inst->GetSingleWordInOperand(1u);
  647. // If the variable value is on the second operand of the div, then this
  648. // operation is a div. Otherwise it should be a multiply.
  649. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  650. : other_inst->opcode());
  651. if (other_first_is_variable) {
  652. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  653. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  654. } else {
  655. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  656. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  657. }
  658. return true;
  659. }
  660. return false;
  661. };
  662. }
  663. // Merges multiply of constant and negation.
  664. // Cases:
  665. // (-x) * 2 = x * -2
  666. // 2 * (-x) = x * -2
  667. FoldingRule MergeMulNegateArithmetic() {
  668. return [](IRContext* context, Instruction* inst,
  669. const std::vector<const analysis::Constant*>& constants) {
  670. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  671. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  672. const analysis::Type* type =
  673. context->get_type_mgr()->GetType(inst->type_id());
  674. bool uses_float = HasFloatingPoint(type);
  675. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  676. uint32_t width = ElementWidth(type);
  677. if (width != 32 && width != 64) return false;
  678. const analysis::Constant* const_input1 = ConstInput(constants);
  679. if (!const_input1) return false;
  680. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  681. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  682. return false;
  683. if (other_inst->opcode() == SpvOpFNegate ||
  684. other_inst->opcode() == SpvOpSNegate) {
  685. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  686. inst->SetInOperands(
  687. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  688. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  689. return true;
  690. }
  691. return false;
  692. };
  693. }
  694. // Merges consecutive divides if each instruction contains one constant operand.
  695. // Does not support integer division.
  696. // Cases:
  697. // 2 / (x / 2) = 4 / x
  698. // 4 / (2 / x) = 2 * x
  699. // (4 / x) / 2 = 2 / x
  700. // (x / 2) / 2 = x / 4
  701. FoldingRule MergeDivDivArithmetic() {
  702. return [](IRContext* context, Instruction* inst,
  703. const std::vector<const analysis::Constant*>& constants) {
  704. assert(inst->opcode() == SpvOpFDiv);
  705. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  706. const analysis::Type* type =
  707. context->get_type_mgr()->GetType(inst->type_id());
  708. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  709. uint32_t width = ElementWidth(type);
  710. if (width != 32 && width != 64) return false;
  711. const analysis::Constant* const_input1 = ConstInput(constants);
  712. if (!const_input1 || HasZero(const_input1)) return false;
  713. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  714. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  715. bool first_is_variable = constants[0] == nullptr;
  716. if (other_inst->opcode() == inst->opcode()) {
  717. std::vector<const analysis::Constant*> other_constants =
  718. const_mgr->GetOperandConstants(other_inst);
  719. const analysis::Constant* const_input2 = ConstInput(other_constants);
  720. if (!const_input2 || HasZero(const_input2)) return false;
  721. bool other_first_is_variable = other_constants[0] == nullptr;
  722. SpvOp merge_op = inst->opcode();
  723. if (other_first_is_variable) {
  724. // Constants magnify.
  725. merge_op = SpvOpFMul;
  726. }
  727. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  728. // because it is commutative.
  729. if (first_is_variable) std::swap(const_input1, const_input2);
  730. uint32_t merged_id =
  731. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  732. if (merged_id == 0) return false;
  733. uint32_t non_const_id = other_first_is_variable
  734. ? other_inst->GetSingleWordInOperand(0u)
  735. : other_inst->GetSingleWordInOperand(1u);
  736. SpvOp op = inst->opcode();
  737. if (!first_is_variable && !other_first_is_variable) {
  738. // Effectively div of 1/x, so change to multiply.
  739. op = SpvOpFMul;
  740. }
  741. uint32_t op1 = merged_id;
  742. uint32_t op2 = non_const_id;
  743. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  744. inst->SetOpcode(op);
  745. inst->SetInOperands(
  746. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  747. return true;
  748. }
  749. return false;
  750. };
  751. }
  752. // Fold multiplies succeeded by divides where each instruction contains a
  753. // constant operand. Does not support integer divide.
  754. // Cases:
  755. // 4 / (x * 2) = 2 / x
  756. // 4 / (2 * x) = 2 / x
  757. // (x * 4) / 2 = x * 2
  758. // (4 * x) / 2 = x * 2
  759. // (x * y) / x = y
  760. // (y * x) / x = y
  761. FoldingRule MergeDivMulArithmetic() {
  762. return [](IRContext* context, Instruction* inst,
  763. const std::vector<const analysis::Constant*>& constants) {
  764. assert(inst->opcode() == SpvOpFDiv);
  765. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  766. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  767. const analysis::Type* type =
  768. context->get_type_mgr()->GetType(inst->type_id());
  769. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  770. uint32_t width = ElementWidth(type);
  771. if (width != 32 && width != 64) return false;
  772. uint32_t op_id = inst->GetSingleWordInOperand(0);
  773. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  774. if (op_inst->opcode() == SpvOpFMul) {
  775. for (uint32_t i = 0; i < 2; i++) {
  776. if (op_inst->GetSingleWordInOperand(i) ==
  777. inst->GetSingleWordInOperand(1)) {
  778. inst->SetOpcode(SpvOpCopyObject);
  779. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  780. {op_inst->GetSingleWordInOperand(1 - i)}}});
  781. return true;
  782. }
  783. }
  784. }
  785. const analysis::Constant* const_input1 = ConstInput(constants);
  786. if (!const_input1 || HasZero(const_input1)) return false;
  787. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  788. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  789. bool first_is_variable = constants[0] == nullptr;
  790. if (other_inst->opcode() == SpvOpFMul) {
  791. std::vector<const analysis::Constant*> other_constants =
  792. const_mgr->GetOperandConstants(other_inst);
  793. const analysis::Constant* const_input2 = ConstInput(other_constants);
  794. if (!const_input2) return false;
  795. bool other_first_is_variable = other_constants[0] == nullptr;
  796. // This is an x / (*) case. Swap the inputs.
  797. if (first_is_variable) std::swap(const_input1, const_input2);
  798. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  799. const_input1, const_input2);
  800. if (merged_id == 0) return false;
  801. uint32_t non_const_id = other_first_is_variable
  802. ? other_inst->GetSingleWordInOperand(0u)
  803. : other_inst->GetSingleWordInOperand(1u);
  804. uint32_t op1 = merged_id;
  805. uint32_t op2 = non_const_id;
  806. if (first_is_variable) std::swap(op1, op2);
  807. // Convert to multiply
  808. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  809. inst->SetInOperands(
  810. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  811. return true;
  812. }
  813. return false;
  814. };
  815. }
  816. // Fold divides of a constant and a negation.
  817. // Cases:
  818. // (-x) / 2 = x / -2
  819. // 2 / (-x) = 2 / -x
  820. FoldingRule MergeDivNegateArithmetic() {
  821. return [](IRContext* context, Instruction* inst,
  822. const std::vector<const analysis::Constant*>& constants) {
  823. assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
  824. inst->opcode() == SpvOpUDiv);
  825. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  826. const analysis::Type* type =
  827. context->get_type_mgr()->GetType(inst->type_id());
  828. bool uses_float = HasFloatingPoint(type);
  829. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  830. uint32_t width = ElementWidth(type);
  831. if (width != 32 && width != 64) return false;
  832. const analysis::Constant* const_input1 = ConstInput(constants);
  833. if (!const_input1) return false;
  834. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  835. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  836. return false;
  837. bool first_is_variable = constants[0] == nullptr;
  838. if (other_inst->opcode() == SpvOpFNegate ||
  839. other_inst->opcode() == SpvOpSNegate) {
  840. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  841. if (first_is_variable) {
  842. inst->SetInOperands(
  843. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  844. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  845. } else {
  846. inst->SetInOperands(
  847. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  848. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  849. }
  850. return true;
  851. }
  852. return false;
  853. };
  854. }
  855. // Folds addition of a constant and a negation.
  856. // Cases:
  857. // (-x) + 2 = 2 - x
  858. // 2 + (-x) = 2 - x
  859. FoldingRule MergeAddNegateArithmetic() {
  860. return [](IRContext* context, Instruction* inst,
  861. const std::vector<const analysis::Constant*>& constants) {
  862. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  863. const analysis::Type* type =
  864. context->get_type_mgr()->GetType(inst->type_id());
  865. bool uses_float = HasFloatingPoint(type);
  866. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  867. const analysis::Constant* const_input1 = ConstInput(constants);
  868. if (!const_input1) return false;
  869. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  870. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  871. return false;
  872. if (other_inst->opcode() == SpvOpSNegate ||
  873. other_inst->opcode() == SpvOpFNegate) {
  874. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  875. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  876. : inst->GetSingleWordInOperand(1u);
  877. inst->SetInOperands(
  878. {{SPV_OPERAND_TYPE_ID, {const_id}},
  879. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  880. return true;
  881. }
  882. return false;
  883. };
  884. }
  885. // Folds subtraction of a constant and a negation.
  886. // Cases:
  887. // (-x) - 2 = -2 - x
  888. // 2 - (-x) = x + 2
  889. FoldingRule MergeSubNegateArithmetic() {
  890. return [](IRContext* context, Instruction* inst,
  891. const std::vector<const analysis::Constant*>& constants) {
  892. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  893. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  894. const analysis::Type* type =
  895. context->get_type_mgr()->GetType(inst->type_id());
  896. bool uses_float = HasFloatingPoint(type);
  897. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  898. uint32_t width = ElementWidth(type);
  899. if (width != 32 && width != 64) return false;
  900. const analysis::Constant* const_input1 = ConstInput(constants);
  901. if (!const_input1) return false;
  902. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  903. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  904. return false;
  905. if (other_inst->opcode() == SpvOpSNegate ||
  906. other_inst->opcode() == SpvOpFNegate) {
  907. uint32_t op1 = 0;
  908. uint32_t op2 = 0;
  909. SpvOp opcode = inst->opcode();
  910. if (constants[0] != nullptr) {
  911. op1 = other_inst->GetSingleWordInOperand(0u);
  912. op2 = inst->GetSingleWordInOperand(0u);
  913. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  914. } else {
  915. op1 = NegateConstant(const_mgr, const_input1);
  916. op2 = other_inst->GetSingleWordInOperand(0u);
  917. }
  918. inst->SetOpcode(opcode);
  919. inst->SetInOperands(
  920. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  921. return true;
  922. }
  923. return false;
  924. };
  925. }
  926. // Folds addition of an addition where each operation has a constant operand.
  927. // Cases:
  928. // (x + 2) + 2 = x + 4
  929. // (2 + x) + 2 = x + 4
  930. // 2 + (x + 2) = x + 4
  931. // 2 + (2 + x) = x + 4
  932. FoldingRule MergeAddAddArithmetic() {
  933. return [](IRContext* context, Instruction* inst,
  934. const std::vector<const analysis::Constant*>& constants) {
  935. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  936. const analysis::Type* type =
  937. context->get_type_mgr()->GetType(inst->type_id());
  938. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  939. bool uses_float = HasFloatingPoint(type);
  940. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  941. uint32_t width = ElementWidth(type);
  942. if (width != 32 && width != 64) return false;
  943. const analysis::Constant* const_input1 = ConstInput(constants);
  944. if (!const_input1) return false;
  945. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  946. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  947. return false;
  948. if (other_inst->opcode() == SpvOpFAdd ||
  949. other_inst->opcode() == SpvOpIAdd) {
  950. std::vector<const analysis::Constant*> other_constants =
  951. const_mgr->GetOperandConstants(other_inst);
  952. const analysis::Constant* const_input2 = ConstInput(other_constants);
  953. if (!const_input2) return false;
  954. Instruction* non_const_input =
  955. NonConstInput(context, other_constants[0], other_inst);
  956. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  957. const_input1, const_input2);
  958. if (merged_id == 0) return false;
  959. inst->SetInOperands(
  960. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  961. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  962. return true;
  963. }
  964. return false;
  965. };
  966. }
  967. // Folds addition of a subtraction where each operation has a constant operand.
  968. // Cases:
  969. // (x - 2) + 2 = x + 0
  970. // (2 - x) + 2 = 4 - x
  971. // 2 + (x - 2) = x + 0
  972. // 2 + (2 - x) = 4 - x
  973. FoldingRule MergeAddSubArithmetic() {
  974. return [](IRContext* context, Instruction* inst,
  975. const std::vector<const analysis::Constant*>& constants) {
  976. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  977. const analysis::Type* type =
  978. context->get_type_mgr()->GetType(inst->type_id());
  979. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  980. bool uses_float = HasFloatingPoint(type);
  981. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  982. uint32_t width = ElementWidth(type);
  983. if (width != 32 && width != 64) return false;
  984. const analysis::Constant* const_input1 = ConstInput(constants);
  985. if (!const_input1) return false;
  986. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  987. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  988. return false;
  989. if (other_inst->opcode() == SpvOpFSub ||
  990. other_inst->opcode() == SpvOpISub) {
  991. std::vector<const analysis::Constant*> other_constants =
  992. const_mgr->GetOperandConstants(other_inst);
  993. const analysis::Constant* const_input2 = ConstInput(other_constants);
  994. if (!const_input2) return false;
  995. bool first_is_variable = other_constants[0] == nullptr;
  996. SpvOp op = inst->opcode();
  997. uint32_t op1 = 0;
  998. uint32_t op2 = 0;
  999. if (first_is_variable) {
  1000. // Subtract constants. Non-constant operand is first.
  1001. op1 = other_inst->GetSingleWordInOperand(0u);
  1002. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  1003. const_input2);
  1004. } else {
  1005. // Add constants. Constant operand is first. Change the opcode.
  1006. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  1007. const_input2);
  1008. op2 = other_inst->GetSingleWordInOperand(1u);
  1009. op = other_inst->opcode();
  1010. }
  1011. if (op1 == 0 || op2 == 0) return false;
  1012. inst->SetOpcode(op);
  1013. inst->SetInOperands(
  1014. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1015. return true;
  1016. }
  1017. return false;
  1018. };
  1019. }
  1020. // Folds subtraction of an addition where each operand has a constant operand.
  1021. // Cases:
  1022. // (x + 2) - 2 = x + 0
  1023. // (2 + x) - 2 = x + 0
  1024. // 2 - (x + 2) = 0 - x
  1025. // 2 - (2 + x) = 0 - x
  1026. FoldingRule MergeSubAddArithmetic() {
  1027. return [](IRContext* context, Instruction* inst,
  1028. const std::vector<const analysis::Constant*>& constants) {
  1029. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1030. const analysis::Type* type =
  1031. context->get_type_mgr()->GetType(inst->type_id());
  1032. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1033. bool uses_float = HasFloatingPoint(type);
  1034. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1035. uint32_t width = ElementWidth(type);
  1036. if (width != 32 && width != 64) return false;
  1037. const analysis::Constant* const_input1 = ConstInput(constants);
  1038. if (!const_input1) return false;
  1039. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1040. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1041. return false;
  1042. if (other_inst->opcode() == SpvOpFAdd ||
  1043. other_inst->opcode() == SpvOpIAdd) {
  1044. std::vector<const analysis::Constant*> other_constants =
  1045. const_mgr->GetOperandConstants(other_inst);
  1046. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1047. if (!const_input2) return false;
  1048. Instruction* non_const_input =
  1049. NonConstInput(context, other_constants[0], other_inst);
  1050. // If the first operand of the sub is not a constant, swap the constants
  1051. // so the subtraction has the correct operands.
  1052. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1053. // Subtract the constants.
  1054. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1055. const_input1, const_input2);
  1056. SpvOp op = inst->opcode();
  1057. uint32_t op1 = 0;
  1058. uint32_t op2 = 0;
  1059. if (constants[0] == nullptr) {
  1060. // Non-constant operand is first. Change the opcode.
  1061. op1 = non_const_input->result_id();
  1062. op2 = merged_id;
  1063. op = other_inst->opcode();
  1064. } else {
  1065. // Constant operand is first.
  1066. op1 = merged_id;
  1067. op2 = non_const_input->result_id();
  1068. }
  1069. if (op1 == 0 || op2 == 0) return false;
  1070. inst->SetOpcode(op);
  1071. inst->SetInOperands(
  1072. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1073. return true;
  1074. }
  1075. return false;
  1076. };
  1077. }
  1078. // Folds subtraction of a subtraction where each operand has a constant operand.
  1079. // Cases:
  1080. // (x - 2) - 2 = x - 4
  1081. // (2 - x) - 2 = 0 - x
  1082. // 2 - (x - 2) = 4 - x
  1083. // 2 - (2 - x) = x + 0
  1084. FoldingRule MergeSubSubArithmetic() {
  1085. return [](IRContext* context, Instruction* inst,
  1086. const std::vector<const analysis::Constant*>& constants) {
  1087. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1088. const analysis::Type* type =
  1089. context->get_type_mgr()->GetType(inst->type_id());
  1090. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1091. bool uses_float = HasFloatingPoint(type);
  1092. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1093. uint32_t width = ElementWidth(type);
  1094. if (width != 32 && width != 64) return false;
  1095. const analysis::Constant* const_input1 = ConstInput(constants);
  1096. if (!const_input1) return false;
  1097. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1098. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1099. return false;
  1100. if (other_inst->opcode() == SpvOpFSub ||
  1101. other_inst->opcode() == SpvOpISub) {
  1102. std::vector<const analysis::Constant*> other_constants =
  1103. const_mgr->GetOperandConstants(other_inst);
  1104. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1105. if (!const_input2) return false;
  1106. Instruction* non_const_input =
  1107. NonConstInput(context, other_constants[0], other_inst);
  1108. // Merge the constants.
  1109. uint32_t merged_id = 0;
  1110. SpvOp merge_op = inst->opcode();
  1111. if (other_constants[0] == nullptr) {
  1112. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1113. } else if (constants[0] == nullptr) {
  1114. std::swap(const_input1, const_input2);
  1115. }
  1116. merged_id =
  1117. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1118. if (merged_id == 0) return false;
  1119. SpvOp op = inst->opcode();
  1120. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1121. // Change the operation.
  1122. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1123. }
  1124. uint32_t op1 = 0;
  1125. uint32_t op2 = 0;
  1126. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1127. op1 = merged_id;
  1128. op2 = non_const_input->result_id();
  1129. } else {
  1130. op1 = non_const_input->result_id();
  1131. op2 = merged_id;
  1132. }
  1133. inst->SetOpcode(op);
  1134. inst->SetInOperands(
  1135. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1136. return true;
  1137. }
  1138. return false;
  1139. };
  1140. }
  1141. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1142. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1143. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1144. IRContext* context = inst->context();
  1145. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1146. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1147. if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
  1148. return false;
  1149. if (sub_inst->opcode() == SpvOpFSub &&
  1150. !sub_inst->IsFloatingPointFoldingAllowed())
  1151. return false;
  1152. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1153. inst->SetOpcode(SpvOpCopyObject);
  1154. inst->SetInOperands(
  1155. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1156. context->UpdateDefUse(inst);
  1157. return true;
  1158. }
  1159. // Folds addition of a subtraction where the subtrahend is equal to the
  1160. // other addend. Return a copy of the minuend. Accepts generic (const and
  1161. // non-const) operands.
  1162. // Cases:
  1163. // (a - b) + b = a
  1164. // b + (a - b) = a
  1165. FoldingRule MergeGenericAddSubArithmetic() {
  1166. return [](IRContext* context, Instruction* inst,
  1167. const std::vector<const analysis::Constant*>&) {
  1168. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1169. const analysis::Type* type =
  1170. context->get_type_mgr()->GetType(inst->type_id());
  1171. bool uses_float = HasFloatingPoint(type);
  1172. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1173. uint32_t width = ElementWidth(type);
  1174. if (width != 32 && width != 64) return false;
  1175. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1176. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1177. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1178. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1179. };
  1180. }
  1181. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1182. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1183. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1184. uint32_t factor1_0, uint32_t factor1_1,
  1185. Instruction* inst) {
  1186. IRContext* context = inst->context();
  1187. if (factor0_0 != factor1_0) return false;
  1188. InstructionBuilder ir_builder(
  1189. context, inst,
  1190. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1191. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1192. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1193. inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
  1194. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1195. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1196. context->UpdateDefUse(inst);
  1197. return true;
  1198. }
  1199. // Perform the following factoring identity, handling all operand order
  1200. // combinations: (a * b) + (a * c) = a * (b + c)
  1201. FoldingRule FactorAddMuls() {
  1202. return [](IRContext* context, Instruction* inst,
  1203. const std::vector<const analysis::Constant*>&) {
  1204. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1205. const analysis::Type* type =
  1206. context->get_type_mgr()->GetType(inst->type_id());
  1207. bool uses_float = HasFloatingPoint(type);
  1208. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1209. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1210. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1211. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1212. if (add_op0_inst->opcode() != SpvOpFMul &&
  1213. add_op0_inst->opcode() != SpvOpIMul)
  1214. return false;
  1215. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1216. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1217. if (add_op1_inst->opcode() != SpvOpFMul &&
  1218. add_op1_inst->opcode() != SpvOpIMul)
  1219. return false;
  1220. // Only perform this optimization if both of the muls only have one use.
  1221. // Otherwise this is a deoptimization in size and performance.
  1222. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1223. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1224. if (add_op0_inst->opcode() == SpvOpFMul &&
  1225. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1226. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1227. return false;
  1228. for (int i = 0; i < 2; i++) {
  1229. for (int j = 0; j < 2; j++) {
  1230. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1231. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1232. add_op0_inst->GetSingleWordInOperand(1 - i),
  1233. add_op1_inst->GetSingleWordInOperand(j),
  1234. add_op1_inst->GetSingleWordInOperand(1 - j),
  1235. inst))
  1236. return true;
  1237. }
  1238. }
  1239. return false;
  1240. };
  1241. }
  1242. FoldingRule IntMultipleBy1() {
  1243. return [](IRContext*, Instruction* inst,
  1244. const std::vector<const analysis::Constant*>& constants) {
  1245. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1246. for (uint32_t i = 0; i < 2; i++) {
  1247. if (constants[i] == nullptr) {
  1248. continue;
  1249. }
  1250. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1251. if (int_constant) {
  1252. uint32_t width = ElementWidth(int_constant->type());
  1253. if (width != 32 && width != 64) return false;
  1254. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1255. : int_constant->GetU64BitValue() == 1ull;
  1256. if (is_one) {
  1257. inst->SetOpcode(SpvOpCopyObject);
  1258. inst->SetInOperands(
  1259. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1260. return true;
  1261. }
  1262. }
  1263. }
  1264. return false;
  1265. };
  1266. }
  1267. FoldingRule CompositeConstructFeedingExtract() {
  1268. return [](IRContext* context, Instruction* inst,
  1269. const std::vector<const analysis::Constant*>&) {
  1270. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1271. // then we can simply use the appropriate element in the construction.
  1272. assert(inst->opcode() == SpvOpCompositeExtract &&
  1273. "Wrong opcode. Should be OpCompositeExtract.");
  1274. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1275. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1276. // If there are no index operands, then this rule cannot do anything.
  1277. if (inst->NumInOperands() <= 1) {
  1278. return false;
  1279. }
  1280. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1281. Instruction* cinst = def_use_mgr->GetDef(cid);
  1282. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1283. return false;
  1284. }
  1285. std::vector<Operand> operands;
  1286. analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
  1287. if (composite_type->AsVector() == nullptr) {
  1288. // Get the element being extracted from the OpCompositeConstruct
  1289. // Since it is not a vector, it is simple to extract the single element.
  1290. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1291. uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
  1292. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1293. // Add the remaining indices for extraction.
  1294. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1295. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1296. {inst->GetSingleWordInOperand(i)}});
  1297. }
  1298. } else {
  1299. // With vectors we have to handle the case where it is concatenating
  1300. // vectors.
  1301. assert(inst->NumInOperands() == 2 &&
  1302. "Expecting a vector of scalar values.");
  1303. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1304. for (uint32_t construct_index = 0;
  1305. construct_index < cinst->NumInOperands(); ++construct_index) {
  1306. uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
  1307. Instruction* element_def = def_use_mgr->GetDef(element_id);
  1308. analysis::Vector* element_type =
  1309. type_mgr->GetType(element_def->type_id())->AsVector();
  1310. if (element_type) {
  1311. uint32_t vector_size = element_type->element_count();
  1312. if (vector_size < element_index) {
  1313. // The element we want comes after this vector.
  1314. element_index -= vector_size;
  1315. } else {
  1316. // We want an element of this vector.
  1317. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1318. operands.push_back(
  1319. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
  1320. break;
  1321. }
  1322. } else {
  1323. if (element_index == 0) {
  1324. // This is a scalar, and we this is the element we are extracting.
  1325. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1326. break;
  1327. } else {
  1328. // Skip over this scalar value.
  1329. --element_index;
  1330. }
  1331. }
  1332. }
  1333. }
  1334. // If there were no extra indices, then we have the final object. No need
  1335. // to extract even more.
  1336. if (operands.size() == 1) {
  1337. inst->SetOpcode(SpvOpCopyObject);
  1338. }
  1339. inst->SetInOperands(std::move(operands));
  1340. return true;
  1341. };
  1342. }
  1343. FoldingRule CompositeExtractFeedingConstruct() {
  1344. // If the OpCompositeConstruct is simply putting back together elements that
  1345. // where extracted from the same souce, we can simlpy reuse the source.
  1346. //
  1347. // This is a common code pattern because of the way that scalar replacement
  1348. // works.
  1349. return [](IRContext* context, Instruction* inst,
  1350. const std::vector<const analysis::Constant*>&) {
  1351. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1352. "Wrong opcode. Should be OpCompositeConstruct.");
  1353. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1354. uint32_t original_id = 0;
  1355. // Check each element to make sure they are:
  1356. // - extractions
  1357. // - extracting the same position they are inserting
  1358. // - all extract from the same id.
  1359. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1360. uint32_t element_id = inst->GetSingleWordInOperand(i);
  1361. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1362. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1363. return false;
  1364. }
  1365. if (element_inst->NumInOperands() != 2) {
  1366. return false;
  1367. }
  1368. if (element_inst->GetSingleWordInOperand(1) != i) {
  1369. return false;
  1370. }
  1371. if (i == 0) {
  1372. original_id =
  1373. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1374. } else if (original_id != element_inst->GetSingleWordInOperand(
  1375. kExtractCompositeIdInIdx)) {
  1376. return false;
  1377. }
  1378. }
  1379. // The last check it to see that the object being extracted from is the
  1380. // correct type.
  1381. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1382. if (original_inst->type_id() != inst->type_id()) {
  1383. return false;
  1384. }
  1385. // Simplify by using the original object.
  1386. inst->SetOpcode(SpvOpCopyObject);
  1387. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1388. return true;
  1389. };
  1390. }
  1391. FoldingRule InsertFeedingExtract() {
  1392. return [](IRContext* context, Instruction* inst,
  1393. const std::vector<const analysis::Constant*>&) {
  1394. assert(inst->opcode() == SpvOpCompositeExtract &&
  1395. "Wrong opcode. Should be OpCompositeExtract.");
  1396. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1397. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1398. Instruction* cinst = def_use_mgr->GetDef(cid);
  1399. if (cinst->opcode() != SpvOpCompositeInsert) {
  1400. return false;
  1401. }
  1402. // Find the first position where the list of insert and extract indicies
  1403. // differ, if at all.
  1404. uint32_t i;
  1405. for (i = 1; i < inst->NumInOperands(); ++i) {
  1406. if (i + 1 >= cinst->NumInOperands()) {
  1407. break;
  1408. }
  1409. if (inst->GetSingleWordInOperand(i) !=
  1410. cinst->GetSingleWordInOperand(i + 1)) {
  1411. break;
  1412. }
  1413. }
  1414. // We are extracting the element that was inserted.
  1415. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1416. inst->SetOpcode(SpvOpCopyObject);
  1417. inst->SetInOperands(
  1418. {{SPV_OPERAND_TYPE_ID,
  1419. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1420. return true;
  1421. }
  1422. // Extracting the value that was inserted along with values for the base
  1423. // composite. Cannot do anything.
  1424. if (i == inst->NumInOperands()) {
  1425. return false;
  1426. }
  1427. // Extracting an element of the value that was inserted. Extract from
  1428. // that value directly.
  1429. if (i + 1 == cinst->NumInOperands()) {
  1430. std::vector<Operand> operands;
  1431. operands.push_back(
  1432. {SPV_OPERAND_TYPE_ID,
  1433. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1434. for (; i < inst->NumInOperands(); ++i) {
  1435. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1436. {inst->GetSingleWordInOperand(i)}});
  1437. }
  1438. inst->SetInOperands(std::move(operands));
  1439. return true;
  1440. }
  1441. // Extracting a value that is disjoint from the element being inserted.
  1442. // Rewrite the extract to use the composite input to the insert.
  1443. std::vector<Operand> operands;
  1444. operands.push_back(
  1445. {SPV_OPERAND_TYPE_ID,
  1446. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1447. for (i = 1; i < inst->NumInOperands(); ++i) {
  1448. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1449. {inst->GetSingleWordInOperand(i)}});
  1450. }
  1451. inst->SetInOperands(std::move(operands));
  1452. return true;
  1453. };
  1454. }
  1455. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1456. // operands of the VectorShuffle. We just need to adjust the index in the
  1457. // extract instruction.
  1458. FoldingRule VectorShuffleFeedingExtract() {
  1459. return [](IRContext* context, Instruction* inst,
  1460. const std::vector<const analysis::Constant*>&) {
  1461. assert(inst->opcode() == SpvOpCompositeExtract &&
  1462. "Wrong opcode. Should be OpCompositeExtract.");
  1463. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1464. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1465. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1466. Instruction* cinst = def_use_mgr->GetDef(cid);
  1467. if (cinst->opcode() != SpvOpVectorShuffle) {
  1468. return false;
  1469. }
  1470. // Find the size of the first vector operand of the VectorShuffle
  1471. Instruction* first_input =
  1472. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1473. analysis::Type* first_input_type =
  1474. type_mgr->GetType(first_input->type_id());
  1475. assert(first_input_type->AsVector() &&
  1476. "Input to vector shuffle should be vectors.");
  1477. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1478. // Get index of the element the vector shuffle is placing in the position
  1479. // being extracted.
  1480. uint32_t new_index =
  1481. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1482. // Extracting an undefined value so fold this extract into an undef.
  1483. const uint32_t undef_literal_value = 0xffffffff;
  1484. if (new_index == undef_literal_value) {
  1485. inst->SetOpcode(SpvOpUndef);
  1486. inst->SetInOperands({});
  1487. return true;
  1488. }
  1489. // Get the id of the of the vector the elemtent comes from, and update the
  1490. // index if needed.
  1491. uint32_t new_vector = 0;
  1492. if (new_index < first_input_size) {
  1493. new_vector = cinst->GetSingleWordInOperand(0);
  1494. } else {
  1495. new_vector = cinst->GetSingleWordInOperand(1);
  1496. new_index -= first_input_size;
  1497. }
  1498. // Update the extract instruction.
  1499. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1500. inst->SetInOperand(1, {new_index});
  1501. return true;
  1502. };
  1503. }
  1504. // When an FMix with is feeding an Extract that extracts an element whose
  1505. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1506. // operands of the FMix.
  1507. FoldingRule FMixFeedingExtract() {
  1508. return [](IRContext* context, Instruction* inst,
  1509. const std::vector<const analysis::Constant*>&) {
  1510. assert(inst->opcode() == SpvOpCompositeExtract &&
  1511. "Wrong opcode. Should be OpCompositeExtract.");
  1512. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1513. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1514. uint32_t composite_id =
  1515. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1516. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1517. if (composite_inst->opcode() != SpvOpExtInst) {
  1518. return false;
  1519. }
  1520. uint32_t inst_set_id =
  1521. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1522. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1523. inst_set_id ||
  1524. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1525. GLSLstd450FMix) {
  1526. return false;
  1527. }
  1528. // Get the |a| for the FMix instruction.
  1529. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1530. std::unique_ptr<Instruction> a(inst->Clone(context));
  1531. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1532. context->get_instruction_folder().FoldInstruction(a.get());
  1533. if (a->opcode() != SpvOpCopyObject) {
  1534. return false;
  1535. }
  1536. const analysis::Constant* a_const =
  1537. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1538. if (!a_const) {
  1539. return false;
  1540. }
  1541. bool use_x = false;
  1542. assert(a_const->type()->AsFloat());
  1543. double element_value = a_const->GetValueAsDouble();
  1544. if (element_value == 0.0) {
  1545. use_x = true;
  1546. } else if (element_value == 1.0) {
  1547. use_x = false;
  1548. } else {
  1549. return false;
  1550. }
  1551. // Get the id of the of the vector the element comes from.
  1552. uint32_t new_vector = 0;
  1553. if (use_x) {
  1554. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1555. } else {
  1556. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1557. }
  1558. // Update the extract instruction.
  1559. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1560. return true;
  1561. };
  1562. }
  1563. FoldingRule RedundantPhi() {
  1564. // An OpPhi instruction where all values are the same or the result of the phi
  1565. // itself, can be replaced by the value itself.
  1566. return [](IRContext*, Instruction* inst,
  1567. const std::vector<const analysis::Constant*>&) {
  1568. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1569. uint32_t incoming_value = 0;
  1570. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1571. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1572. if (op_id == inst->result_id()) {
  1573. continue;
  1574. }
  1575. if (incoming_value == 0) {
  1576. incoming_value = op_id;
  1577. } else if (op_id != incoming_value) {
  1578. // Found two possible value. Can't simplify.
  1579. return false;
  1580. }
  1581. }
  1582. if (incoming_value == 0) {
  1583. // Code looks invalid. Don't do anything.
  1584. return false;
  1585. }
  1586. // We have a single incoming value. Simplify using that value.
  1587. inst->SetOpcode(SpvOpCopyObject);
  1588. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1589. return true;
  1590. };
  1591. }
  1592. FoldingRule RedundantSelect() {
  1593. // An OpSelect instruction where both values are the same or the condition is
  1594. // constant can be replaced by one of the values
  1595. return [](IRContext*, Instruction* inst,
  1596. const std::vector<const analysis::Constant*>& constants) {
  1597. assert(inst->opcode() == SpvOpSelect &&
  1598. "Wrong opcode. Should be OpSelect.");
  1599. assert(inst->NumInOperands() == 3);
  1600. assert(constants.size() == 3);
  1601. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1602. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1603. if (true_id == false_id) {
  1604. // Both results are the same, condition doesn't matter
  1605. inst->SetOpcode(SpvOpCopyObject);
  1606. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1607. return true;
  1608. } else if (constants[0]) {
  1609. const analysis::Type* type = constants[0]->type();
  1610. if (type->AsBool()) {
  1611. // Scalar constant value, select the corresponding value.
  1612. inst->SetOpcode(SpvOpCopyObject);
  1613. if (constants[0]->AsNullConstant() ||
  1614. !constants[0]->AsBoolConstant()->value()) {
  1615. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1616. } else {
  1617. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1618. }
  1619. return true;
  1620. } else {
  1621. assert(type->AsVector());
  1622. if (constants[0]->AsNullConstant()) {
  1623. // All values come from false id.
  1624. inst->SetOpcode(SpvOpCopyObject);
  1625. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1626. return true;
  1627. } else {
  1628. // Convert to a vector shuffle.
  1629. std::vector<Operand> ops;
  1630. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  1631. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  1632. const analysis::VectorConstant* vector_const =
  1633. constants[0]->AsVectorConstant();
  1634. uint32_t size =
  1635. static_cast<uint32_t>(vector_const->GetComponents().size());
  1636. for (uint32_t i = 0; i != size; ++i) {
  1637. const analysis::Constant* component =
  1638. vector_const->GetComponents()[i];
  1639. if (component->AsNullConstant() ||
  1640. !component->AsBoolConstant()->value()) {
  1641. // Selecting from the false vector which is the second input
  1642. // vector to the shuffle. Offset the index by |size|.
  1643. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  1644. } else {
  1645. // Selecting from true vector which is the first input vector to
  1646. // the shuffle.
  1647. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  1648. }
  1649. }
  1650. inst->SetOpcode(SpvOpVectorShuffle);
  1651. inst->SetInOperands(std::move(ops));
  1652. return true;
  1653. }
  1654. }
  1655. }
  1656. return false;
  1657. };
  1658. }
  1659. enum class FloatConstantKind { Unknown, Zero, One };
  1660. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  1661. if (constant == nullptr) {
  1662. return FloatConstantKind::Unknown;
  1663. }
  1664. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  1665. if (constant->AsNullConstant()) {
  1666. return FloatConstantKind::Zero;
  1667. } else if (const analysis::VectorConstant* vc =
  1668. constant->AsVectorConstant()) {
  1669. const std::vector<const analysis::Constant*>& components =
  1670. vc->GetComponents();
  1671. assert(!components.empty());
  1672. FloatConstantKind kind = getFloatConstantKind(components[0]);
  1673. for (size_t i = 1; i < components.size(); ++i) {
  1674. if (getFloatConstantKind(components[i]) != kind) {
  1675. return FloatConstantKind::Unknown;
  1676. }
  1677. }
  1678. return kind;
  1679. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  1680. if (fc->IsZero()) return FloatConstantKind::Zero;
  1681. uint32_t width = fc->type()->AsFloat()->width();
  1682. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  1683. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  1684. if (value == 0.0) {
  1685. return FloatConstantKind::Zero;
  1686. } else if (value == 1.0) {
  1687. return FloatConstantKind::One;
  1688. } else {
  1689. return FloatConstantKind::Unknown;
  1690. }
  1691. } else {
  1692. return FloatConstantKind::Unknown;
  1693. }
  1694. }
  1695. FoldingRule RedundantFAdd() {
  1696. return [](IRContext*, Instruction* inst,
  1697. const std::vector<const analysis::Constant*>& constants) {
  1698. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  1699. assert(constants.size() == 2);
  1700. if (!inst->IsFloatingPointFoldingAllowed()) {
  1701. return false;
  1702. }
  1703. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1704. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1705. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1706. inst->SetOpcode(SpvOpCopyObject);
  1707. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1708. {inst->GetSingleWordInOperand(
  1709. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  1710. return true;
  1711. }
  1712. return false;
  1713. };
  1714. }
  1715. FoldingRule RedundantFSub() {
  1716. return [](IRContext*, Instruction* inst,
  1717. const std::vector<const analysis::Constant*>& constants) {
  1718. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  1719. assert(constants.size() == 2);
  1720. if (!inst->IsFloatingPointFoldingAllowed()) {
  1721. return false;
  1722. }
  1723. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1724. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1725. if (kind0 == FloatConstantKind::Zero) {
  1726. inst->SetOpcode(SpvOpFNegate);
  1727. inst->SetInOperands(
  1728. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  1729. return true;
  1730. }
  1731. if (kind1 == FloatConstantKind::Zero) {
  1732. inst->SetOpcode(SpvOpCopyObject);
  1733. inst->SetInOperands(
  1734. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1735. return true;
  1736. }
  1737. return false;
  1738. };
  1739. }
  1740. FoldingRule RedundantFMul() {
  1741. return [](IRContext*, Instruction* inst,
  1742. const std::vector<const analysis::Constant*>& constants) {
  1743. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  1744. assert(constants.size() == 2);
  1745. if (!inst->IsFloatingPointFoldingAllowed()) {
  1746. return false;
  1747. }
  1748. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1749. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1750. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1751. inst->SetOpcode(SpvOpCopyObject);
  1752. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1753. {inst->GetSingleWordInOperand(
  1754. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  1755. return true;
  1756. }
  1757. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  1758. inst->SetOpcode(SpvOpCopyObject);
  1759. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1760. {inst->GetSingleWordInOperand(
  1761. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  1762. return true;
  1763. }
  1764. return false;
  1765. };
  1766. }
  1767. FoldingRule RedundantFDiv() {
  1768. return [](IRContext*, Instruction* inst,
  1769. const std::vector<const analysis::Constant*>& constants) {
  1770. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  1771. assert(constants.size() == 2);
  1772. if (!inst->IsFloatingPointFoldingAllowed()) {
  1773. return false;
  1774. }
  1775. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1776. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1777. if (kind0 == FloatConstantKind::Zero) {
  1778. inst->SetOpcode(SpvOpCopyObject);
  1779. inst->SetInOperands(
  1780. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1781. return true;
  1782. }
  1783. if (kind1 == FloatConstantKind::One) {
  1784. inst->SetOpcode(SpvOpCopyObject);
  1785. inst->SetInOperands(
  1786. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1787. return true;
  1788. }
  1789. return false;
  1790. };
  1791. }
  1792. FoldingRule RedundantFMix() {
  1793. return [](IRContext* context, Instruction* inst,
  1794. const std::vector<const analysis::Constant*>& constants) {
  1795. assert(inst->opcode() == SpvOpExtInst &&
  1796. "Wrong opcode. Should be OpExtInst.");
  1797. if (!inst->IsFloatingPointFoldingAllowed()) {
  1798. return false;
  1799. }
  1800. uint32_t instSetId =
  1801. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1802. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  1803. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  1804. GLSLstd450FMix) {
  1805. assert(constants.size() == 5);
  1806. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  1807. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  1808. inst->SetOpcode(SpvOpCopyObject);
  1809. inst->SetInOperands(
  1810. {{SPV_OPERAND_TYPE_ID,
  1811. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  1812. ? kFMixXIdInIdx
  1813. : kFMixYIdInIdx)}}});
  1814. return true;
  1815. }
  1816. }
  1817. return false;
  1818. };
  1819. }
  1820. // This rule handles addition of zero for integers.
  1821. FoldingRule RedundantIAdd() {
  1822. return [](IRContext* context, Instruction* inst,
  1823. const std::vector<const analysis::Constant*>& constants) {
  1824. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  1825. uint32_t operand = std::numeric_limits<uint32_t>::max();
  1826. const analysis::Type* operand_type = nullptr;
  1827. if (constants[0] && constants[0]->IsZero()) {
  1828. operand = inst->GetSingleWordInOperand(1);
  1829. operand_type = constants[0]->type();
  1830. } else if (constants[1] && constants[1]->IsZero()) {
  1831. operand = inst->GetSingleWordInOperand(0);
  1832. operand_type = constants[1]->type();
  1833. }
  1834. if (operand != std::numeric_limits<uint32_t>::max()) {
  1835. const analysis::Type* inst_type =
  1836. context->get_type_mgr()->GetType(inst->type_id());
  1837. if (inst_type->IsSame(operand_type)) {
  1838. inst->SetOpcode(SpvOpCopyObject);
  1839. } else {
  1840. inst->SetOpcode(SpvOpBitcast);
  1841. }
  1842. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  1843. return true;
  1844. }
  1845. return false;
  1846. };
  1847. }
  1848. // This rule look for a dot with a constant vector containing a single 1 and
  1849. // the rest 0s. This is the same as doing an extract.
  1850. FoldingRule DotProductDoingExtract() {
  1851. return [](IRContext* context, Instruction* inst,
  1852. const std::vector<const analysis::Constant*>& constants) {
  1853. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  1854. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1855. if (!inst->IsFloatingPointFoldingAllowed()) {
  1856. return false;
  1857. }
  1858. for (int i = 0; i < 2; ++i) {
  1859. if (!constants[i]) {
  1860. continue;
  1861. }
  1862. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  1863. assert(vector_type && "Inputs to OpDot must be vectors.");
  1864. const analysis::Float* element_type =
  1865. vector_type->element_type()->AsFloat();
  1866. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  1867. uint32_t element_width = element_type->width();
  1868. if (element_width != 32 && element_width != 64) {
  1869. return false;
  1870. }
  1871. std::vector<const analysis::Constant*> components;
  1872. components = constants[i]->GetVectorComponents(const_mgr);
  1873. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  1874. uint32_t component_with_one = kNotFound;
  1875. bool all_others_zero = true;
  1876. for (uint32_t j = 0; j < components.size(); ++j) {
  1877. const analysis::Constant* element = components[j];
  1878. double value =
  1879. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  1880. if (value == 0.0) {
  1881. continue;
  1882. } else if (value == 1.0) {
  1883. if (component_with_one == kNotFound) {
  1884. component_with_one = j;
  1885. } else {
  1886. component_with_one = kNotFound;
  1887. break;
  1888. }
  1889. } else {
  1890. all_others_zero = false;
  1891. break;
  1892. }
  1893. }
  1894. if (!all_others_zero || component_with_one == kNotFound) {
  1895. continue;
  1896. }
  1897. std::vector<Operand> operands;
  1898. operands.push_back(
  1899. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  1900. operands.push_back(
  1901. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  1902. inst->SetOpcode(SpvOpCompositeExtract);
  1903. inst->SetInOperands(std::move(operands));
  1904. return true;
  1905. }
  1906. return false;
  1907. };
  1908. }
  1909. // If we are storing an undef, then we can remove the store.
  1910. //
  1911. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  1912. // is complicated. Waiting to see if it is needed.
  1913. FoldingRule StoringUndef() {
  1914. return [](IRContext* context, Instruction* inst,
  1915. const std::vector<const analysis::Constant*>&) {
  1916. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  1917. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1918. // If this is a volatile store, the store cannot be removed.
  1919. if (inst->NumInOperands() == 3) {
  1920. if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
  1921. return false;
  1922. }
  1923. }
  1924. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  1925. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  1926. if (object_inst->opcode() == SpvOpUndef) {
  1927. inst->ToNop();
  1928. return true;
  1929. }
  1930. return false;
  1931. };
  1932. }
  1933. FoldingRule VectorShuffleFeedingShuffle() {
  1934. return [](IRContext* context, Instruction* inst,
  1935. const std::vector<const analysis::Constant*>&) {
  1936. assert(inst->opcode() == SpvOpVectorShuffle &&
  1937. "Wrong opcode. Should be OpVectorShuffle.");
  1938. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1939. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1940. Instruction* feeding_shuffle_inst =
  1941. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  1942. analysis::Vector* op0_type =
  1943. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  1944. uint32_t op0_length = op0_type->element_count();
  1945. bool feeder_is_op0 = true;
  1946. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1947. feeding_shuffle_inst =
  1948. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  1949. feeder_is_op0 = false;
  1950. }
  1951. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1952. return false;
  1953. }
  1954. Instruction* feeder2 =
  1955. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  1956. analysis::Vector* feeder_op0_type =
  1957. type_mgr->GetType(feeder2->type_id())->AsVector();
  1958. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  1959. uint32_t new_feeder_id = 0;
  1960. std::vector<Operand> new_operands;
  1961. new_operands.resize(
  1962. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  1963. const uint32_t undef_literal = 0xffffffff;
  1964. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  1965. uint32_t component_index = inst->GetSingleWordInOperand(op);
  1966. // Do not interpret the undefined value literal as coming from operand 1.
  1967. if (component_index != undef_literal &&
  1968. feeder_is_op0 == (component_index < op0_length)) {
  1969. // This component comes from the feeding_shuffle_inst. Update
  1970. // |component_index| to be the index into the operand of the feeder.
  1971. // Adjust component_index to get the index into the operands of the
  1972. // feeding_shuffle_inst.
  1973. if (component_index >= op0_length) {
  1974. component_index -= op0_length;
  1975. }
  1976. component_index =
  1977. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  1978. // Check if we are using a component from the first or second operand of
  1979. // the feeding instruction.
  1980. if (component_index < feeder_op0_length) {
  1981. if (new_feeder_id == 0) {
  1982. // First time through, save the id of the operand the element comes
  1983. // from.
  1984. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  1985. } else if (new_feeder_id !=
  1986. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  1987. // We need both elements of the feeding_shuffle_inst, so we cannot
  1988. // fold.
  1989. return false;
  1990. }
  1991. } else {
  1992. if (new_feeder_id == 0) {
  1993. // First time through, save the id of the operand the element comes
  1994. // from.
  1995. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  1996. } else if (new_feeder_id !=
  1997. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  1998. // We need both elements of the feeding_shuffle_inst, so we cannot
  1999. // fold.
  2000. return false;
  2001. }
  2002. component_index -= feeder_op0_length;
  2003. }
  2004. if (!feeder_is_op0) {
  2005. component_index += op0_length;
  2006. }
  2007. }
  2008. new_operands.push_back(
  2009. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  2010. }
  2011. if (new_feeder_id == 0) {
  2012. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2013. const analysis::Type* type =
  2014. type_mgr->GetType(feeding_shuffle_inst->type_id());
  2015. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  2016. new_feeder_id =
  2017. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  2018. }
  2019. if (feeder_is_op0) {
  2020. // If the size of the first vector operand changed then the indices
  2021. // referring to the second operand need to be adjusted.
  2022. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  2023. analysis::Type* new_feeder_type =
  2024. type_mgr->GetType(new_feeder_inst->type_id());
  2025. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  2026. int32_t adjustment = op0_length - new_op0_size;
  2027. if (adjustment != 0) {
  2028. for (uint32_t i = 2; i < new_operands.size(); i++) {
  2029. if (inst->GetSingleWordInOperand(i) >= op0_length) {
  2030. new_operands[i].words[0] -= adjustment;
  2031. }
  2032. }
  2033. }
  2034. new_operands[0].words[0] = new_feeder_id;
  2035. new_operands[1] = inst->GetInOperand(1);
  2036. } else {
  2037. new_operands[1].words[0] = new_feeder_id;
  2038. new_operands[0] = inst->GetInOperand(0);
  2039. }
  2040. inst->SetInOperands(std::move(new_operands));
  2041. return true;
  2042. };
  2043. }
  2044. // Removes duplicate ids from the interface list of an OpEntryPoint
  2045. // instruction.
  2046. FoldingRule RemoveRedundantOperands() {
  2047. return [](IRContext*, Instruction* inst,
  2048. const std::vector<const analysis::Constant*>&) {
  2049. assert(inst->opcode() == SpvOpEntryPoint &&
  2050. "Wrong opcode. Should be OpEntryPoint.");
  2051. bool has_redundant_operand = false;
  2052. std::unordered_set<uint32_t> seen_operands;
  2053. std::vector<Operand> new_operands;
  2054. new_operands.emplace_back(inst->GetOperand(0));
  2055. new_operands.emplace_back(inst->GetOperand(1));
  2056. new_operands.emplace_back(inst->GetOperand(2));
  2057. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2058. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2059. new_operands.emplace_back(inst->GetOperand(i));
  2060. } else {
  2061. has_redundant_operand = true;
  2062. }
  2063. }
  2064. if (!has_redundant_operand) {
  2065. return false;
  2066. }
  2067. inst->SetInOperands(std::move(new_operands));
  2068. return true;
  2069. };
  2070. }
  2071. // If an image instruction's operand is a constant, updates the image operand
  2072. // flag from Offset to ConstOffset.
  2073. FoldingRule UpdateImageOperands() {
  2074. return [](IRContext*, Instruction* inst,
  2075. const std::vector<const analysis::Constant*>& constants) {
  2076. const auto opcode = inst->opcode();
  2077. (void)opcode;
  2078. assert((opcode == SpvOpImageSampleImplicitLod ||
  2079. opcode == SpvOpImageSampleExplicitLod ||
  2080. opcode == SpvOpImageSampleDrefImplicitLod ||
  2081. opcode == SpvOpImageSampleDrefExplicitLod ||
  2082. opcode == SpvOpImageSampleProjImplicitLod ||
  2083. opcode == SpvOpImageSampleProjExplicitLod ||
  2084. opcode == SpvOpImageSampleProjDrefImplicitLod ||
  2085. opcode == SpvOpImageSampleProjDrefExplicitLod ||
  2086. opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
  2087. opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
  2088. opcode == SpvOpImageWrite ||
  2089. opcode == SpvOpImageSparseSampleImplicitLod ||
  2090. opcode == SpvOpImageSparseSampleExplicitLod ||
  2091. opcode == SpvOpImageSparseSampleDrefImplicitLod ||
  2092. opcode == SpvOpImageSparseSampleDrefExplicitLod ||
  2093. opcode == SpvOpImageSparseSampleProjImplicitLod ||
  2094. opcode == SpvOpImageSparseSampleProjExplicitLod ||
  2095. opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
  2096. opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
  2097. opcode == SpvOpImageSparseFetch ||
  2098. opcode == SpvOpImageSparseGather ||
  2099. opcode == SpvOpImageSparseDrefGather ||
  2100. opcode == SpvOpImageSparseRead) &&
  2101. "Wrong opcode. Should be an image instruction.");
  2102. int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
  2103. if (operand_index >= 0) {
  2104. auto image_operands = inst->GetSingleWordInOperand(operand_index);
  2105. if (image_operands & SpvImageOperandsOffsetMask) {
  2106. uint32_t offset_operand_index = operand_index + 1;
  2107. if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
  2108. if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
  2109. if (image_operands & SpvImageOperandsGradMask)
  2110. offset_operand_index += 2;
  2111. assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
  2112. "Offset and ConstOffset may not be used together");
  2113. if (offset_operand_index < inst->NumOperands()) {
  2114. if (constants[offset_operand_index]) {
  2115. image_operands = image_operands | SpvImageOperandsConstOffsetMask;
  2116. image_operands = image_operands & ~SpvImageOperandsOffsetMask;
  2117. inst->SetInOperand(operand_index, {image_operands});
  2118. return true;
  2119. }
  2120. }
  2121. }
  2122. }
  2123. return false;
  2124. };
  2125. }
  2126. } // namespace
  2127. void FoldingRules::AddFoldingRules() {
  2128. // Add all folding rules to the list for the opcodes to which they apply.
  2129. // Note that the order in which rules are added to the list matters. If a rule
  2130. // applies to the instruction, the rest of the rules will not be attempted.
  2131. // Take that into consideration.
  2132. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
  2133. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  2134. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
  2135. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2136. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  2137. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  2138. rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
  2139. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  2140. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  2141. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  2142. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  2143. rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
  2144. rules_[SpvOpFAdd].push_back(FactorAddMuls());
  2145. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  2146. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  2147. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  2148. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  2149. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  2150. rules_[SpvOpFMul].push_back(RedundantFMul());
  2151. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  2152. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  2153. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  2154. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  2155. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  2156. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  2157. rules_[SpvOpFSub].push_back(RedundantFSub());
  2158. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  2159. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  2160. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  2161. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  2162. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  2163. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  2164. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  2165. rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
  2166. rules_[SpvOpIAdd].push_back(FactorAddMuls());
  2167. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  2168. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  2169. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  2170. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  2171. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  2172. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  2173. rules_[SpvOpPhi].push_back(RedundantPhi());
  2174. rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
  2175. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  2176. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  2177. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  2178. rules_[SpvOpSelect].push_back(RedundantSelect());
  2179. rules_[SpvOpStore].push_back(StoringUndef());
  2180. rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
  2181. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2182. rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
  2183. rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
  2184. rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
  2185. rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
  2186. rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
  2187. rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
  2188. rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
  2189. rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
  2190. rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
  2191. rules_[SpvOpImageGather].push_back(UpdateImageOperands());
  2192. rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
  2193. rules_[SpvOpImageRead].push_back(UpdateImageOperands());
  2194. rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
  2195. rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
  2196. rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
  2197. rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
  2198. UpdateImageOperands());
  2199. rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
  2200. UpdateImageOperands());
  2201. rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
  2202. UpdateImageOperands());
  2203. rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
  2204. UpdateImageOperands());
  2205. rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
  2206. UpdateImageOperands());
  2207. rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
  2208. UpdateImageOperands());
  2209. rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
  2210. rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
  2211. rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
  2212. rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
  2213. FeatureManager* feature_manager = context_->get_feature_mgr();
  2214. // Add rules for GLSLstd450
  2215. uint32_t ext_inst_glslstd450_id =
  2216. feature_manager->GetExtInstImportId_GLSLstd450();
  2217. if (ext_inst_glslstd450_id != 0) {
  2218. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2219. RedundantFMix());
  2220. }
  2221. }
  2222. } // namespace opt
  2223. } // namespace spvtools