folding_rules.cpp 110 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <climits>
  16. #include <limits>
  17. #include <memory>
  18. #include <utility>
  19. #include "ir_builder.h"
  20. #include "source/latest_version_glsl_std_450_header.h"
  21. #include "source/opt/ir_context.h"
  22. namespace spvtools {
  23. namespace opt {
  24. namespace {
  25. const uint32_t kExtractCompositeIdInIdx = 0;
  26. const uint32_t kInsertObjectIdInIdx = 0;
  27. const uint32_t kInsertCompositeIdInIdx = 1;
  28. const uint32_t kExtInstSetIdInIdx = 0;
  29. const uint32_t kExtInstInstructionInIdx = 1;
  30. const uint32_t kFMixXIdInIdx = 2;
  31. const uint32_t kFMixYIdInIdx = 3;
  32. const uint32_t kFMixAIdInIdx = 4;
  33. const uint32_t kStoreObjectInIdx = 1;
  34. // Some image instructions may contain an "image operands" argument.
  35. // Returns the operand index for the "image operands".
  36. // Returns -1 if the instruction does not have image operands.
  37. int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
  38. const auto opcode = inst->opcode();
  39. switch (opcode) {
  40. case SpvOpImageSampleImplicitLod:
  41. case SpvOpImageSampleExplicitLod:
  42. case SpvOpImageSampleProjImplicitLod:
  43. case SpvOpImageSampleProjExplicitLod:
  44. case SpvOpImageFetch:
  45. case SpvOpImageRead:
  46. case SpvOpImageSparseSampleImplicitLod:
  47. case SpvOpImageSparseSampleExplicitLod:
  48. case SpvOpImageSparseSampleProjImplicitLod:
  49. case SpvOpImageSparseSampleProjExplicitLod:
  50. case SpvOpImageSparseFetch:
  51. case SpvOpImageSparseRead:
  52. return inst->NumOperands() > 4 ? 2 : -1;
  53. case SpvOpImageSampleDrefImplicitLod:
  54. case SpvOpImageSampleDrefExplicitLod:
  55. case SpvOpImageSampleProjDrefImplicitLod:
  56. case SpvOpImageSampleProjDrefExplicitLod:
  57. case SpvOpImageGather:
  58. case SpvOpImageDrefGather:
  59. case SpvOpImageSparseSampleDrefImplicitLod:
  60. case SpvOpImageSparseSampleDrefExplicitLod:
  61. case SpvOpImageSparseSampleProjDrefImplicitLod:
  62. case SpvOpImageSparseSampleProjDrefExplicitLod:
  63. case SpvOpImageSparseGather:
  64. case SpvOpImageSparseDrefGather:
  65. return inst->NumOperands() > 5 ? 3 : -1;
  66. case SpvOpImageWrite:
  67. return inst->NumOperands() > 3 ? 3 : -1;
  68. default:
  69. return -1;
  70. }
  71. }
  72. // Returns the element width of |type|.
  73. uint32_t ElementWidth(const analysis::Type* type) {
  74. if (const analysis::Vector* vec_type = type->AsVector()) {
  75. return ElementWidth(vec_type->element_type());
  76. } else if (const analysis::Float* float_type = type->AsFloat()) {
  77. return float_type->width();
  78. } else {
  79. assert(type->AsInteger());
  80. return type->AsInteger()->width();
  81. }
  82. }
  83. // Returns true if |type| is Float or a vector of Float.
  84. bool HasFloatingPoint(const analysis::Type* type) {
  85. if (type->AsFloat()) {
  86. return true;
  87. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  88. return vec_type->element_type()->AsFloat() != nullptr;
  89. }
  90. return false;
  91. }
  92. // Returns false if |val| is NaN, infinite or subnormal.
  93. template <typename T>
  94. bool IsValidResult(T val) {
  95. int classified = std::fpclassify(val);
  96. switch (classified) {
  97. case FP_NAN:
  98. case FP_INFINITE:
  99. case FP_SUBNORMAL:
  100. return false;
  101. default:
  102. return true;
  103. }
  104. }
  105. const analysis::Constant* ConstInput(
  106. const std::vector<const analysis::Constant*>& constants) {
  107. return constants[0] ? constants[0] : constants[1];
  108. }
  109. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  110. Instruction* inst) {
  111. uint32_t in_op = c ? 1u : 0u;
  112. return context->get_def_use_mgr()->GetDef(
  113. inst->GetSingleWordInOperand(in_op));
  114. }
  115. std::vector<uint32_t> ExtractInts(uint64_t val) {
  116. std::vector<uint32_t> words;
  117. words.push_back(static_cast<uint32_t>(val));
  118. words.push_back(static_cast<uint32_t>(val >> 32));
  119. return words;
  120. }
  121. std::vector<uint32_t> GetWordsFromScalarIntConstant(
  122. const analysis::IntConstant* c) {
  123. assert(c != nullptr);
  124. uint32_t width = c->type()->AsInteger()->width();
  125. assert(width == 8 || width == 16 || width == 32 || width == 64);
  126. if (width == 64) {
  127. uint64_t uval = static_cast<uint64_t>(c->GetU64());
  128. return ExtractInts(uval);
  129. }
  130. // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
  131. // smaller than 32-bits are automatically zero or sign extended to 32-bits.
  132. return {c->GetU32BitValue()};
  133. }
  134. std::vector<uint32_t> GetWordsFromScalarFloatConstant(
  135. const analysis::FloatConstant* c) {
  136. assert(c != nullptr);
  137. uint32_t width = c->type()->AsFloat()->width();
  138. assert(width == 16 || width == 32 || width == 64);
  139. if (width == 64) {
  140. utils::FloatProxy<double> result(c->GetDouble());
  141. return result.GetWords();
  142. }
  143. // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
  144. // smaller than 32-bits are automatically zero extended to 32-bits.
  145. return {c->GetU32BitValue()};
  146. }
  147. std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
  148. analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
  149. if (const auto* float_constant = c->AsFloatConstant()) {
  150. return GetWordsFromScalarFloatConstant(float_constant);
  151. } else if (const auto* int_constant = c->AsIntConstant()) {
  152. return GetWordsFromScalarIntConstant(int_constant);
  153. } else if (const auto* vec_constant = c->AsVectorConstant()) {
  154. std::vector<uint32_t> words;
  155. for (const auto* comp : vec_constant->GetComponents()) {
  156. auto comp_in_words =
  157. GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
  158. words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
  159. }
  160. return words;
  161. }
  162. return {};
  163. }
  164. const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
  165. analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
  166. const analysis::Type* type) {
  167. if (type->AsInteger() || type->AsFloat())
  168. return const_mgr->GetConstant(type, words);
  169. if (const auto* vec_type = type->AsVector())
  170. return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
  171. return nullptr;
  172. }
  173. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  174. // constant.
  175. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  176. const analysis::Constant* c) {
  177. assert(c);
  178. assert(c->type()->AsFloat());
  179. uint32_t width = c->type()->AsFloat()->width();
  180. assert(width == 32 || width == 64);
  181. std::vector<uint32_t> words;
  182. if (width == 64) {
  183. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  184. words = result.GetWords();
  185. } else {
  186. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  187. words = result.GetWords();
  188. }
  189. const analysis::Constant* negated_const =
  190. const_mgr->GetConstant(c->type(), std::move(words));
  191. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  192. }
  193. // Negates the integer constant |c|. Returns the id of the defining instruction.
  194. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  195. const analysis::Constant* c) {
  196. assert(c);
  197. assert(c->type()->AsInteger());
  198. uint32_t width = c->type()->AsInteger()->width();
  199. assert(width == 32 || width == 64);
  200. std::vector<uint32_t> words;
  201. if (width == 64) {
  202. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  203. words = ExtractInts(uval);
  204. } else {
  205. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  206. }
  207. const analysis::Constant* negated_const =
  208. const_mgr->GetConstant(c->type(), std::move(words));
  209. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  210. }
  211. // Negates the vector constant |c|. Returns the id of the defining instruction.
  212. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  213. const analysis::Constant* c) {
  214. assert(const_mgr && c);
  215. assert(c->type()->AsVector());
  216. if (c->AsNullConstant()) {
  217. // 0.0 vs -0.0 shouldn't matter.
  218. return const_mgr->GetDefiningInstruction(c)->result_id();
  219. } else {
  220. const analysis::Type* component_type =
  221. c->AsVectorConstant()->component_type();
  222. std::vector<uint32_t> words;
  223. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  224. if (component_type->AsFloat()) {
  225. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  226. } else {
  227. assert(component_type->AsInteger());
  228. words.push_back(NegateIntegerConstant(const_mgr, comp));
  229. }
  230. }
  231. const analysis::Constant* negated_const =
  232. const_mgr->GetConstant(c->type(), std::move(words));
  233. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  234. }
  235. }
  236. // Negates |c|. Returns the id of the defining instruction.
  237. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  238. const analysis::Constant* c) {
  239. if (c->type()->AsVector()) {
  240. return NegateVectorConstant(const_mgr, c);
  241. } else if (c->type()->AsFloat()) {
  242. return NegateFloatingPointConstant(const_mgr, c);
  243. } else {
  244. assert(c->type()->AsInteger());
  245. return NegateIntegerConstant(const_mgr, c);
  246. }
  247. }
  248. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  249. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  250. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  251. const analysis::Constant* c) {
  252. assert(const_mgr && c);
  253. assert(c->type()->AsFloat());
  254. uint32_t width = c->type()->AsFloat()->width();
  255. assert(width == 32 || width == 64);
  256. std::vector<uint32_t> words;
  257. if (c->IsZero()) {
  258. return 0;
  259. }
  260. if (width == 64) {
  261. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  262. if (!IsValidResult(result.getAsFloat())) return 0;
  263. words = result.GetWords();
  264. } else {
  265. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  266. if (!IsValidResult(result.getAsFloat())) return 0;
  267. words = result.GetWords();
  268. }
  269. const analysis::Constant* negated_const =
  270. const_mgr->GetConstant(c->type(), std::move(words));
  271. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  272. }
  273. // Replaces fdiv where second operand is constant with fmul.
  274. FoldingRule ReciprocalFDiv() {
  275. return [](IRContext* context, Instruction* inst,
  276. const std::vector<const analysis::Constant*>& constants) {
  277. assert(inst->opcode() == SpvOpFDiv);
  278. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  279. const analysis::Type* type =
  280. context->get_type_mgr()->GetType(inst->type_id());
  281. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  282. uint32_t width = ElementWidth(type);
  283. if (width != 32 && width != 64) return false;
  284. if (constants[1] != nullptr) {
  285. uint32_t id = 0;
  286. if (const analysis::VectorConstant* vector_const =
  287. constants[1]->AsVectorConstant()) {
  288. std::vector<uint32_t> neg_ids;
  289. for (auto& comp : vector_const->GetComponents()) {
  290. id = Reciprocal(const_mgr, comp);
  291. if (id == 0) return false;
  292. neg_ids.push_back(id);
  293. }
  294. const analysis::Constant* negated_const =
  295. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  296. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  297. } else if (constants[1]->AsFloatConstant()) {
  298. id = Reciprocal(const_mgr, constants[1]);
  299. if (id == 0) return false;
  300. } else {
  301. // Don't fold a null constant.
  302. return false;
  303. }
  304. inst->SetOpcode(SpvOpFMul);
  305. inst->SetInOperands(
  306. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  307. {SPV_OPERAND_TYPE_ID, {id}}});
  308. return true;
  309. }
  310. return false;
  311. };
  312. }
  313. // Elides consecutive negate instructions.
  314. FoldingRule MergeNegateArithmetic() {
  315. return [](IRContext* context, Instruction* inst,
  316. const std::vector<const analysis::Constant*>& constants) {
  317. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  318. (void)constants;
  319. const analysis::Type* type =
  320. context->get_type_mgr()->GetType(inst->type_id());
  321. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  322. return false;
  323. Instruction* op_inst =
  324. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  325. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  326. return false;
  327. if (op_inst->opcode() == inst->opcode()) {
  328. // Elide negates.
  329. inst->SetOpcode(SpvOpCopyObject);
  330. inst->SetInOperands(
  331. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  332. return true;
  333. }
  334. return false;
  335. };
  336. }
  337. // Merges negate into a mul or div operation if that operation contains a
  338. // constant operand.
  339. // Cases:
  340. // -(x * 2) = x * -2
  341. // -(2 * x) = x * -2
  342. // -(x / 2) = x / -2
  343. // -(2 / x) = -2 / x
  344. FoldingRule MergeNegateMulDivArithmetic() {
  345. return [](IRContext* context, Instruction* inst,
  346. const std::vector<const analysis::Constant*>& constants) {
  347. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  348. (void)constants;
  349. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  350. const analysis::Type* type =
  351. context->get_type_mgr()->GetType(inst->type_id());
  352. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  353. return false;
  354. Instruction* op_inst =
  355. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  356. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  357. return false;
  358. uint32_t width = ElementWidth(type);
  359. if (width != 32 && width != 64) return false;
  360. SpvOp opcode = op_inst->opcode();
  361. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  362. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  363. std::vector<const analysis::Constant*> op_constants =
  364. const_mgr->GetOperandConstants(op_inst);
  365. // Merge negate into mul or div if one operand is constant.
  366. if (op_constants[0] || op_constants[1]) {
  367. bool zero_is_variable = op_constants[0] == nullptr;
  368. const analysis::Constant* c = ConstInput(op_constants);
  369. uint32_t neg_id = NegateConstant(const_mgr, c);
  370. uint32_t non_const_id = zero_is_variable
  371. ? op_inst->GetSingleWordInOperand(0u)
  372. : op_inst->GetSingleWordInOperand(1u);
  373. // Change this instruction to a mul/div.
  374. inst->SetOpcode(op_inst->opcode());
  375. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  376. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  377. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  378. inst->SetInOperands(
  379. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  380. } else {
  381. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  382. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  383. }
  384. return true;
  385. }
  386. }
  387. return false;
  388. };
  389. }
  390. // Merges negate into a add or sub operation if that operation contains a
  391. // constant operand.
  392. // Cases:
  393. // -(x + 2) = -2 - x
  394. // -(2 + x) = -2 - x
  395. // -(x - 2) = 2 - x
  396. // -(2 - x) = x - 2
  397. FoldingRule MergeNegateAddSubArithmetic() {
  398. return [](IRContext* context, Instruction* inst,
  399. const std::vector<const analysis::Constant*>& constants) {
  400. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  401. (void)constants;
  402. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  403. const analysis::Type* type =
  404. context->get_type_mgr()->GetType(inst->type_id());
  405. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  406. return false;
  407. Instruction* op_inst =
  408. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  409. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  410. return false;
  411. uint32_t width = ElementWidth(type);
  412. if (width != 32 && width != 64) return false;
  413. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  414. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  415. std::vector<const analysis::Constant*> op_constants =
  416. const_mgr->GetOperandConstants(op_inst);
  417. if (op_constants[0] || op_constants[1]) {
  418. bool zero_is_variable = op_constants[0] == nullptr;
  419. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  420. (op_inst->opcode() == SpvOpIAdd);
  421. bool swap_operands = !is_add || zero_is_variable;
  422. bool negate_const = is_add;
  423. const analysis::Constant* c = ConstInput(op_constants);
  424. uint32_t const_id = 0;
  425. if (negate_const) {
  426. const_id = NegateConstant(const_mgr, c);
  427. } else {
  428. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  429. : op_inst->GetSingleWordInOperand(0u);
  430. }
  431. // Swap operands if necessary and make the instruction a subtraction.
  432. uint32_t op0 =
  433. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  434. uint32_t op1 =
  435. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  436. if (swap_operands) std::swap(op0, op1);
  437. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  438. inst->SetInOperands(
  439. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  440. return true;
  441. }
  442. }
  443. return false;
  444. };
  445. }
  446. // Returns true if |c| has a zero element.
  447. bool HasZero(const analysis::Constant* c) {
  448. if (c->AsNullConstant()) {
  449. return true;
  450. }
  451. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  452. for (auto& comp : vec_const->GetComponents())
  453. if (HasZero(comp)) return true;
  454. } else {
  455. assert(c->AsScalarConstant());
  456. return c->AsScalarConstant()->IsZero();
  457. }
  458. return false;
  459. }
  460. // Performs |input1| |opcode| |input2| and returns the merged constant result
  461. // id. Returns 0 if the result is not a valid value. The input types must be
  462. // Float.
  463. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  464. SpvOp opcode,
  465. const analysis::Constant* input1,
  466. const analysis::Constant* input2) {
  467. const analysis::Type* type = input1->type();
  468. assert(type->AsFloat());
  469. uint32_t width = type->AsFloat()->width();
  470. assert(width == 32 || width == 64);
  471. std::vector<uint32_t> words;
  472. #define FOLD_OP(op) \
  473. if (width == 64) { \
  474. utils::FloatProxy<double> val = \
  475. input1->GetDouble() op input2->GetDouble(); \
  476. double dval = val.getAsFloat(); \
  477. if (!IsValidResult(dval)) return 0; \
  478. words = val.GetWords(); \
  479. } else { \
  480. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  481. float fval = val.getAsFloat(); \
  482. if (!IsValidResult(fval)) return 0; \
  483. words = val.GetWords(); \
  484. } \
  485. static_assert(true, "require extra semicolon")
  486. switch (opcode) {
  487. case SpvOpFMul:
  488. FOLD_OP(*);
  489. break;
  490. case SpvOpFDiv:
  491. if (HasZero(input2)) return 0;
  492. FOLD_OP(/);
  493. break;
  494. case SpvOpFAdd:
  495. FOLD_OP(+);
  496. break;
  497. case SpvOpFSub:
  498. FOLD_OP(-);
  499. break;
  500. default:
  501. assert(false && "Unexpected operation");
  502. break;
  503. }
  504. #undef FOLD_OP
  505. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  506. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  507. }
  508. // Performs |input1| |opcode| |input2| and returns the merged constant result
  509. // id. Returns 0 if the result is not a valid value. The input types must be
  510. // Integers.
  511. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  512. SpvOp opcode, const analysis::Constant* input1,
  513. const analysis::Constant* input2) {
  514. assert(input1->type()->AsInteger());
  515. const analysis::Integer* type = input1->type()->AsInteger();
  516. uint32_t width = type->AsInteger()->width();
  517. assert(width == 32 || width == 64);
  518. std::vector<uint32_t> words;
  519. // Regardless of the sign of the constant, folding is performed on an unsigned
  520. // interpretation of the constant data. This avoids signed integer overflow
  521. // while folding, and works because sign is irrelevant for the IAdd, ISub and
  522. // IMul instructions.
  523. #define FOLD_OP(op) \
  524. if (width == 64) { \
  525. uint64_t val = input1->GetU64() op input2->GetU64(); \
  526. words = ExtractInts(val); \
  527. } else { \
  528. uint32_t val = input1->GetU32() op input2->GetU32(); \
  529. words.push_back(val); \
  530. } \
  531. static_assert(true, "require extra semicolon")
  532. switch (opcode) {
  533. case SpvOpIMul:
  534. FOLD_OP(*);
  535. break;
  536. case SpvOpSDiv:
  537. case SpvOpUDiv:
  538. assert(false && "Should not merge integer division");
  539. break;
  540. case SpvOpIAdd:
  541. FOLD_OP(+);
  542. break;
  543. case SpvOpISub:
  544. FOLD_OP(-);
  545. break;
  546. default:
  547. assert(false && "Unexpected operation");
  548. break;
  549. }
  550. #undef FOLD_OP
  551. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  552. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  553. }
  554. // Performs |input1| |opcode| |input2| and returns the merged constant result
  555. // id. Returns 0 if the result is not a valid value. The input types must be
  556. // Integers, Floats or Vectors of such.
  557. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  558. const analysis::Constant* input1,
  559. const analysis::Constant* input2) {
  560. assert(input1 && input2);
  561. const analysis::Type* type = input1->type();
  562. std::vector<uint32_t> words;
  563. if (const analysis::Vector* vector_type = type->AsVector()) {
  564. const analysis::Type* ele_type = vector_type->element_type();
  565. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  566. uint32_t id = 0;
  567. const analysis::Constant* input1_comp = nullptr;
  568. if (const analysis::VectorConstant* input1_vector =
  569. input1->AsVectorConstant()) {
  570. input1_comp = input1_vector->GetComponents()[i];
  571. } else {
  572. assert(input1->AsNullConstant());
  573. input1_comp = const_mgr->GetConstant(ele_type, {});
  574. }
  575. const analysis::Constant* input2_comp = nullptr;
  576. if (const analysis::VectorConstant* input2_vector =
  577. input2->AsVectorConstant()) {
  578. input2_comp = input2_vector->GetComponents()[i];
  579. } else {
  580. assert(input2->AsNullConstant());
  581. input2_comp = const_mgr->GetConstant(ele_type, {});
  582. }
  583. if (ele_type->AsFloat()) {
  584. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  585. input2_comp);
  586. } else {
  587. assert(ele_type->AsInteger());
  588. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  589. input2_comp);
  590. }
  591. if (id == 0) return 0;
  592. words.push_back(id);
  593. }
  594. const analysis::Constant* merged_const =
  595. const_mgr->GetConstant(type, words);
  596. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  597. } else if (type->AsFloat()) {
  598. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  599. } else {
  600. assert(type->AsInteger());
  601. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  602. }
  603. }
  604. // Merges consecutive multiplies where each contains one constant operand.
  605. // Cases:
  606. // 2 * (x * 2) = x * 4
  607. // 2 * (2 * x) = x * 4
  608. // (x * 2) * 2 = x * 4
  609. // (2 * x) * 2 = x * 4
  610. FoldingRule MergeMulMulArithmetic() {
  611. return [](IRContext* context, Instruction* inst,
  612. const std::vector<const analysis::Constant*>& constants) {
  613. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  614. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  615. const analysis::Type* type =
  616. context->get_type_mgr()->GetType(inst->type_id());
  617. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  618. return false;
  619. uint32_t width = ElementWidth(type);
  620. if (width != 32 && width != 64) return false;
  621. // Determine the constant input and the variable input in |inst|.
  622. const analysis::Constant* const_input1 = ConstInput(constants);
  623. if (!const_input1) return false;
  624. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  625. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  626. return false;
  627. if (other_inst->opcode() == inst->opcode()) {
  628. std::vector<const analysis::Constant*> other_constants =
  629. const_mgr->GetOperandConstants(other_inst);
  630. const analysis::Constant* const_input2 = ConstInput(other_constants);
  631. if (!const_input2) return false;
  632. bool other_first_is_variable = other_constants[0] == nullptr;
  633. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  634. const_input1, const_input2);
  635. if (merged_id == 0) return false;
  636. uint32_t non_const_id = other_first_is_variable
  637. ? other_inst->GetSingleWordInOperand(0u)
  638. : other_inst->GetSingleWordInOperand(1u);
  639. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  640. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  641. return true;
  642. }
  643. return false;
  644. };
  645. }
  646. // Merges divides into subsequent multiplies if each instruction contains one
  647. // constant operand. Does not support integer operations.
  648. // Cases:
  649. // 2 * (x / 2) = x * 1
  650. // 2 * (2 / x) = 4 / x
  651. // (x / 2) * 2 = x * 1
  652. // (2 / x) * 2 = 4 / x
  653. // (y / x) * x = y
  654. // x * (y / x) = y
  655. FoldingRule MergeMulDivArithmetic() {
  656. return [](IRContext* context, Instruction* inst,
  657. const std::vector<const analysis::Constant*>& constants) {
  658. assert(inst->opcode() == SpvOpFMul);
  659. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  660. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  661. const analysis::Type* type =
  662. context->get_type_mgr()->GetType(inst->type_id());
  663. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  664. uint32_t width = ElementWidth(type);
  665. if (width != 32 && width != 64) return false;
  666. for (uint32_t i = 0; i < 2; i++) {
  667. uint32_t op_id = inst->GetSingleWordInOperand(i);
  668. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  669. if (op_inst->opcode() == SpvOpFDiv) {
  670. if (op_inst->GetSingleWordInOperand(1) ==
  671. inst->GetSingleWordInOperand(1 - i)) {
  672. inst->SetOpcode(SpvOpCopyObject);
  673. inst->SetInOperands(
  674. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  675. return true;
  676. }
  677. }
  678. }
  679. const analysis::Constant* const_input1 = ConstInput(constants);
  680. if (!const_input1) return false;
  681. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  682. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  683. if (other_inst->opcode() == SpvOpFDiv) {
  684. std::vector<const analysis::Constant*> other_constants =
  685. const_mgr->GetOperandConstants(other_inst);
  686. const analysis::Constant* const_input2 = ConstInput(other_constants);
  687. if (!const_input2 || HasZero(const_input2)) return false;
  688. bool other_first_is_variable = other_constants[0] == nullptr;
  689. // If the variable value is the second operand of the divide, multiply
  690. // the constants together. Otherwise divide the constants.
  691. uint32_t merged_id = PerformOperation(
  692. const_mgr,
  693. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  694. const_input1, const_input2);
  695. if (merged_id == 0) return false;
  696. uint32_t non_const_id = other_first_is_variable
  697. ? other_inst->GetSingleWordInOperand(0u)
  698. : other_inst->GetSingleWordInOperand(1u);
  699. // If the variable value is on the second operand of the div, then this
  700. // operation is a div. Otherwise it should be a multiply.
  701. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  702. : other_inst->opcode());
  703. if (other_first_is_variable) {
  704. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  705. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  706. } else {
  707. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  708. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  709. }
  710. return true;
  711. }
  712. return false;
  713. };
  714. }
  715. // Merges multiply of constant and negation.
  716. // Cases:
  717. // (-x) * 2 = x * -2
  718. // 2 * (-x) = x * -2
  719. FoldingRule MergeMulNegateArithmetic() {
  720. return [](IRContext* context, Instruction* inst,
  721. const std::vector<const analysis::Constant*>& constants) {
  722. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  723. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  724. const analysis::Type* type =
  725. context->get_type_mgr()->GetType(inst->type_id());
  726. bool uses_float = HasFloatingPoint(type);
  727. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  728. uint32_t width = ElementWidth(type);
  729. if (width != 32 && width != 64) return false;
  730. const analysis::Constant* const_input1 = ConstInput(constants);
  731. if (!const_input1) return false;
  732. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  733. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  734. return false;
  735. if (other_inst->opcode() == SpvOpFNegate ||
  736. other_inst->opcode() == SpvOpSNegate) {
  737. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  738. inst->SetInOperands(
  739. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  740. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  741. return true;
  742. }
  743. return false;
  744. };
  745. }
  746. // Merges consecutive divides if each instruction contains one constant operand.
  747. // Does not support integer division.
  748. // Cases:
  749. // 2 / (x / 2) = 4 / x
  750. // 4 / (2 / x) = 2 * x
  751. // (4 / x) / 2 = 2 / x
  752. // (x / 2) / 2 = x / 4
  753. FoldingRule MergeDivDivArithmetic() {
  754. return [](IRContext* context, Instruction* inst,
  755. const std::vector<const analysis::Constant*>& constants) {
  756. assert(inst->opcode() == SpvOpFDiv);
  757. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  758. const analysis::Type* type =
  759. context->get_type_mgr()->GetType(inst->type_id());
  760. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  761. uint32_t width = ElementWidth(type);
  762. if (width != 32 && width != 64) return false;
  763. const analysis::Constant* const_input1 = ConstInput(constants);
  764. if (!const_input1 || HasZero(const_input1)) return false;
  765. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  766. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  767. bool first_is_variable = constants[0] == nullptr;
  768. if (other_inst->opcode() == inst->opcode()) {
  769. std::vector<const analysis::Constant*> other_constants =
  770. const_mgr->GetOperandConstants(other_inst);
  771. const analysis::Constant* const_input2 = ConstInput(other_constants);
  772. if (!const_input2 || HasZero(const_input2)) return false;
  773. bool other_first_is_variable = other_constants[0] == nullptr;
  774. SpvOp merge_op = inst->opcode();
  775. if (other_first_is_variable) {
  776. // Constants magnify.
  777. merge_op = SpvOpFMul;
  778. }
  779. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  780. // because it is commutative.
  781. if (first_is_variable) std::swap(const_input1, const_input2);
  782. uint32_t merged_id =
  783. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  784. if (merged_id == 0) return false;
  785. uint32_t non_const_id = other_first_is_variable
  786. ? other_inst->GetSingleWordInOperand(0u)
  787. : other_inst->GetSingleWordInOperand(1u);
  788. SpvOp op = inst->opcode();
  789. if (!first_is_variable && !other_first_is_variable) {
  790. // Effectively div of 1/x, so change to multiply.
  791. op = SpvOpFMul;
  792. }
  793. uint32_t op1 = merged_id;
  794. uint32_t op2 = non_const_id;
  795. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  796. inst->SetOpcode(op);
  797. inst->SetInOperands(
  798. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  799. return true;
  800. }
  801. return false;
  802. };
  803. }
  804. // Fold multiplies succeeded by divides where each instruction contains a
  805. // constant operand. Does not support integer divide.
  806. // Cases:
  807. // 4 / (x * 2) = 2 / x
  808. // 4 / (2 * x) = 2 / x
  809. // (x * 4) / 2 = x * 2
  810. // (4 * x) / 2 = x * 2
  811. // (x * y) / x = y
  812. // (y * x) / x = y
  813. FoldingRule MergeDivMulArithmetic() {
  814. return [](IRContext* context, Instruction* inst,
  815. const std::vector<const analysis::Constant*>& constants) {
  816. assert(inst->opcode() == SpvOpFDiv);
  817. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  818. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  819. const analysis::Type* type =
  820. context->get_type_mgr()->GetType(inst->type_id());
  821. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  822. uint32_t width = ElementWidth(type);
  823. if (width != 32 && width != 64) return false;
  824. uint32_t op_id = inst->GetSingleWordInOperand(0);
  825. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  826. if (op_inst->opcode() == SpvOpFMul) {
  827. for (uint32_t i = 0; i < 2; i++) {
  828. if (op_inst->GetSingleWordInOperand(i) ==
  829. inst->GetSingleWordInOperand(1)) {
  830. inst->SetOpcode(SpvOpCopyObject);
  831. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  832. {op_inst->GetSingleWordInOperand(1 - i)}}});
  833. return true;
  834. }
  835. }
  836. }
  837. const analysis::Constant* const_input1 = ConstInput(constants);
  838. if (!const_input1 || HasZero(const_input1)) return false;
  839. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  840. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  841. bool first_is_variable = constants[0] == nullptr;
  842. if (other_inst->opcode() == SpvOpFMul) {
  843. std::vector<const analysis::Constant*> other_constants =
  844. const_mgr->GetOperandConstants(other_inst);
  845. const analysis::Constant* const_input2 = ConstInput(other_constants);
  846. if (!const_input2) return false;
  847. bool other_first_is_variable = other_constants[0] == nullptr;
  848. // This is an x / (*) case. Swap the inputs.
  849. if (first_is_variable) std::swap(const_input1, const_input2);
  850. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  851. const_input1, const_input2);
  852. if (merged_id == 0) return false;
  853. uint32_t non_const_id = other_first_is_variable
  854. ? other_inst->GetSingleWordInOperand(0u)
  855. : other_inst->GetSingleWordInOperand(1u);
  856. uint32_t op1 = merged_id;
  857. uint32_t op2 = non_const_id;
  858. if (first_is_variable) std::swap(op1, op2);
  859. // Convert to multiply
  860. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  861. inst->SetInOperands(
  862. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  863. return true;
  864. }
  865. return false;
  866. };
  867. }
  868. // Fold divides of a constant and a negation.
  869. // Cases:
  870. // (-x) / 2 = x / -2
  871. // 2 / (-x) = -2 / x
  872. FoldingRule MergeDivNegateArithmetic() {
  873. return [](IRContext* context, Instruction* inst,
  874. const std::vector<const analysis::Constant*>& constants) {
  875. assert(inst->opcode() == SpvOpFDiv);
  876. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  877. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  878. const analysis::Constant* const_input1 = ConstInput(constants);
  879. if (!const_input1) return false;
  880. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  881. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  882. bool first_is_variable = constants[0] == nullptr;
  883. if (other_inst->opcode() == SpvOpFNegate) {
  884. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  885. if (first_is_variable) {
  886. inst->SetInOperands(
  887. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  888. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  889. } else {
  890. inst->SetInOperands(
  891. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  892. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  893. }
  894. return true;
  895. }
  896. return false;
  897. };
  898. }
  899. // Folds addition of a constant and a negation.
  900. // Cases:
  901. // (-x) + 2 = 2 - x
  902. // 2 + (-x) = 2 - x
  903. FoldingRule MergeAddNegateArithmetic() {
  904. return [](IRContext* context, Instruction* inst,
  905. const std::vector<const analysis::Constant*>& constants) {
  906. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  907. const analysis::Type* type =
  908. context->get_type_mgr()->GetType(inst->type_id());
  909. bool uses_float = HasFloatingPoint(type);
  910. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  911. const analysis::Constant* const_input1 = ConstInput(constants);
  912. if (!const_input1) return false;
  913. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  914. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  915. return false;
  916. if (other_inst->opcode() == SpvOpSNegate ||
  917. other_inst->opcode() == SpvOpFNegate) {
  918. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  919. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  920. : inst->GetSingleWordInOperand(1u);
  921. inst->SetInOperands(
  922. {{SPV_OPERAND_TYPE_ID, {const_id}},
  923. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  924. return true;
  925. }
  926. return false;
  927. };
  928. }
  929. // Folds subtraction of a constant and a negation.
  930. // Cases:
  931. // (-x) - 2 = -2 - x
  932. // 2 - (-x) = x + 2
  933. FoldingRule MergeSubNegateArithmetic() {
  934. return [](IRContext* context, Instruction* inst,
  935. const std::vector<const analysis::Constant*>& constants) {
  936. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  937. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  938. const analysis::Type* type =
  939. context->get_type_mgr()->GetType(inst->type_id());
  940. bool uses_float = HasFloatingPoint(type);
  941. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  942. uint32_t width = ElementWidth(type);
  943. if (width != 32 && width != 64) return false;
  944. const analysis::Constant* const_input1 = ConstInput(constants);
  945. if (!const_input1) return false;
  946. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  947. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  948. return false;
  949. if (other_inst->opcode() == SpvOpSNegate ||
  950. other_inst->opcode() == SpvOpFNegate) {
  951. uint32_t op1 = 0;
  952. uint32_t op2 = 0;
  953. SpvOp opcode = inst->opcode();
  954. if (constants[0] != nullptr) {
  955. op1 = other_inst->GetSingleWordInOperand(0u);
  956. op2 = inst->GetSingleWordInOperand(0u);
  957. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  958. } else {
  959. op1 = NegateConstant(const_mgr, const_input1);
  960. op2 = other_inst->GetSingleWordInOperand(0u);
  961. }
  962. inst->SetOpcode(opcode);
  963. inst->SetInOperands(
  964. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  965. return true;
  966. }
  967. return false;
  968. };
  969. }
  970. // Folds addition of an addition where each operation has a constant operand.
  971. // Cases:
  972. // (x + 2) + 2 = x + 4
  973. // (2 + x) + 2 = x + 4
  974. // 2 + (x + 2) = x + 4
  975. // 2 + (2 + x) = x + 4
  976. FoldingRule MergeAddAddArithmetic() {
  977. return [](IRContext* context, Instruction* inst,
  978. const std::vector<const analysis::Constant*>& constants) {
  979. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  980. const analysis::Type* type =
  981. context->get_type_mgr()->GetType(inst->type_id());
  982. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  983. bool uses_float = HasFloatingPoint(type);
  984. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  985. uint32_t width = ElementWidth(type);
  986. if (width != 32 && width != 64) return false;
  987. const analysis::Constant* const_input1 = ConstInput(constants);
  988. if (!const_input1) return false;
  989. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  990. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  991. return false;
  992. if (other_inst->opcode() == SpvOpFAdd ||
  993. other_inst->opcode() == SpvOpIAdd) {
  994. std::vector<const analysis::Constant*> other_constants =
  995. const_mgr->GetOperandConstants(other_inst);
  996. const analysis::Constant* const_input2 = ConstInput(other_constants);
  997. if (!const_input2) return false;
  998. Instruction* non_const_input =
  999. NonConstInput(context, other_constants[0], other_inst);
  1000. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1001. const_input1, const_input2);
  1002. if (merged_id == 0) return false;
  1003. inst->SetInOperands(
  1004. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  1005. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  1006. return true;
  1007. }
  1008. return false;
  1009. };
  1010. }
  1011. // Folds addition of a subtraction where each operation has a constant operand.
  1012. // Cases:
  1013. // (x - 2) + 2 = x + 0
  1014. // (2 - x) + 2 = 4 - x
  1015. // 2 + (x - 2) = x + 0
  1016. // 2 + (2 - x) = 4 - x
  1017. FoldingRule MergeAddSubArithmetic() {
  1018. return [](IRContext* context, Instruction* inst,
  1019. const std::vector<const analysis::Constant*>& constants) {
  1020. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1021. const analysis::Type* type =
  1022. context->get_type_mgr()->GetType(inst->type_id());
  1023. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1024. bool uses_float = HasFloatingPoint(type);
  1025. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1026. uint32_t width = ElementWidth(type);
  1027. if (width != 32 && width != 64) return false;
  1028. const analysis::Constant* const_input1 = ConstInput(constants);
  1029. if (!const_input1) return false;
  1030. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1031. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1032. return false;
  1033. if (other_inst->opcode() == SpvOpFSub ||
  1034. other_inst->opcode() == SpvOpISub) {
  1035. std::vector<const analysis::Constant*> other_constants =
  1036. const_mgr->GetOperandConstants(other_inst);
  1037. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1038. if (!const_input2) return false;
  1039. bool first_is_variable = other_constants[0] == nullptr;
  1040. SpvOp op = inst->opcode();
  1041. uint32_t op1 = 0;
  1042. uint32_t op2 = 0;
  1043. if (first_is_variable) {
  1044. // Subtract constants. Non-constant operand is first.
  1045. op1 = other_inst->GetSingleWordInOperand(0u);
  1046. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  1047. const_input2);
  1048. } else {
  1049. // Add constants. Constant operand is first. Change the opcode.
  1050. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  1051. const_input2);
  1052. op2 = other_inst->GetSingleWordInOperand(1u);
  1053. op = other_inst->opcode();
  1054. }
  1055. if (op1 == 0 || op2 == 0) return false;
  1056. inst->SetOpcode(op);
  1057. inst->SetInOperands(
  1058. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1059. return true;
  1060. }
  1061. return false;
  1062. };
  1063. }
  1064. // Folds subtraction of an addition where each operand has a constant operand.
  1065. // Cases:
  1066. // (x + 2) - 2 = x + 0
  1067. // (2 + x) - 2 = x + 0
  1068. // 2 - (x + 2) = 0 - x
  1069. // 2 - (2 + x) = 0 - x
  1070. FoldingRule MergeSubAddArithmetic() {
  1071. return [](IRContext* context, Instruction* inst,
  1072. const std::vector<const analysis::Constant*>& constants) {
  1073. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1074. const analysis::Type* type =
  1075. context->get_type_mgr()->GetType(inst->type_id());
  1076. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1077. bool uses_float = HasFloatingPoint(type);
  1078. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1079. uint32_t width = ElementWidth(type);
  1080. if (width != 32 && width != 64) return false;
  1081. const analysis::Constant* const_input1 = ConstInput(constants);
  1082. if (!const_input1) return false;
  1083. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1084. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1085. return false;
  1086. if (other_inst->opcode() == SpvOpFAdd ||
  1087. other_inst->opcode() == SpvOpIAdd) {
  1088. std::vector<const analysis::Constant*> other_constants =
  1089. const_mgr->GetOperandConstants(other_inst);
  1090. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1091. if (!const_input2) return false;
  1092. Instruction* non_const_input =
  1093. NonConstInput(context, other_constants[0], other_inst);
  1094. // If the first operand of the sub is not a constant, swap the constants
  1095. // so the subtraction has the correct operands.
  1096. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1097. // Subtract the constants.
  1098. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1099. const_input1, const_input2);
  1100. SpvOp op = inst->opcode();
  1101. uint32_t op1 = 0;
  1102. uint32_t op2 = 0;
  1103. if (constants[0] == nullptr) {
  1104. // Non-constant operand is first. Change the opcode.
  1105. op1 = non_const_input->result_id();
  1106. op2 = merged_id;
  1107. op = other_inst->opcode();
  1108. } else {
  1109. // Constant operand is first.
  1110. op1 = merged_id;
  1111. op2 = non_const_input->result_id();
  1112. }
  1113. if (op1 == 0 || op2 == 0) return false;
  1114. inst->SetOpcode(op);
  1115. inst->SetInOperands(
  1116. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1117. return true;
  1118. }
  1119. return false;
  1120. };
  1121. }
  1122. // Folds subtraction of a subtraction where each operand has a constant operand.
  1123. // Cases:
  1124. // (x - 2) - 2 = x - 4
  1125. // (2 - x) - 2 = 0 - x
  1126. // 2 - (x - 2) = 4 - x
  1127. // 2 - (2 - x) = x + 0
  1128. FoldingRule MergeSubSubArithmetic() {
  1129. return [](IRContext* context, Instruction* inst,
  1130. const std::vector<const analysis::Constant*>& constants) {
  1131. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1132. const analysis::Type* type =
  1133. context->get_type_mgr()->GetType(inst->type_id());
  1134. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1135. bool uses_float = HasFloatingPoint(type);
  1136. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1137. uint32_t width = ElementWidth(type);
  1138. if (width != 32 && width != 64) return false;
  1139. const analysis::Constant* const_input1 = ConstInput(constants);
  1140. if (!const_input1) return false;
  1141. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1142. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1143. return false;
  1144. if (other_inst->opcode() == SpvOpFSub ||
  1145. other_inst->opcode() == SpvOpISub) {
  1146. std::vector<const analysis::Constant*> other_constants =
  1147. const_mgr->GetOperandConstants(other_inst);
  1148. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1149. if (!const_input2) return false;
  1150. Instruction* non_const_input =
  1151. NonConstInput(context, other_constants[0], other_inst);
  1152. // Merge the constants.
  1153. uint32_t merged_id = 0;
  1154. SpvOp merge_op = inst->opcode();
  1155. if (other_constants[0] == nullptr) {
  1156. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1157. } else if (constants[0] == nullptr) {
  1158. std::swap(const_input1, const_input2);
  1159. }
  1160. merged_id =
  1161. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1162. if (merged_id == 0) return false;
  1163. SpvOp op = inst->opcode();
  1164. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1165. // Change the operation.
  1166. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1167. }
  1168. uint32_t op1 = 0;
  1169. uint32_t op2 = 0;
  1170. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1171. op1 = merged_id;
  1172. op2 = non_const_input->result_id();
  1173. } else {
  1174. op1 = non_const_input->result_id();
  1175. op2 = merged_id;
  1176. }
  1177. inst->SetOpcode(op);
  1178. inst->SetInOperands(
  1179. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1180. return true;
  1181. }
  1182. return false;
  1183. };
  1184. }
  1185. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1186. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1187. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1188. IRContext* context = inst->context();
  1189. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1190. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1191. if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
  1192. return false;
  1193. if (sub_inst->opcode() == SpvOpFSub &&
  1194. !sub_inst->IsFloatingPointFoldingAllowed())
  1195. return false;
  1196. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1197. inst->SetOpcode(SpvOpCopyObject);
  1198. inst->SetInOperands(
  1199. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1200. context->UpdateDefUse(inst);
  1201. return true;
  1202. }
  1203. // Folds addition of a subtraction where the subtrahend is equal to the
  1204. // other addend. Return a copy of the minuend. Accepts generic (const and
  1205. // non-const) operands.
  1206. // Cases:
  1207. // (a - b) + b = a
  1208. // b + (a - b) = a
  1209. FoldingRule MergeGenericAddSubArithmetic() {
  1210. return [](IRContext* context, Instruction* inst,
  1211. const std::vector<const analysis::Constant*>&) {
  1212. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1213. const analysis::Type* type =
  1214. context->get_type_mgr()->GetType(inst->type_id());
  1215. bool uses_float = HasFloatingPoint(type);
  1216. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1217. uint32_t width = ElementWidth(type);
  1218. if (width != 32 && width != 64) return false;
  1219. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1220. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1221. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1222. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1223. };
  1224. }
  1225. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1226. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1227. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1228. uint32_t factor1_0, uint32_t factor1_1,
  1229. Instruction* inst) {
  1230. IRContext* context = inst->context();
  1231. if (factor0_0 != factor1_0) return false;
  1232. InstructionBuilder ir_builder(
  1233. context, inst,
  1234. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1235. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1236. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1237. inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
  1238. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1239. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1240. context->UpdateDefUse(inst);
  1241. return true;
  1242. }
  1243. // Perform the following factoring identity, handling all operand order
  1244. // combinations: (a * b) + (a * c) = a * (b + c)
  1245. FoldingRule FactorAddMuls() {
  1246. return [](IRContext* context, Instruction* inst,
  1247. const std::vector<const analysis::Constant*>&) {
  1248. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1249. const analysis::Type* type =
  1250. context->get_type_mgr()->GetType(inst->type_id());
  1251. bool uses_float = HasFloatingPoint(type);
  1252. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1253. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1254. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1255. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1256. if (add_op0_inst->opcode() != SpvOpFMul &&
  1257. add_op0_inst->opcode() != SpvOpIMul)
  1258. return false;
  1259. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1260. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1261. if (add_op1_inst->opcode() != SpvOpFMul &&
  1262. add_op1_inst->opcode() != SpvOpIMul)
  1263. return false;
  1264. // Only perform this optimization if both of the muls only have one use.
  1265. // Otherwise this is a deoptimization in size and performance.
  1266. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1267. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1268. if (add_op0_inst->opcode() == SpvOpFMul &&
  1269. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1270. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1271. return false;
  1272. for (int i = 0; i < 2; i++) {
  1273. for (int j = 0; j < 2; j++) {
  1274. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1275. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1276. add_op0_inst->GetSingleWordInOperand(1 - i),
  1277. add_op1_inst->GetSingleWordInOperand(j),
  1278. add_op1_inst->GetSingleWordInOperand(1 - j),
  1279. inst))
  1280. return true;
  1281. }
  1282. }
  1283. return false;
  1284. };
  1285. }
  1286. // Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
  1287. void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
  1288. uint32_t ext =
  1289. inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1290. if (ext == 0) {
  1291. inst->context()->AddExtInstImport("GLSL.std.450");
  1292. ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1293. assert(ext != 0 &&
  1294. "Could not add the GLSL.std.450 extended instruction set");
  1295. }
  1296. std::vector<Operand> operands;
  1297. operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
  1298. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
  1299. operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
  1300. operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
  1301. operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
  1302. inst->SetOpcode(SpvOpExtInst);
  1303. inst->SetInOperands(std::move(operands));
  1304. }
  1305. // Folds a multiple and add into an Fma.
  1306. //
  1307. // Cases:
  1308. // (x * y) + a = Fma x y a
  1309. // a + (x * y) = Fma x y a
  1310. bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
  1311. const std::vector<const analysis::Constant*>&) {
  1312. assert(inst->opcode() == SpvOpFAdd);
  1313. if (!inst->IsFloatingPointFoldingAllowed()) {
  1314. return false;
  1315. }
  1316. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1317. for (int i = 0; i < 2; i++) {
  1318. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1319. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  1320. if (op_inst->opcode() != SpvOpFMul) {
  1321. continue;
  1322. }
  1323. if (!op_inst->IsFloatingPointFoldingAllowed()) {
  1324. continue;
  1325. }
  1326. uint32_t x = op_inst->GetSingleWordInOperand(0);
  1327. uint32_t y = op_inst->GetSingleWordInOperand(1);
  1328. uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
  1329. ReplaceWithFma(inst, x, y, a);
  1330. return true;
  1331. }
  1332. return false;
  1333. }
  1334. // Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
  1335. // negated if |negate_addition| is true, otherwise |x| gets negated.
  1336. void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
  1337. uint32_t a, bool negate_addition) {
  1338. uint32_t ext =
  1339. sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1340. if (ext == 0) {
  1341. sub->context()->AddExtInstImport("GLSL.std.450");
  1342. ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1343. assert(ext != 0 &&
  1344. "Could not add the GLSL.std.450 extended instruction set");
  1345. }
  1346. InstructionBuilder ir_builder(
  1347. sub->context(), sub,
  1348. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1349. Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate,
  1350. negate_addition ? a : x);
  1351. uint32_t neg_op = neg->result_id(); // -a : -x
  1352. std::vector<Operand> operands;
  1353. operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
  1354. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
  1355. operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
  1356. operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
  1357. operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
  1358. sub->SetOpcode(SpvOpExtInst);
  1359. sub->SetInOperands(std::move(operands));
  1360. }
  1361. // Folds a multiply and subtract into an Fma and negation.
  1362. //
  1363. // Cases:
  1364. // (x * y) - a = Fma x y -a
  1365. // a - (x * y) = Fma -x y a
  1366. bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
  1367. const std::vector<const analysis::Constant*>&) {
  1368. assert(sub->opcode() == SpvOpFSub);
  1369. if (!sub->IsFloatingPointFoldingAllowed()) {
  1370. return false;
  1371. }
  1372. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1373. for (int i = 0; i < 2; i++) {
  1374. uint32_t op_id = sub->GetSingleWordInOperand(i);
  1375. Instruction* mul = def_use_mgr->GetDef(op_id);
  1376. if (mul->opcode() != SpvOpFMul) {
  1377. continue;
  1378. }
  1379. if (!mul->IsFloatingPointFoldingAllowed()) {
  1380. continue;
  1381. }
  1382. uint32_t x = mul->GetSingleWordInOperand(0);
  1383. uint32_t y = mul->GetSingleWordInOperand(1);
  1384. uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
  1385. ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
  1386. return true;
  1387. }
  1388. return false;
  1389. }
  1390. FoldingRule IntMultipleBy1() {
  1391. return [](IRContext*, Instruction* inst,
  1392. const std::vector<const analysis::Constant*>& constants) {
  1393. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1394. for (uint32_t i = 0; i < 2; i++) {
  1395. if (constants[i] == nullptr) {
  1396. continue;
  1397. }
  1398. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1399. if (int_constant) {
  1400. uint32_t width = ElementWidth(int_constant->type());
  1401. if (width != 32 && width != 64) return false;
  1402. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1403. : int_constant->GetU64BitValue() == 1ull;
  1404. if (is_one) {
  1405. inst->SetOpcode(SpvOpCopyObject);
  1406. inst->SetInOperands(
  1407. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1408. return true;
  1409. }
  1410. }
  1411. }
  1412. return false;
  1413. };
  1414. }
  1415. // Returns the number of elements that the |index|th in operand in |inst|
  1416. // contributes to the result of |inst|. |inst| must be an
  1417. // OpCompositeConstructInstruction.
  1418. uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
  1419. const Instruction* inst,
  1420. uint32_t index) {
  1421. assert(inst->opcode() == SpvOpCompositeConstruct);
  1422. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1423. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1424. analysis::Vector* result_type =
  1425. type_mgr->GetType(inst->type_id())->AsVector();
  1426. if (result_type == nullptr) {
  1427. // If the result of the OpCompositeConstruct is not a vector then every
  1428. // operands corresponds to a single element in the result.
  1429. return 1;
  1430. }
  1431. // If the result type is a vector then the operands are either scalars or
  1432. // vectors. If it is a scalar, then it corresponds to a single element. If it
  1433. // is a vector, then each element in the vector will be an element in the
  1434. // result.
  1435. uint32_t id = inst->GetSingleWordInOperand(index);
  1436. Instruction* def = def_use_mgr->GetDef(id);
  1437. analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
  1438. if (type == nullptr) {
  1439. return 1;
  1440. }
  1441. return type->element_count();
  1442. }
  1443. // Returns the in-operands for an OpCompositeExtract instruction that are needed
  1444. // to extract the |result_index|th element in the result of |inst| without using
  1445. // the result of |inst|. Returns the empty vector if |result_index| is
  1446. // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
  1447. std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
  1448. IRContext* context, const Instruction* inst, uint32_t result_index) {
  1449. assert(inst->opcode() == SpvOpCompositeConstruct);
  1450. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1451. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1452. analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  1453. if (result_type->AsVector() == nullptr) {
  1454. uint32_t id = inst->GetSingleWordInOperand(result_index);
  1455. return {Operand(SPV_OPERAND_TYPE_ID, {id})};
  1456. }
  1457. // If the result type is a vector, then vector operands are concatenated.
  1458. uint32_t total_element_count = 0;
  1459. for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
  1460. uint32_t element_count =
  1461. GetNumOfElementsContributedByOperand(context, inst, idx);
  1462. total_element_count += element_count;
  1463. if (result_index < total_element_count) {
  1464. std::vector<Operand> operands;
  1465. uint32_t id = inst->GetSingleWordInOperand(idx);
  1466. Instruction* operand_def = def_use_mgr->GetDef(id);
  1467. analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
  1468. operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
  1469. if (operand_type->AsVector()) {
  1470. uint32_t start_index_of_id = total_element_count - element_count;
  1471. uint32_t index_into_id = result_index - start_index_of_id;
  1472. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
  1473. }
  1474. return operands;
  1475. }
  1476. }
  1477. return {};
  1478. }
  1479. bool CompositeConstructFeedingExtract(
  1480. IRContext* context, Instruction* inst,
  1481. const std::vector<const analysis::Constant*>&) {
  1482. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1483. // then we can simply use the appropriate element in the construction.
  1484. assert(inst->opcode() == SpvOpCompositeExtract &&
  1485. "Wrong opcode. Should be OpCompositeExtract.");
  1486. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1487. // If there are no index operands, then this rule cannot do anything.
  1488. if (inst->NumInOperands() <= 1) {
  1489. return false;
  1490. }
  1491. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1492. Instruction* cinst = def_use_mgr->GetDef(cid);
  1493. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1494. return false;
  1495. }
  1496. uint32_t index_into_result = inst->GetSingleWordInOperand(1);
  1497. std::vector<Operand> operands =
  1498. GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
  1499. index_into_result);
  1500. if (operands.empty()) {
  1501. return false;
  1502. }
  1503. // Add the remaining indices for extraction.
  1504. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1505. operands.push_back(
  1506. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
  1507. }
  1508. if (operands.size() == 1) {
  1509. // If there were no extra indices, then we have the final object. No need
  1510. // to extract any more.
  1511. inst->SetOpcode(SpvOpCopyObject);
  1512. }
  1513. inst->SetInOperands(std::move(operands));
  1514. return true;
  1515. }
  1516. // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
  1517. // OpCompositeExtract instruction, and returns the type of the final element
  1518. // being accessed.
  1519. const analysis::Type* GetElementType(uint32_t type_id,
  1520. Instruction::iterator start,
  1521. Instruction::iterator end,
  1522. const analysis::TypeManager* type_mgr) {
  1523. const analysis::Type* type = type_mgr->GetType(type_id);
  1524. for (auto index : make_range(std::move(start), std::move(end))) {
  1525. assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
  1526. index.words.size() == 1);
  1527. if (auto* array_type = type->AsArray()) {
  1528. type = array_type->element_type();
  1529. } else if (auto* matrix_type = type->AsMatrix()) {
  1530. type = matrix_type->element_type();
  1531. } else if (auto* struct_type = type->AsStruct()) {
  1532. type = struct_type->element_types()[index.words[0]];
  1533. } else {
  1534. type = nullptr;
  1535. }
  1536. }
  1537. return type;
  1538. }
  1539. // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
  1540. // to index into a composite object, excluding the last index. The two
  1541. // instructions must have the same opcode, and be either OpCompositeExtract or
  1542. // OpCompositeInsert instructions.
  1543. bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
  1544. assert(inst_1->opcode() == inst_2->opcode() &&
  1545. "Expecting the opcodes to be the same.");
  1546. assert((inst_1->opcode() == SpvOpCompositeInsert ||
  1547. inst_1->opcode() == SpvOpCompositeExtract) &&
  1548. "Instructions must be OpCompositeInsert or OpCompositeExtract.");
  1549. if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
  1550. return false;
  1551. }
  1552. uint32_t first_index_position =
  1553. (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
  1554. for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
  1555. i++) {
  1556. if (inst_1->GetSingleWordInOperand(i) !=
  1557. inst_2->GetSingleWordInOperand(i)) {
  1558. return false;
  1559. }
  1560. }
  1561. return true;
  1562. }
  1563. // If the OpCompositeConstruct is simply putting back together elements that
  1564. // where extracted from the same source, we can simply reuse the source.
  1565. //
  1566. // This is a common code pattern because of the way that scalar replacement
  1567. // works.
  1568. bool CompositeExtractFeedingConstruct(
  1569. IRContext* context, Instruction* inst,
  1570. const std::vector<const analysis::Constant*>&) {
  1571. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1572. "Wrong opcode. Should be OpCompositeConstruct.");
  1573. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1574. uint32_t original_id = 0;
  1575. if (inst->NumInOperands() == 0) {
  1576. // The struct being constructed has no members.
  1577. return false;
  1578. }
  1579. // Check each element to make sure they are:
  1580. // - extractions
  1581. // - extracting the same position they are inserting
  1582. // - all extract from the same id.
  1583. Instruction* first_element_inst = nullptr;
  1584. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1585. const uint32_t element_id = inst->GetSingleWordInOperand(i);
  1586. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1587. if (first_element_inst == nullptr) {
  1588. first_element_inst = element_inst;
  1589. }
  1590. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1591. return false;
  1592. }
  1593. if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
  1594. return false;
  1595. }
  1596. if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
  1597. 1) != i) {
  1598. return false;
  1599. }
  1600. if (i == 0) {
  1601. original_id =
  1602. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1603. } else if (original_id !=
  1604. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
  1605. return false;
  1606. }
  1607. }
  1608. // The last check it to see that the object being extracted from is the
  1609. // correct type.
  1610. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1611. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1612. const analysis::Type* original_type =
  1613. GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
  1614. first_element_inst->end() - 1, type_mgr);
  1615. if (original_type == nullptr) {
  1616. return false;
  1617. }
  1618. if (inst->type_id() != type_mgr->GetId(original_type)) {
  1619. return false;
  1620. }
  1621. if (first_element_inst->NumInOperands() == 2) {
  1622. // Simplify by using the original object.
  1623. inst->SetOpcode(SpvOpCopyObject);
  1624. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1625. return true;
  1626. }
  1627. // Copies the original id and all indexes except for the last to the new
  1628. // extract instruction.
  1629. inst->SetOpcode(SpvOpCompositeExtract);
  1630. inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
  1631. first_element_inst->end() - 1));
  1632. return true;
  1633. }
  1634. FoldingRule InsertFeedingExtract() {
  1635. return [](IRContext* context, Instruction* inst,
  1636. const std::vector<const analysis::Constant*>&) {
  1637. assert(inst->opcode() == SpvOpCompositeExtract &&
  1638. "Wrong opcode. Should be OpCompositeExtract.");
  1639. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1640. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1641. Instruction* cinst = def_use_mgr->GetDef(cid);
  1642. if (cinst->opcode() != SpvOpCompositeInsert) {
  1643. return false;
  1644. }
  1645. // Find the first position where the list of insert and extract indicies
  1646. // differ, if at all.
  1647. uint32_t i;
  1648. for (i = 1; i < inst->NumInOperands(); ++i) {
  1649. if (i + 1 >= cinst->NumInOperands()) {
  1650. break;
  1651. }
  1652. if (inst->GetSingleWordInOperand(i) !=
  1653. cinst->GetSingleWordInOperand(i + 1)) {
  1654. break;
  1655. }
  1656. }
  1657. // We are extracting the element that was inserted.
  1658. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1659. inst->SetOpcode(SpvOpCopyObject);
  1660. inst->SetInOperands(
  1661. {{SPV_OPERAND_TYPE_ID,
  1662. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1663. return true;
  1664. }
  1665. // Extracting the value that was inserted along with values for the base
  1666. // composite. Cannot do anything.
  1667. if (i == inst->NumInOperands()) {
  1668. return false;
  1669. }
  1670. // Extracting an element of the value that was inserted. Extract from
  1671. // that value directly.
  1672. if (i + 1 == cinst->NumInOperands()) {
  1673. std::vector<Operand> operands;
  1674. operands.push_back(
  1675. {SPV_OPERAND_TYPE_ID,
  1676. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1677. for (; i < inst->NumInOperands(); ++i) {
  1678. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1679. {inst->GetSingleWordInOperand(i)}});
  1680. }
  1681. inst->SetInOperands(std::move(operands));
  1682. return true;
  1683. }
  1684. // Extracting a value that is disjoint from the element being inserted.
  1685. // Rewrite the extract to use the composite input to the insert.
  1686. std::vector<Operand> operands;
  1687. operands.push_back(
  1688. {SPV_OPERAND_TYPE_ID,
  1689. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1690. for (i = 1; i < inst->NumInOperands(); ++i) {
  1691. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1692. {inst->GetSingleWordInOperand(i)}});
  1693. }
  1694. inst->SetInOperands(std::move(operands));
  1695. return true;
  1696. };
  1697. }
  1698. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1699. // operands of the VectorShuffle. We just need to adjust the index in the
  1700. // extract instruction.
  1701. FoldingRule VectorShuffleFeedingExtract() {
  1702. return [](IRContext* context, Instruction* inst,
  1703. const std::vector<const analysis::Constant*>&) {
  1704. assert(inst->opcode() == SpvOpCompositeExtract &&
  1705. "Wrong opcode. Should be OpCompositeExtract.");
  1706. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1707. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1708. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1709. Instruction* cinst = def_use_mgr->GetDef(cid);
  1710. if (cinst->opcode() != SpvOpVectorShuffle) {
  1711. return false;
  1712. }
  1713. // Find the size of the first vector operand of the VectorShuffle
  1714. Instruction* first_input =
  1715. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1716. analysis::Type* first_input_type =
  1717. type_mgr->GetType(first_input->type_id());
  1718. assert(first_input_type->AsVector() &&
  1719. "Input to vector shuffle should be vectors.");
  1720. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1721. // Get index of the element the vector shuffle is placing in the position
  1722. // being extracted.
  1723. uint32_t new_index =
  1724. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1725. // Extracting an undefined value so fold this extract into an undef.
  1726. const uint32_t undef_literal_value = 0xffffffff;
  1727. if (new_index == undef_literal_value) {
  1728. inst->SetOpcode(SpvOpUndef);
  1729. inst->SetInOperands({});
  1730. return true;
  1731. }
  1732. // Get the id of the of the vector the elemtent comes from, and update the
  1733. // index if needed.
  1734. uint32_t new_vector = 0;
  1735. if (new_index < first_input_size) {
  1736. new_vector = cinst->GetSingleWordInOperand(0);
  1737. } else {
  1738. new_vector = cinst->GetSingleWordInOperand(1);
  1739. new_index -= first_input_size;
  1740. }
  1741. // Update the extract instruction.
  1742. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1743. inst->SetInOperand(1, {new_index});
  1744. return true;
  1745. };
  1746. }
  1747. // When an FMix with is feeding an Extract that extracts an element whose
  1748. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1749. // operands of the FMix.
  1750. FoldingRule FMixFeedingExtract() {
  1751. return [](IRContext* context, Instruction* inst,
  1752. const std::vector<const analysis::Constant*>&) {
  1753. assert(inst->opcode() == SpvOpCompositeExtract &&
  1754. "Wrong opcode. Should be OpCompositeExtract.");
  1755. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1756. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1757. uint32_t composite_id =
  1758. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1759. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1760. if (composite_inst->opcode() != SpvOpExtInst) {
  1761. return false;
  1762. }
  1763. uint32_t inst_set_id =
  1764. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1765. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1766. inst_set_id ||
  1767. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1768. GLSLstd450FMix) {
  1769. return false;
  1770. }
  1771. // Get the |a| for the FMix instruction.
  1772. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1773. std::unique_ptr<Instruction> a(inst->Clone(context));
  1774. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1775. context->get_instruction_folder().FoldInstruction(a.get());
  1776. if (a->opcode() != SpvOpCopyObject) {
  1777. return false;
  1778. }
  1779. const analysis::Constant* a_const =
  1780. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1781. if (!a_const) {
  1782. return false;
  1783. }
  1784. bool use_x = false;
  1785. assert(a_const->type()->AsFloat());
  1786. double element_value = a_const->GetValueAsDouble();
  1787. if (element_value == 0.0) {
  1788. use_x = true;
  1789. } else if (element_value == 1.0) {
  1790. use_x = false;
  1791. } else {
  1792. return false;
  1793. }
  1794. // Get the id of the of the vector the element comes from.
  1795. uint32_t new_vector = 0;
  1796. if (use_x) {
  1797. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1798. } else {
  1799. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1800. }
  1801. // Update the extract instruction.
  1802. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1803. return true;
  1804. };
  1805. }
  1806. // Returns the number of elements in the composite type |type|. Returns 0 if
  1807. // |type| is a scalar value.
  1808. uint32_t GetNumberOfElements(const analysis::Type* type) {
  1809. if (auto* vector_type = type->AsVector()) {
  1810. return vector_type->element_count();
  1811. }
  1812. if (auto* matrix_type = type->AsMatrix()) {
  1813. return matrix_type->element_count();
  1814. }
  1815. if (auto* struct_type = type->AsStruct()) {
  1816. return static_cast<uint32_t>(struct_type->element_types().size());
  1817. }
  1818. if (auto* array_type = type->AsArray()) {
  1819. return array_type->length_info().words[0];
  1820. }
  1821. return 0;
  1822. }
  1823. // Returns a map with the set of values that were inserted into an object by
  1824. // the chain of OpCompositeInsertInstruction starting with |inst|.
  1825. // The map will map the index to the value inserted at that index.
  1826. std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
  1827. analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
  1828. std::map<uint32_t, uint32_t> values_inserted;
  1829. Instruction* current_inst = inst;
  1830. while (current_inst->opcode() == SpvOpCompositeInsert) {
  1831. if (current_inst->NumInOperands() > inst->NumInOperands()) {
  1832. // This is the catch the case
  1833. // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
  1834. // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
  1835. // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
  1836. // In this case we cannot do a single construct to get the matrix.
  1837. uint32_t partially_inserted_element_index =
  1838. current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
  1839. if (values_inserted.count(partially_inserted_element_index) == 0)
  1840. return {};
  1841. }
  1842. if (HaveSameIndexesExceptForLast(inst, current_inst)) {
  1843. values_inserted.insert(
  1844. {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
  1845. 1),
  1846. current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
  1847. }
  1848. current_inst = def_use_mgr->GetDef(
  1849. current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
  1850. }
  1851. return values_inserted;
  1852. }
  1853. // Returns true of there is an entry in |values_inserted| for every element of
  1854. // |Type|.
  1855. bool DoInsertedValuesCoverEntireObject(
  1856. const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
  1857. uint32_t container_size = GetNumberOfElements(type);
  1858. if (container_size != values_inserted.size()) {
  1859. return false;
  1860. }
  1861. if (values_inserted.rbegin()->first >= container_size) {
  1862. return false;
  1863. }
  1864. return true;
  1865. }
  1866. // Returns the type of the element that immediately contains the element being
  1867. // inserted by the OpCompositeInsert instruction |inst|.
  1868. const analysis::Type* GetContainerType(Instruction* inst) {
  1869. assert(inst->opcode() == SpvOpCompositeInsert);
  1870. analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
  1871. return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
  1872. type_mgr);
  1873. }
  1874. // Returns an OpCompositeConstruct instruction that build an object with
  1875. // |type_id| out of the values in |values_inserted|. Each value will be
  1876. // placed at the index corresponding to the value. The new instruction will
  1877. // be placed before |insert_before|.
  1878. Instruction* BuildCompositeConstruct(
  1879. uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
  1880. Instruction* insert_before) {
  1881. InstructionBuilder ir_builder(
  1882. insert_before->context(), insert_before,
  1883. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1884. std::vector<uint32_t> ids_in_order;
  1885. for (auto it : values_inserted) {
  1886. ids_in_order.push_back(it.second);
  1887. }
  1888. Instruction* construct =
  1889. ir_builder.AddCompositeConstruct(type_id, ids_in_order);
  1890. return construct;
  1891. }
  1892. // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
  1893. // object as |inst| with final index removed. If the resulting
  1894. // OpCompositeInsert instruction would have no remaining indexes, the
  1895. // instruction is replaced with an OpCopyObject instead.
  1896. void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
  1897. if (inst->NumInOperands() == 3) {
  1898. inst->SetOpcode(SpvOpCopyObject);
  1899. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
  1900. } else {
  1901. inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
  1902. inst->RemoveOperand(inst->NumOperands() - 1);
  1903. }
  1904. }
  1905. // Replaces a series of |OpCompositeInsert| instruction that cover the entire
  1906. // object with an |OpCompositeConstruct|.
  1907. bool CompositeInsertToCompositeConstruct(
  1908. IRContext* context, Instruction* inst,
  1909. const std::vector<const analysis::Constant*>&) {
  1910. assert(inst->opcode() == SpvOpCompositeInsert &&
  1911. "Wrong opcode. Should be OpCompositeInsert.");
  1912. if (inst->NumInOperands() < 3) return false;
  1913. std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
  1914. const analysis::Type* container_type = GetContainerType(inst);
  1915. if (container_type == nullptr) {
  1916. return false;
  1917. }
  1918. if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
  1919. return false;
  1920. }
  1921. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1922. Instruction* construct = BuildCompositeConstruct(
  1923. type_mgr->GetId(container_type), values_inserted, inst);
  1924. InsertConstructedObject(inst, construct);
  1925. return true;
  1926. }
  1927. FoldingRule RedundantPhi() {
  1928. // An OpPhi instruction where all values are the same or the result of the phi
  1929. // itself, can be replaced by the value itself.
  1930. return [](IRContext*, Instruction* inst,
  1931. const std::vector<const analysis::Constant*>&) {
  1932. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1933. uint32_t incoming_value = 0;
  1934. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1935. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1936. if (op_id == inst->result_id()) {
  1937. continue;
  1938. }
  1939. if (incoming_value == 0) {
  1940. incoming_value = op_id;
  1941. } else if (op_id != incoming_value) {
  1942. // Found two possible value. Can't simplify.
  1943. return false;
  1944. }
  1945. }
  1946. if (incoming_value == 0) {
  1947. // Code looks invalid. Don't do anything.
  1948. return false;
  1949. }
  1950. // We have a single incoming value. Simplify using that value.
  1951. inst->SetOpcode(SpvOpCopyObject);
  1952. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1953. return true;
  1954. };
  1955. }
  1956. FoldingRule BitCastScalarOrVector() {
  1957. return [](IRContext* context, Instruction* inst,
  1958. const std::vector<const analysis::Constant*>& constants) {
  1959. assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
  1960. if (constants[0] == nullptr) return false;
  1961. const analysis::Type* type =
  1962. context->get_type_mgr()->GetType(inst->type_id());
  1963. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  1964. return false;
  1965. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1966. std::vector<uint32_t> words =
  1967. GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
  1968. if (words.size() == 0) return false;
  1969. const analysis::Constant* bitcasted_constant =
  1970. ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
  1971. if (!bitcasted_constant) return false;
  1972. auto new_feeder_id =
  1973. const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
  1974. ->result_id();
  1975. inst->SetOpcode(SpvOpCopyObject);
  1976. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
  1977. return true;
  1978. };
  1979. }
  1980. FoldingRule RedundantSelect() {
  1981. // An OpSelect instruction where both values are the same or the condition is
  1982. // constant can be replaced by one of the values
  1983. return [](IRContext*, Instruction* inst,
  1984. const std::vector<const analysis::Constant*>& constants) {
  1985. assert(inst->opcode() == SpvOpSelect &&
  1986. "Wrong opcode. Should be OpSelect.");
  1987. assert(inst->NumInOperands() == 3);
  1988. assert(constants.size() == 3);
  1989. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1990. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1991. if (true_id == false_id) {
  1992. // Both results are the same, condition doesn't matter
  1993. inst->SetOpcode(SpvOpCopyObject);
  1994. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1995. return true;
  1996. } else if (constants[0]) {
  1997. const analysis::Type* type = constants[0]->type();
  1998. if (type->AsBool()) {
  1999. // Scalar constant value, select the corresponding value.
  2000. inst->SetOpcode(SpvOpCopyObject);
  2001. if (constants[0]->AsNullConstant() ||
  2002. !constants[0]->AsBoolConstant()->value()) {
  2003. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  2004. } else {
  2005. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  2006. }
  2007. return true;
  2008. } else {
  2009. assert(type->AsVector());
  2010. if (constants[0]->AsNullConstant()) {
  2011. // All values come from false id.
  2012. inst->SetOpcode(SpvOpCopyObject);
  2013. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  2014. return true;
  2015. } else {
  2016. // Convert to a vector shuffle.
  2017. std::vector<Operand> ops;
  2018. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  2019. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  2020. const analysis::VectorConstant* vector_const =
  2021. constants[0]->AsVectorConstant();
  2022. uint32_t size =
  2023. static_cast<uint32_t>(vector_const->GetComponents().size());
  2024. for (uint32_t i = 0; i != size; ++i) {
  2025. const analysis::Constant* component =
  2026. vector_const->GetComponents()[i];
  2027. if (component->AsNullConstant() ||
  2028. !component->AsBoolConstant()->value()) {
  2029. // Selecting from the false vector which is the second input
  2030. // vector to the shuffle. Offset the index by |size|.
  2031. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  2032. } else {
  2033. // Selecting from true vector which is the first input vector to
  2034. // the shuffle.
  2035. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  2036. }
  2037. }
  2038. inst->SetOpcode(SpvOpVectorShuffle);
  2039. inst->SetInOperands(std::move(ops));
  2040. return true;
  2041. }
  2042. }
  2043. }
  2044. return false;
  2045. };
  2046. }
  2047. enum class FloatConstantKind { Unknown, Zero, One };
  2048. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  2049. if (constant == nullptr) {
  2050. return FloatConstantKind::Unknown;
  2051. }
  2052. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  2053. if (constant->AsNullConstant()) {
  2054. return FloatConstantKind::Zero;
  2055. } else if (const analysis::VectorConstant* vc =
  2056. constant->AsVectorConstant()) {
  2057. const std::vector<const analysis::Constant*>& components =
  2058. vc->GetComponents();
  2059. assert(!components.empty());
  2060. FloatConstantKind kind = getFloatConstantKind(components[0]);
  2061. for (size_t i = 1; i < components.size(); ++i) {
  2062. if (getFloatConstantKind(components[i]) != kind) {
  2063. return FloatConstantKind::Unknown;
  2064. }
  2065. }
  2066. return kind;
  2067. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  2068. if (fc->IsZero()) return FloatConstantKind::Zero;
  2069. uint32_t width = fc->type()->AsFloat()->width();
  2070. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  2071. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  2072. if (value == 0.0) {
  2073. return FloatConstantKind::Zero;
  2074. } else if (value == 1.0) {
  2075. return FloatConstantKind::One;
  2076. } else {
  2077. return FloatConstantKind::Unknown;
  2078. }
  2079. } else {
  2080. return FloatConstantKind::Unknown;
  2081. }
  2082. }
  2083. FoldingRule RedundantFAdd() {
  2084. return [](IRContext*, Instruction* inst,
  2085. const std::vector<const analysis::Constant*>& constants) {
  2086. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  2087. assert(constants.size() == 2);
  2088. if (!inst->IsFloatingPointFoldingAllowed()) {
  2089. return false;
  2090. }
  2091. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2092. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2093. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  2094. inst->SetOpcode(SpvOpCopyObject);
  2095. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2096. {inst->GetSingleWordInOperand(
  2097. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  2098. return true;
  2099. }
  2100. return false;
  2101. };
  2102. }
  2103. FoldingRule RedundantFSub() {
  2104. return [](IRContext*, Instruction* inst,
  2105. const std::vector<const analysis::Constant*>& constants) {
  2106. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  2107. assert(constants.size() == 2);
  2108. if (!inst->IsFloatingPointFoldingAllowed()) {
  2109. return false;
  2110. }
  2111. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2112. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2113. if (kind0 == FloatConstantKind::Zero) {
  2114. inst->SetOpcode(SpvOpFNegate);
  2115. inst->SetInOperands(
  2116. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  2117. return true;
  2118. }
  2119. if (kind1 == FloatConstantKind::Zero) {
  2120. inst->SetOpcode(SpvOpCopyObject);
  2121. inst->SetInOperands(
  2122. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2123. return true;
  2124. }
  2125. return false;
  2126. };
  2127. }
  2128. FoldingRule RedundantFMul() {
  2129. return [](IRContext*, Instruction* inst,
  2130. const std::vector<const analysis::Constant*>& constants) {
  2131. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  2132. assert(constants.size() == 2);
  2133. if (!inst->IsFloatingPointFoldingAllowed()) {
  2134. return false;
  2135. }
  2136. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2137. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2138. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  2139. inst->SetOpcode(SpvOpCopyObject);
  2140. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2141. {inst->GetSingleWordInOperand(
  2142. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  2143. return true;
  2144. }
  2145. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  2146. inst->SetOpcode(SpvOpCopyObject);
  2147. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2148. {inst->GetSingleWordInOperand(
  2149. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  2150. return true;
  2151. }
  2152. return false;
  2153. };
  2154. }
  2155. FoldingRule RedundantFDiv() {
  2156. return [](IRContext*, Instruction* inst,
  2157. const std::vector<const analysis::Constant*>& constants) {
  2158. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  2159. assert(constants.size() == 2);
  2160. if (!inst->IsFloatingPointFoldingAllowed()) {
  2161. return false;
  2162. }
  2163. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2164. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2165. if (kind0 == FloatConstantKind::Zero) {
  2166. inst->SetOpcode(SpvOpCopyObject);
  2167. inst->SetInOperands(
  2168. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2169. return true;
  2170. }
  2171. if (kind1 == FloatConstantKind::One) {
  2172. inst->SetOpcode(SpvOpCopyObject);
  2173. inst->SetInOperands(
  2174. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2175. return true;
  2176. }
  2177. return false;
  2178. };
  2179. }
  2180. FoldingRule RedundantFMix() {
  2181. return [](IRContext* context, Instruction* inst,
  2182. const std::vector<const analysis::Constant*>& constants) {
  2183. assert(inst->opcode() == SpvOpExtInst &&
  2184. "Wrong opcode. Should be OpExtInst.");
  2185. if (!inst->IsFloatingPointFoldingAllowed()) {
  2186. return false;
  2187. }
  2188. uint32_t instSetId =
  2189. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  2190. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  2191. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  2192. GLSLstd450FMix) {
  2193. assert(constants.size() == 5);
  2194. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  2195. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  2196. inst->SetOpcode(SpvOpCopyObject);
  2197. inst->SetInOperands(
  2198. {{SPV_OPERAND_TYPE_ID,
  2199. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  2200. ? kFMixXIdInIdx
  2201. : kFMixYIdInIdx)}}});
  2202. return true;
  2203. }
  2204. }
  2205. return false;
  2206. };
  2207. }
  2208. // This rule handles addition of zero for integers.
  2209. FoldingRule RedundantIAdd() {
  2210. return [](IRContext* context, Instruction* inst,
  2211. const std::vector<const analysis::Constant*>& constants) {
  2212. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  2213. uint32_t operand = std::numeric_limits<uint32_t>::max();
  2214. const analysis::Type* operand_type = nullptr;
  2215. if (constants[0] && constants[0]->IsZero()) {
  2216. operand = inst->GetSingleWordInOperand(1);
  2217. operand_type = constants[0]->type();
  2218. } else if (constants[1] && constants[1]->IsZero()) {
  2219. operand = inst->GetSingleWordInOperand(0);
  2220. operand_type = constants[1]->type();
  2221. }
  2222. if (operand != std::numeric_limits<uint32_t>::max()) {
  2223. const analysis::Type* inst_type =
  2224. context->get_type_mgr()->GetType(inst->type_id());
  2225. if (inst_type->IsSame(operand_type)) {
  2226. inst->SetOpcode(SpvOpCopyObject);
  2227. } else {
  2228. inst->SetOpcode(SpvOpBitcast);
  2229. }
  2230. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  2231. return true;
  2232. }
  2233. return false;
  2234. };
  2235. }
  2236. // This rule look for a dot with a constant vector containing a single 1 and
  2237. // the rest 0s. This is the same as doing an extract.
  2238. FoldingRule DotProductDoingExtract() {
  2239. return [](IRContext* context, Instruction* inst,
  2240. const std::vector<const analysis::Constant*>& constants) {
  2241. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  2242. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2243. if (!inst->IsFloatingPointFoldingAllowed()) {
  2244. return false;
  2245. }
  2246. for (int i = 0; i < 2; ++i) {
  2247. if (!constants[i]) {
  2248. continue;
  2249. }
  2250. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  2251. assert(vector_type && "Inputs to OpDot must be vectors.");
  2252. const analysis::Float* element_type =
  2253. vector_type->element_type()->AsFloat();
  2254. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  2255. uint32_t element_width = element_type->width();
  2256. if (element_width != 32 && element_width != 64) {
  2257. return false;
  2258. }
  2259. std::vector<const analysis::Constant*> components;
  2260. components = constants[i]->GetVectorComponents(const_mgr);
  2261. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  2262. uint32_t component_with_one = kNotFound;
  2263. bool all_others_zero = true;
  2264. for (uint32_t j = 0; j < components.size(); ++j) {
  2265. const analysis::Constant* element = components[j];
  2266. double value =
  2267. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  2268. if (value == 0.0) {
  2269. continue;
  2270. } else if (value == 1.0) {
  2271. if (component_with_one == kNotFound) {
  2272. component_with_one = j;
  2273. } else {
  2274. component_with_one = kNotFound;
  2275. break;
  2276. }
  2277. } else {
  2278. all_others_zero = false;
  2279. break;
  2280. }
  2281. }
  2282. if (!all_others_zero || component_with_one == kNotFound) {
  2283. continue;
  2284. }
  2285. std::vector<Operand> operands;
  2286. operands.push_back(
  2287. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  2288. operands.push_back(
  2289. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  2290. inst->SetOpcode(SpvOpCompositeExtract);
  2291. inst->SetInOperands(std::move(operands));
  2292. return true;
  2293. }
  2294. return false;
  2295. };
  2296. }
  2297. // If we are storing an undef, then we can remove the store.
  2298. //
  2299. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  2300. // is complicated. Waiting to see if it is needed.
  2301. FoldingRule StoringUndef() {
  2302. return [](IRContext* context, Instruction* inst,
  2303. const std::vector<const analysis::Constant*>&) {
  2304. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  2305. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  2306. // If this is a volatile store, the store cannot be removed.
  2307. if (inst->NumInOperands() == 3) {
  2308. if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
  2309. return false;
  2310. }
  2311. }
  2312. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  2313. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  2314. if (object_inst->opcode() == SpvOpUndef) {
  2315. inst->ToNop();
  2316. return true;
  2317. }
  2318. return false;
  2319. };
  2320. }
  2321. FoldingRule VectorShuffleFeedingShuffle() {
  2322. return [](IRContext* context, Instruction* inst,
  2323. const std::vector<const analysis::Constant*>&) {
  2324. assert(inst->opcode() == SpvOpVectorShuffle &&
  2325. "Wrong opcode. Should be OpVectorShuffle.");
  2326. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  2327. analysis::TypeManager* type_mgr = context->get_type_mgr();
  2328. Instruction* feeding_shuffle_inst =
  2329. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  2330. analysis::Vector* op0_type =
  2331. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  2332. uint32_t op0_length = op0_type->element_count();
  2333. bool feeder_is_op0 = true;
  2334. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  2335. feeding_shuffle_inst =
  2336. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  2337. feeder_is_op0 = false;
  2338. }
  2339. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  2340. return false;
  2341. }
  2342. Instruction* feeder2 =
  2343. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  2344. analysis::Vector* feeder_op0_type =
  2345. type_mgr->GetType(feeder2->type_id())->AsVector();
  2346. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  2347. uint32_t new_feeder_id = 0;
  2348. std::vector<Operand> new_operands;
  2349. new_operands.resize(
  2350. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  2351. const uint32_t undef_literal = 0xffffffff;
  2352. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  2353. uint32_t component_index = inst->GetSingleWordInOperand(op);
  2354. // Do not interpret the undefined value literal as coming from operand 1.
  2355. if (component_index != undef_literal &&
  2356. feeder_is_op0 == (component_index < op0_length)) {
  2357. // This component comes from the feeding_shuffle_inst. Update
  2358. // |component_index| to be the index into the operand of the feeder.
  2359. // Adjust component_index to get the index into the operands of the
  2360. // feeding_shuffle_inst.
  2361. if (component_index >= op0_length) {
  2362. component_index -= op0_length;
  2363. }
  2364. component_index =
  2365. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  2366. // Check if we are using a component from the first or second operand of
  2367. // the feeding instruction.
  2368. if (component_index < feeder_op0_length) {
  2369. if (new_feeder_id == 0) {
  2370. // First time through, save the id of the operand the element comes
  2371. // from.
  2372. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  2373. } else if (new_feeder_id !=
  2374. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  2375. // We need both elements of the feeding_shuffle_inst, so we cannot
  2376. // fold.
  2377. return false;
  2378. }
  2379. } else if (component_index != undef_literal) {
  2380. if (new_feeder_id == 0) {
  2381. // First time through, save the id of the operand the element comes
  2382. // from.
  2383. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  2384. } else if (new_feeder_id !=
  2385. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  2386. // We need both elements of the feeding_shuffle_inst, so we cannot
  2387. // fold.
  2388. return false;
  2389. }
  2390. component_index -= feeder_op0_length;
  2391. }
  2392. if (!feeder_is_op0 && component_index != undef_literal) {
  2393. component_index += op0_length;
  2394. }
  2395. }
  2396. new_operands.push_back(
  2397. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  2398. }
  2399. if (new_feeder_id == 0) {
  2400. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2401. const analysis::Type* type =
  2402. type_mgr->GetType(feeding_shuffle_inst->type_id());
  2403. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  2404. new_feeder_id =
  2405. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  2406. }
  2407. if (feeder_is_op0) {
  2408. // If the size of the first vector operand changed then the indices
  2409. // referring to the second operand need to be adjusted.
  2410. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  2411. analysis::Type* new_feeder_type =
  2412. type_mgr->GetType(new_feeder_inst->type_id());
  2413. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  2414. int32_t adjustment = op0_length - new_op0_size;
  2415. if (adjustment != 0) {
  2416. for (uint32_t i = 2; i < new_operands.size(); i++) {
  2417. uint32_t operand = inst->GetSingleWordInOperand(i);
  2418. if (operand >= op0_length && operand != undef_literal) {
  2419. new_operands[i].words[0] -= adjustment;
  2420. }
  2421. }
  2422. }
  2423. new_operands[0].words[0] = new_feeder_id;
  2424. new_operands[1] = inst->GetInOperand(1);
  2425. } else {
  2426. new_operands[1].words[0] = new_feeder_id;
  2427. new_operands[0] = inst->GetInOperand(0);
  2428. }
  2429. inst->SetInOperands(std::move(new_operands));
  2430. return true;
  2431. };
  2432. }
  2433. // Removes duplicate ids from the interface list of an OpEntryPoint
  2434. // instruction.
  2435. FoldingRule RemoveRedundantOperands() {
  2436. return [](IRContext*, Instruction* inst,
  2437. const std::vector<const analysis::Constant*>&) {
  2438. assert(inst->opcode() == SpvOpEntryPoint &&
  2439. "Wrong opcode. Should be OpEntryPoint.");
  2440. bool has_redundant_operand = false;
  2441. std::unordered_set<uint32_t> seen_operands;
  2442. std::vector<Operand> new_operands;
  2443. new_operands.emplace_back(inst->GetOperand(0));
  2444. new_operands.emplace_back(inst->GetOperand(1));
  2445. new_operands.emplace_back(inst->GetOperand(2));
  2446. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2447. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2448. new_operands.emplace_back(inst->GetOperand(i));
  2449. } else {
  2450. has_redundant_operand = true;
  2451. }
  2452. }
  2453. if (!has_redundant_operand) {
  2454. return false;
  2455. }
  2456. inst->SetInOperands(std::move(new_operands));
  2457. return true;
  2458. };
  2459. }
  2460. // If an image instruction's operand is a constant, updates the image operand
  2461. // flag from Offset to ConstOffset.
  2462. FoldingRule UpdateImageOperands() {
  2463. return [](IRContext*, Instruction* inst,
  2464. const std::vector<const analysis::Constant*>& constants) {
  2465. const auto opcode = inst->opcode();
  2466. (void)opcode;
  2467. assert((opcode == SpvOpImageSampleImplicitLod ||
  2468. opcode == SpvOpImageSampleExplicitLod ||
  2469. opcode == SpvOpImageSampleDrefImplicitLod ||
  2470. opcode == SpvOpImageSampleDrefExplicitLod ||
  2471. opcode == SpvOpImageSampleProjImplicitLod ||
  2472. opcode == SpvOpImageSampleProjExplicitLod ||
  2473. opcode == SpvOpImageSampleProjDrefImplicitLod ||
  2474. opcode == SpvOpImageSampleProjDrefExplicitLod ||
  2475. opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
  2476. opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
  2477. opcode == SpvOpImageWrite ||
  2478. opcode == SpvOpImageSparseSampleImplicitLod ||
  2479. opcode == SpvOpImageSparseSampleExplicitLod ||
  2480. opcode == SpvOpImageSparseSampleDrefImplicitLod ||
  2481. opcode == SpvOpImageSparseSampleDrefExplicitLod ||
  2482. opcode == SpvOpImageSparseSampleProjImplicitLod ||
  2483. opcode == SpvOpImageSparseSampleProjExplicitLod ||
  2484. opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
  2485. opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
  2486. opcode == SpvOpImageSparseFetch ||
  2487. opcode == SpvOpImageSparseGather ||
  2488. opcode == SpvOpImageSparseDrefGather ||
  2489. opcode == SpvOpImageSparseRead) &&
  2490. "Wrong opcode. Should be an image instruction.");
  2491. int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
  2492. if (operand_index >= 0) {
  2493. auto image_operands = inst->GetSingleWordInOperand(operand_index);
  2494. if (image_operands & SpvImageOperandsOffsetMask) {
  2495. uint32_t offset_operand_index = operand_index + 1;
  2496. if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
  2497. if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
  2498. if (image_operands & SpvImageOperandsGradMask)
  2499. offset_operand_index += 2;
  2500. assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
  2501. "Offset and ConstOffset may not be used together");
  2502. if (offset_operand_index < inst->NumOperands()) {
  2503. if (constants[offset_operand_index]) {
  2504. image_operands = image_operands | SpvImageOperandsConstOffsetMask;
  2505. image_operands = image_operands & ~SpvImageOperandsOffsetMask;
  2506. inst->SetInOperand(operand_index, {image_operands});
  2507. return true;
  2508. }
  2509. }
  2510. }
  2511. }
  2512. return false;
  2513. };
  2514. }
  2515. } // namespace
  2516. void FoldingRules::AddFoldingRules() {
  2517. // Add all folding rules to the list for the opcodes to which they apply.
  2518. // Note that the order in which rules are added to the list matters. If a rule
  2519. // applies to the instruction, the rest of the rules will not be attempted.
  2520. // Take that into consideration.
  2521. rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
  2522. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
  2523. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  2524. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract);
  2525. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2526. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  2527. rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
  2528. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  2529. rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
  2530. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  2531. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  2532. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  2533. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  2534. rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
  2535. rules_[SpvOpFAdd].push_back(FactorAddMuls());
  2536. rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
  2537. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  2538. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  2539. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  2540. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  2541. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  2542. rules_[SpvOpFMul].push_back(RedundantFMul());
  2543. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  2544. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  2545. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  2546. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  2547. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  2548. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  2549. rules_[SpvOpFSub].push_back(RedundantFSub());
  2550. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  2551. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  2552. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  2553. rules_[SpvOpFSub].push_back(MergeMulSubArithmetic);
  2554. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  2555. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  2556. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  2557. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  2558. rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
  2559. rules_[SpvOpIAdd].push_back(FactorAddMuls());
  2560. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  2561. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  2562. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  2563. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  2564. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  2565. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  2566. rules_[SpvOpPhi].push_back(RedundantPhi());
  2567. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  2568. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  2569. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  2570. rules_[SpvOpSelect].push_back(RedundantSelect());
  2571. rules_[SpvOpStore].push_back(StoringUndef());
  2572. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2573. rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
  2574. rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
  2575. rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
  2576. rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
  2577. rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
  2578. rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
  2579. rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
  2580. rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
  2581. rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
  2582. rules_[SpvOpImageGather].push_back(UpdateImageOperands());
  2583. rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
  2584. rules_[SpvOpImageRead].push_back(UpdateImageOperands());
  2585. rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
  2586. rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
  2587. rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
  2588. rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
  2589. UpdateImageOperands());
  2590. rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
  2591. UpdateImageOperands());
  2592. rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
  2593. UpdateImageOperands());
  2594. rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
  2595. UpdateImageOperands());
  2596. rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
  2597. UpdateImageOperands());
  2598. rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
  2599. UpdateImageOperands());
  2600. rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
  2601. rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
  2602. rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
  2603. rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
  2604. FeatureManager* feature_manager = context_->get_feature_mgr();
  2605. // Add rules for GLSLstd450
  2606. uint32_t ext_inst_glslstd450_id =
  2607. feature_manager->GetExtInstImportId_GLSLstd450();
  2608. if (ext_inst_glslstd450_id != 0) {
  2609. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2610. RedundantFMix());
  2611. }
  2612. }
  2613. } // namespace opt
  2614. } // namespace spvtools