folding_rules.cpp 110 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <climits>
  16. #include <limits>
  17. #include <memory>
  18. #include <utility>
  19. #include "ir_builder.h"
  20. #include "source/latest_version_glsl_std_450_header.h"
  21. #include "source/opt/ir_context.h"
  22. namespace spvtools {
  23. namespace opt {
  24. namespace {
  25. const uint32_t kExtractCompositeIdInIdx = 0;
  26. const uint32_t kInsertObjectIdInIdx = 0;
  27. const uint32_t kInsertCompositeIdInIdx = 1;
  28. const uint32_t kExtInstSetIdInIdx = 0;
  29. const uint32_t kExtInstInstructionInIdx = 1;
  30. const uint32_t kFMixXIdInIdx = 2;
  31. const uint32_t kFMixYIdInIdx = 3;
  32. const uint32_t kFMixAIdInIdx = 4;
  33. const uint32_t kStoreObjectInIdx = 1;
  34. // Some image instructions may contain an "image operands" argument.
  35. // Returns the operand index for the "image operands".
  36. // Returns -1 if the instruction does not have image operands.
  37. int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
  38. const auto opcode = inst->opcode();
  39. switch (opcode) {
  40. case SpvOpImageSampleImplicitLod:
  41. case SpvOpImageSampleExplicitLod:
  42. case SpvOpImageSampleProjImplicitLod:
  43. case SpvOpImageSampleProjExplicitLod:
  44. case SpvOpImageFetch:
  45. case SpvOpImageRead:
  46. case SpvOpImageSparseSampleImplicitLod:
  47. case SpvOpImageSparseSampleExplicitLod:
  48. case SpvOpImageSparseSampleProjImplicitLod:
  49. case SpvOpImageSparseSampleProjExplicitLod:
  50. case SpvOpImageSparseFetch:
  51. case SpvOpImageSparseRead:
  52. return inst->NumOperands() > 4 ? 2 : -1;
  53. case SpvOpImageSampleDrefImplicitLod:
  54. case SpvOpImageSampleDrefExplicitLod:
  55. case SpvOpImageSampleProjDrefImplicitLod:
  56. case SpvOpImageSampleProjDrefExplicitLod:
  57. case SpvOpImageGather:
  58. case SpvOpImageDrefGather:
  59. case SpvOpImageSparseSampleDrefImplicitLod:
  60. case SpvOpImageSparseSampleDrefExplicitLod:
  61. case SpvOpImageSparseSampleProjDrefImplicitLod:
  62. case SpvOpImageSparseSampleProjDrefExplicitLod:
  63. case SpvOpImageSparseGather:
  64. case SpvOpImageSparseDrefGather:
  65. return inst->NumOperands() > 5 ? 3 : -1;
  66. case SpvOpImageWrite:
  67. return inst->NumOperands() > 3 ? 3 : -1;
  68. default:
  69. return -1;
  70. }
  71. }
  72. // Returns the element width of |type|.
  73. uint32_t ElementWidth(const analysis::Type* type) {
  74. if (const analysis::Vector* vec_type = type->AsVector()) {
  75. return ElementWidth(vec_type->element_type());
  76. } else if (const analysis::Float* float_type = type->AsFloat()) {
  77. return float_type->width();
  78. } else {
  79. assert(type->AsInteger());
  80. return type->AsInteger()->width();
  81. }
  82. }
  83. // Returns true if |type| is Float or a vector of Float.
  84. bool HasFloatingPoint(const analysis::Type* type) {
  85. if (type->AsFloat()) {
  86. return true;
  87. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  88. return vec_type->element_type()->AsFloat() != nullptr;
  89. }
  90. return false;
  91. }
  92. // Returns false if |val| is NaN, infinite or subnormal.
  93. template <typename T>
  94. bool IsValidResult(T val) {
  95. int classified = std::fpclassify(val);
  96. switch (classified) {
  97. case FP_NAN:
  98. case FP_INFINITE:
  99. case FP_SUBNORMAL:
  100. return false;
  101. default:
  102. return true;
  103. }
  104. }
  105. const analysis::Constant* ConstInput(
  106. const std::vector<const analysis::Constant*>& constants) {
  107. return constants[0] ? constants[0] : constants[1];
  108. }
  109. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  110. Instruction* inst) {
  111. uint32_t in_op = c ? 1u : 0u;
  112. return context->get_def_use_mgr()->GetDef(
  113. inst->GetSingleWordInOperand(in_op));
  114. }
  115. std::vector<uint32_t> ExtractInts(uint64_t val) {
  116. std::vector<uint32_t> words;
  117. words.push_back(static_cast<uint32_t>(val));
  118. words.push_back(static_cast<uint32_t>(val >> 32));
  119. return words;
  120. }
  121. std::vector<uint32_t> GetWordsFromScalarIntConstant(
  122. const analysis::IntConstant* c) {
  123. assert(c != nullptr);
  124. uint32_t width = c->type()->AsInteger()->width();
  125. assert(width == 32 || width == 64);
  126. if (width == 64) {
  127. uint64_t uval = static_cast<uint64_t>(c->GetU64());
  128. return ExtractInts(uval);
  129. }
  130. return {c->GetU32()};
  131. }
  132. std::vector<uint32_t> GetWordsFromScalarFloatConstant(
  133. const analysis::FloatConstant* c) {
  134. assert(c != nullptr);
  135. uint32_t width = c->type()->AsFloat()->width();
  136. assert(width == 32 || width == 64);
  137. if (width == 64) {
  138. utils::FloatProxy<double> result(c->GetDouble());
  139. return result.GetWords();
  140. }
  141. utils::FloatProxy<float> result(c->GetFloat());
  142. return result.GetWords();
  143. }
  144. std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
  145. analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
  146. if (const auto* float_constant = c->AsFloatConstant()) {
  147. return GetWordsFromScalarFloatConstant(float_constant);
  148. } else if (const auto* int_constant = c->AsIntConstant()) {
  149. return GetWordsFromScalarIntConstant(int_constant);
  150. } else if (const auto* vec_constant = c->AsVectorConstant()) {
  151. std::vector<uint32_t> words;
  152. for (const auto* comp : vec_constant->GetComponents()) {
  153. auto comp_in_words =
  154. GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
  155. words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
  156. }
  157. return words;
  158. }
  159. return {};
  160. }
  161. const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
  162. analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
  163. const analysis::Type* type) {
  164. if (type->AsInteger() || type->AsFloat())
  165. return const_mgr->GetConstant(type, words);
  166. if (const auto* vec_type = type->AsVector())
  167. return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
  168. return nullptr;
  169. }
  170. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  171. // constant.
  172. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  173. const analysis::Constant* c) {
  174. assert(c);
  175. assert(c->type()->AsFloat());
  176. uint32_t width = c->type()->AsFloat()->width();
  177. assert(width == 32 || width == 64);
  178. std::vector<uint32_t> words;
  179. if (width == 64) {
  180. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  181. words = result.GetWords();
  182. } else {
  183. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  184. words = result.GetWords();
  185. }
  186. const analysis::Constant* negated_const =
  187. const_mgr->GetConstant(c->type(), std::move(words));
  188. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  189. }
  190. // Negates the integer constant |c|. Returns the id of the defining instruction.
  191. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  192. const analysis::Constant* c) {
  193. assert(c);
  194. assert(c->type()->AsInteger());
  195. uint32_t width = c->type()->AsInteger()->width();
  196. assert(width == 32 || width == 64);
  197. std::vector<uint32_t> words;
  198. if (width == 64) {
  199. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  200. words = ExtractInts(uval);
  201. } else {
  202. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  203. }
  204. const analysis::Constant* negated_const =
  205. const_mgr->GetConstant(c->type(), std::move(words));
  206. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  207. }
  208. // Negates the vector constant |c|. Returns the id of the defining instruction.
  209. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  210. const analysis::Constant* c) {
  211. assert(const_mgr && c);
  212. assert(c->type()->AsVector());
  213. if (c->AsNullConstant()) {
  214. // 0.0 vs -0.0 shouldn't matter.
  215. return const_mgr->GetDefiningInstruction(c)->result_id();
  216. } else {
  217. const analysis::Type* component_type =
  218. c->AsVectorConstant()->component_type();
  219. std::vector<uint32_t> words;
  220. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  221. if (component_type->AsFloat()) {
  222. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  223. } else {
  224. assert(component_type->AsInteger());
  225. words.push_back(NegateIntegerConstant(const_mgr, comp));
  226. }
  227. }
  228. const analysis::Constant* negated_const =
  229. const_mgr->GetConstant(c->type(), std::move(words));
  230. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  231. }
  232. }
  233. // Negates |c|. Returns the id of the defining instruction.
  234. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  235. const analysis::Constant* c) {
  236. if (c->type()->AsVector()) {
  237. return NegateVectorConstant(const_mgr, c);
  238. } else if (c->type()->AsFloat()) {
  239. return NegateFloatingPointConstant(const_mgr, c);
  240. } else {
  241. assert(c->type()->AsInteger());
  242. return NegateIntegerConstant(const_mgr, c);
  243. }
  244. }
  245. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  246. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  247. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  248. const analysis::Constant* c) {
  249. assert(const_mgr && c);
  250. assert(c->type()->AsFloat());
  251. uint32_t width = c->type()->AsFloat()->width();
  252. assert(width == 32 || width == 64);
  253. std::vector<uint32_t> words;
  254. if (c->IsZero()) {
  255. return 0;
  256. }
  257. if (width == 64) {
  258. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  259. if (!IsValidResult(result.getAsFloat())) return 0;
  260. words = result.GetWords();
  261. } else {
  262. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  263. if (!IsValidResult(result.getAsFloat())) return 0;
  264. words = result.GetWords();
  265. }
  266. const analysis::Constant* negated_const =
  267. const_mgr->GetConstant(c->type(), std::move(words));
  268. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  269. }
  270. // Replaces fdiv where second operand is constant with fmul.
  271. FoldingRule ReciprocalFDiv() {
  272. return [](IRContext* context, Instruction* inst,
  273. const std::vector<const analysis::Constant*>& constants) {
  274. assert(inst->opcode() == SpvOpFDiv);
  275. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  276. const analysis::Type* type =
  277. context->get_type_mgr()->GetType(inst->type_id());
  278. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  279. uint32_t width = ElementWidth(type);
  280. if (width != 32 && width != 64) return false;
  281. if (constants[1] != nullptr) {
  282. uint32_t id = 0;
  283. if (const analysis::VectorConstant* vector_const =
  284. constants[1]->AsVectorConstant()) {
  285. std::vector<uint32_t> neg_ids;
  286. for (auto& comp : vector_const->GetComponents()) {
  287. id = Reciprocal(const_mgr, comp);
  288. if (id == 0) return false;
  289. neg_ids.push_back(id);
  290. }
  291. const analysis::Constant* negated_const =
  292. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  293. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  294. } else if (constants[1]->AsFloatConstant()) {
  295. id = Reciprocal(const_mgr, constants[1]);
  296. if (id == 0) return false;
  297. } else {
  298. // Don't fold a null constant.
  299. return false;
  300. }
  301. inst->SetOpcode(SpvOpFMul);
  302. inst->SetInOperands(
  303. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  304. {SPV_OPERAND_TYPE_ID, {id}}});
  305. return true;
  306. }
  307. return false;
  308. };
  309. }
  310. // Elides consecutive negate instructions.
  311. FoldingRule MergeNegateArithmetic() {
  312. return [](IRContext* context, Instruction* inst,
  313. const std::vector<const analysis::Constant*>& constants) {
  314. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  315. (void)constants;
  316. const analysis::Type* type =
  317. context->get_type_mgr()->GetType(inst->type_id());
  318. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  319. return false;
  320. Instruction* op_inst =
  321. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  322. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  323. return false;
  324. if (op_inst->opcode() == inst->opcode()) {
  325. // Elide negates.
  326. inst->SetOpcode(SpvOpCopyObject);
  327. inst->SetInOperands(
  328. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  329. return true;
  330. }
  331. return false;
  332. };
  333. }
  334. // Merges negate into a mul or div operation if that operation contains a
  335. // constant operand.
  336. // Cases:
  337. // -(x * 2) = x * -2
  338. // -(2 * x) = x * -2
  339. // -(x / 2) = x / -2
  340. // -(2 / x) = -2 / x
  341. FoldingRule MergeNegateMulDivArithmetic() {
  342. return [](IRContext* context, Instruction* inst,
  343. const std::vector<const analysis::Constant*>& constants) {
  344. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  345. (void)constants;
  346. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  347. const analysis::Type* type =
  348. context->get_type_mgr()->GetType(inst->type_id());
  349. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  350. return false;
  351. Instruction* op_inst =
  352. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  353. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  354. return false;
  355. uint32_t width = ElementWidth(type);
  356. if (width != 32 && width != 64) return false;
  357. SpvOp opcode = op_inst->opcode();
  358. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  359. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  360. std::vector<const analysis::Constant*> op_constants =
  361. const_mgr->GetOperandConstants(op_inst);
  362. // Merge negate into mul or div if one operand is constant.
  363. if (op_constants[0] || op_constants[1]) {
  364. bool zero_is_variable = op_constants[0] == nullptr;
  365. const analysis::Constant* c = ConstInput(op_constants);
  366. uint32_t neg_id = NegateConstant(const_mgr, c);
  367. uint32_t non_const_id = zero_is_variable
  368. ? op_inst->GetSingleWordInOperand(0u)
  369. : op_inst->GetSingleWordInOperand(1u);
  370. // Change this instruction to a mul/div.
  371. inst->SetOpcode(op_inst->opcode());
  372. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  373. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  374. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  375. inst->SetInOperands(
  376. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  377. } else {
  378. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  379. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  380. }
  381. return true;
  382. }
  383. }
  384. return false;
  385. };
  386. }
  387. // Merges negate into a add or sub operation if that operation contains a
  388. // constant operand.
  389. // Cases:
  390. // -(x + 2) = -2 - x
  391. // -(2 + x) = -2 - x
  392. // -(x - 2) = 2 - x
  393. // -(2 - x) = x - 2
  394. FoldingRule MergeNegateAddSubArithmetic() {
  395. return [](IRContext* context, Instruction* inst,
  396. const std::vector<const analysis::Constant*>& constants) {
  397. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  398. (void)constants;
  399. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  400. const analysis::Type* type =
  401. context->get_type_mgr()->GetType(inst->type_id());
  402. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  403. return false;
  404. Instruction* op_inst =
  405. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  406. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  407. return false;
  408. uint32_t width = ElementWidth(type);
  409. if (width != 32 && width != 64) return false;
  410. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  411. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  412. std::vector<const analysis::Constant*> op_constants =
  413. const_mgr->GetOperandConstants(op_inst);
  414. if (op_constants[0] || op_constants[1]) {
  415. bool zero_is_variable = op_constants[0] == nullptr;
  416. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  417. (op_inst->opcode() == SpvOpIAdd);
  418. bool swap_operands = !is_add || zero_is_variable;
  419. bool negate_const = is_add;
  420. const analysis::Constant* c = ConstInput(op_constants);
  421. uint32_t const_id = 0;
  422. if (negate_const) {
  423. const_id = NegateConstant(const_mgr, c);
  424. } else {
  425. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  426. : op_inst->GetSingleWordInOperand(0u);
  427. }
  428. // Swap operands if necessary and make the instruction a subtraction.
  429. uint32_t op0 =
  430. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  431. uint32_t op1 =
  432. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  433. if (swap_operands) std::swap(op0, op1);
  434. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  435. inst->SetInOperands(
  436. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  437. return true;
  438. }
  439. }
  440. return false;
  441. };
  442. }
  443. // Returns true if |c| has a zero element.
  444. bool HasZero(const analysis::Constant* c) {
  445. if (c->AsNullConstant()) {
  446. return true;
  447. }
  448. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  449. for (auto& comp : vec_const->GetComponents())
  450. if (HasZero(comp)) return true;
  451. } else {
  452. assert(c->AsScalarConstant());
  453. return c->AsScalarConstant()->IsZero();
  454. }
  455. return false;
  456. }
  457. // Performs |input1| |opcode| |input2| and returns the merged constant result
  458. // id. Returns 0 if the result is not a valid value. The input types must be
  459. // Float.
  460. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  461. SpvOp opcode,
  462. const analysis::Constant* input1,
  463. const analysis::Constant* input2) {
  464. const analysis::Type* type = input1->type();
  465. assert(type->AsFloat());
  466. uint32_t width = type->AsFloat()->width();
  467. assert(width == 32 || width == 64);
  468. std::vector<uint32_t> words;
  469. #define FOLD_OP(op) \
  470. if (width == 64) { \
  471. utils::FloatProxy<double> val = \
  472. input1->GetDouble() op input2->GetDouble(); \
  473. double dval = val.getAsFloat(); \
  474. if (!IsValidResult(dval)) return 0; \
  475. words = val.GetWords(); \
  476. } else { \
  477. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  478. float fval = val.getAsFloat(); \
  479. if (!IsValidResult(fval)) return 0; \
  480. words = val.GetWords(); \
  481. } \
  482. static_assert(true, "require extra semicolon")
  483. switch (opcode) {
  484. case SpvOpFMul:
  485. FOLD_OP(*);
  486. break;
  487. case SpvOpFDiv:
  488. if (HasZero(input2)) return 0;
  489. FOLD_OP(/);
  490. break;
  491. case SpvOpFAdd:
  492. FOLD_OP(+);
  493. break;
  494. case SpvOpFSub:
  495. FOLD_OP(-);
  496. break;
  497. default:
  498. assert(false && "Unexpected operation");
  499. break;
  500. }
  501. #undef FOLD_OP
  502. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  503. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  504. }
  505. // Performs |input1| |opcode| |input2| and returns the merged constant result
  506. // id. Returns 0 if the result is not a valid value. The input types must be
  507. // Integers.
  508. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  509. SpvOp opcode, const analysis::Constant* input1,
  510. const analysis::Constant* input2) {
  511. assert(input1->type()->AsInteger());
  512. const analysis::Integer* type = input1->type()->AsInteger();
  513. uint32_t width = type->AsInteger()->width();
  514. assert(width == 32 || width == 64);
  515. std::vector<uint32_t> words;
  516. // Regardless of the sign of the constant, folding is performed on an unsigned
  517. // interpretation of the constant data. This avoids signed integer overflow
  518. // while folding, and works because sign is irrelevant for the IAdd, ISub and
  519. // IMul instructions.
  520. #define FOLD_OP(op) \
  521. if (width == 64) { \
  522. uint64_t val = input1->GetU64() op input2->GetU64(); \
  523. words = ExtractInts(val); \
  524. } else { \
  525. uint32_t val = input1->GetU32() op input2->GetU32(); \
  526. words.push_back(val); \
  527. } \
  528. static_assert(true, "require extra semicolon")
  529. switch (opcode) {
  530. case SpvOpIMul:
  531. FOLD_OP(*);
  532. break;
  533. case SpvOpSDiv:
  534. case SpvOpUDiv:
  535. assert(false && "Should not merge integer division");
  536. break;
  537. case SpvOpIAdd:
  538. FOLD_OP(+);
  539. break;
  540. case SpvOpISub:
  541. FOLD_OP(-);
  542. break;
  543. default:
  544. assert(false && "Unexpected operation");
  545. break;
  546. }
  547. #undef FOLD_OP
  548. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  549. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  550. }
  551. // Performs |input1| |opcode| |input2| and returns the merged constant result
  552. // id. Returns 0 if the result is not a valid value. The input types must be
  553. // Integers, Floats or Vectors of such.
  554. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  555. const analysis::Constant* input1,
  556. const analysis::Constant* input2) {
  557. assert(input1 && input2);
  558. const analysis::Type* type = input1->type();
  559. std::vector<uint32_t> words;
  560. if (const analysis::Vector* vector_type = type->AsVector()) {
  561. const analysis::Type* ele_type = vector_type->element_type();
  562. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  563. uint32_t id = 0;
  564. const analysis::Constant* input1_comp = nullptr;
  565. if (const analysis::VectorConstant* input1_vector =
  566. input1->AsVectorConstant()) {
  567. input1_comp = input1_vector->GetComponents()[i];
  568. } else {
  569. assert(input1->AsNullConstant());
  570. input1_comp = const_mgr->GetConstant(ele_type, {});
  571. }
  572. const analysis::Constant* input2_comp = nullptr;
  573. if (const analysis::VectorConstant* input2_vector =
  574. input2->AsVectorConstant()) {
  575. input2_comp = input2_vector->GetComponents()[i];
  576. } else {
  577. assert(input2->AsNullConstant());
  578. input2_comp = const_mgr->GetConstant(ele_type, {});
  579. }
  580. if (ele_type->AsFloat()) {
  581. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  582. input2_comp);
  583. } else {
  584. assert(ele_type->AsInteger());
  585. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  586. input2_comp);
  587. }
  588. if (id == 0) return 0;
  589. words.push_back(id);
  590. }
  591. const analysis::Constant* merged_const =
  592. const_mgr->GetConstant(type, words);
  593. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  594. } else if (type->AsFloat()) {
  595. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  596. } else {
  597. assert(type->AsInteger());
  598. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  599. }
  600. }
  601. // Merges consecutive multiplies where each contains one constant operand.
  602. // Cases:
  603. // 2 * (x * 2) = x * 4
  604. // 2 * (2 * x) = x * 4
  605. // (x * 2) * 2 = x * 4
  606. // (2 * x) * 2 = x * 4
  607. FoldingRule MergeMulMulArithmetic() {
  608. return [](IRContext* context, Instruction* inst,
  609. const std::vector<const analysis::Constant*>& constants) {
  610. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  611. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  612. const analysis::Type* type =
  613. context->get_type_mgr()->GetType(inst->type_id());
  614. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  615. return false;
  616. uint32_t width = ElementWidth(type);
  617. if (width != 32 && width != 64) return false;
  618. // Determine the constant input and the variable input in |inst|.
  619. const analysis::Constant* const_input1 = ConstInput(constants);
  620. if (!const_input1) return false;
  621. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  622. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  623. return false;
  624. if (other_inst->opcode() == inst->opcode()) {
  625. std::vector<const analysis::Constant*> other_constants =
  626. const_mgr->GetOperandConstants(other_inst);
  627. const analysis::Constant* const_input2 = ConstInput(other_constants);
  628. if (!const_input2) return false;
  629. bool other_first_is_variable = other_constants[0] == nullptr;
  630. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  631. const_input1, const_input2);
  632. if (merged_id == 0) return false;
  633. uint32_t non_const_id = other_first_is_variable
  634. ? other_inst->GetSingleWordInOperand(0u)
  635. : other_inst->GetSingleWordInOperand(1u);
  636. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  637. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  638. return true;
  639. }
  640. return false;
  641. };
  642. }
  643. // Merges divides into subsequent multiplies if each instruction contains one
  644. // constant operand. Does not support integer operations.
  645. // Cases:
  646. // 2 * (x / 2) = x * 1
  647. // 2 * (2 / x) = 4 / x
  648. // (x / 2) * 2 = x * 1
  649. // (2 / x) * 2 = 4 / x
  650. // (y / x) * x = y
  651. // x * (y / x) = y
  652. FoldingRule MergeMulDivArithmetic() {
  653. return [](IRContext* context, Instruction* inst,
  654. const std::vector<const analysis::Constant*>& constants) {
  655. assert(inst->opcode() == SpvOpFMul);
  656. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  657. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  658. const analysis::Type* type =
  659. context->get_type_mgr()->GetType(inst->type_id());
  660. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  661. uint32_t width = ElementWidth(type);
  662. if (width != 32 && width != 64) return false;
  663. for (uint32_t i = 0; i < 2; i++) {
  664. uint32_t op_id = inst->GetSingleWordInOperand(i);
  665. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  666. if (op_inst->opcode() == SpvOpFDiv) {
  667. if (op_inst->GetSingleWordInOperand(1) ==
  668. inst->GetSingleWordInOperand(1 - i)) {
  669. inst->SetOpcode(SpvOpCopyObject);
  670. inst->SetInOperands(
  671. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  672. return true;
  673. }
  674. }
  675. }
  676. const analysis::Constant* const_input1 = ConstInput(constants);
  677. if (!const_input1) return false;
  678. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  679. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  680. if (other_inst->opcode() == SpvOpFDiv) {
  681. std::vector<const analysis::Constant*> other_constants =
  682. const_mgr->GetOperandConstants(other_inst);
  683. const analysis::Constant* const_input2 = ConstInput(other_constants);
  684. if (!const_input2 || HasZero(const_input2)) return false;
  685. bool other_first_is_variable = other_constants[0] == nullptr;
  686. // If the variable value is the second operand of the divide, multiply
  687. // the constants together. Otherwise divide the constants.
  688. uint32_t merged_id = PerformOperation(
  689. const_mgr,
  690. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  691. const_input1, const_input2);
  692. if (merged_id == 0) return false;
  693. uint32_t non_const_id = other_first_is_variable
  694. ? other_inst->GetSingleWordInOperand(0u)
  695. : other_inst->GetSingleWordInOperand(1u);
  696. // If the variable value is on the second operand of the div, then this
  697. // operation is a div. Otherwise it should be a multiply.
  698. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  699. : other_inst->opcode());
  700. if (other_first_is_variable) {
  701. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  702. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  703. } else {
  704. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  705. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  706. }
  707. return true;
  708. }
  709. return false;
  710. };
  711. }
  712. // Merges multiply of constant and negation.
  713. // Cases:
  714. // (-x) * 2 = x * -2
  715. // 2 * (-x) = x * -2
  716. FoldingRule MergeMulNegateArithmetic() {
  717. return [](IRContext* context, Instruction* inst,
  718. const std::vector<const analysis::Constant*>& constants) {
  719. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  720. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  721. const analysis::Type* type =
  722. context->get_type_mgr()->GetType(inst->type_id());
  723. bool uses_float = HasFloatingPoint(type);
  724. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  725. uint32_t width = ElementWidth(type);
  726. if (width != 32 && width != 64) return false;
  727. const analysis::Constant* const_input1 = ConstInput(constants);
  728. if (!const_input1) return false;
  729. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  730. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  731. return false;
  732. if (other_inst->opcode() == SpvOpFNegate ||
  733. other_inst->opcode() == SpvOpSNegate) {
  734. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  735. inst->SetInOperands(
  736. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  737. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  738. return true;
  739. }
  740. return false;
  741. };
  742. }
  743. // Merges consecutive divides if each instruction contains one constant operand.
  744. // Does not support integer division.
  745. // Cases:
  746. // 2 / (x / 2) = 4 / x
  747. // 4 / (2 / x) = 2 * x
  748. // (4 / x) / 2 = 2 / x
  749. // (x / 2) / 2 = x / 4
  750. FoldingRule MergeDivDivArithmetic() {
  751. return [](IRContext* context, Instruction* inst,
  752. const std::vector<const analysis::Constant*>& constants) {
  753. assert(inst->opcode() == SpvOpFDiv);
  754. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  755. const analysis::Type* type =
  756. context->get_type_mgr()->GetType(inst->type_id());
  757. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  758. uint32_t width = ElementWidth(type);
  759. if (width != 32 && width != 64) return false;
  760. const analysis::Constant* const_input1 = ConstInput(constants);
  761. if (!const_input1 || HasZero(const_input1)) return false;
  762. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  763. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  764. bool first_is_variable = constants[0] == nullptr;
  765. if (other_inst->opcode() == inst->opcode()) {
  766. std::vector<const analysis::Constant*> other_constants =
  767. const_mgr->GetOperandConstants(other_inst);
  768. const analysis::Constant* const_input2 = ConstInput(other_constants);
  769. if (!const_input2 || HasZero(const_input2)) return false;
  770. bool other_first_is_variable = other_constants[0] == nullptr;
  771. SpvOp merge_op = inst->opcode();
  772. if (other_first_is_variable) {
  773. // Constants magnify.
  774. merge_op = SpvOpFMul;
  775. }
  776. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  777. // because it is commutative.
  778. if (first_is_variable) std::swap(const_input1, const_input2);
  779. uint32_t merged_id =
  780. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  781. if (merged_id == 0) return false;
  782. uint32_t non_const_id = other_first_is_variable
  783. ? other_inst->GetSingleWordInOperand(0u)
  784. : other_inst->GetSingleWordInOperand(1u);
  785. SpvOp op = inst->opcode();
  786. if (!first_is_variable && !other_first_is_variable) {
  787. // Effectively div of 1/x, so change to multiply.
  788. op = SpvOpFMul;
  789. }
  790. uint32_t op1 = merged_id;
  791. uint32_t op2 = non_const_id;
  792. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  793. inst->SetOpcode(op);
  794. inst->SetInOperands(
  795. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  796. return true;
  797. }
  798. return false;
  799. };
  800. }
  801. // Fold multiplies succeeded by divides where each instruction contains a
  802. // constant operand. Does not support integer divide.
  803. // Cases:
  804. // 4 / (x * 2) = 2 / x
  805. // 4 / (2 * x) = 2 / x
  806. // (x * 4) / 2 = x * 2
  807. // (4 * x) / 2 = x * 2
  808. // (x * y) / x = y
  809. // (y * x) / x = y
  810. FoldingRule MergeDivMulArithmetic() {
  811. return [](IRContext* context, Instruction* inst,
  812. const std::vector<const analysis::Constant*>& constants) {
  813. assert(inst->opcode() == SpvOpFDiv);
  814. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  815. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  816. const analysis::Type* type =
  817. context->get_type_mgr()->GetType(inst->type_id());
  818. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  819. uint32_t width = ElementWidth(type);
  820. if (width != 32 && width != 64) return false;
  821. uint32_t op_id = inst->GetSingleWordInOperand(0);
  822. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  823. if (op_inst->opcode() == SpvOpFMul) {
  824. for (uint32_t i = 0; i < 2; i++) {
  825. if (op_inst->GetSingleWordInOperand(i) ==
  826. inst->GetSingleWordInOperand(1)) {
  827. inst->SetOpcode(SpvOpCopyObject);
  828. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  829. {op_inst->GetSingleWordInOperand(1 - i)}}});
  830. return true;
  831. }
  832. }
  833. }
  834. const analysis::Constant* const_input1 = ConstInput(constants);
  835. if (!const_input1 || HasZero(const_input1)) return false;
  836. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  837. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  838. bool first_is_variable = constants[0] == nullptr;
  839. if (other_inst->opcode() == SpvOpFMul) {
  840. std::vector<const analysis::Constant*> other_constants =
  841. const_mgr->GetOperandConstants(other_inst);
  842. const analysis::Constant* const_input2 = ConstInput(other_constants);
  843. if (!const_input2) return false;
  844. bool other_first_is_variable = other_constants[0] == nullptr;
  845. // This is an x / (*) case. Swap the inputs.
  846. if (first_is_variable) std::swap(const_input1, const_input2);
  847. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  848. const_input1, const_input2);
  849. if (merged_id == 0) return false;
  850. uint32_t non_const_id = other_first_is_variable
  851. ? other_inst->GetSingleWordInOperand(0u)
  852. : other_inst->GetSingleWordInOperand(1u);
  853. uint32_t op1 = merged_id;
  854. uint32_t op2 = non_const_id;
  855. if (first_is_variable) std::swap(op1, op2);
  856. // Convert to multiply
  857. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  858. inst->SetInOperands(
  859. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  860. return true;
  861. }
  862. return false;
  863. };
  864. }
  865. // Fold divides of a constant and a negation.
  866. // Cases:
  867. // (-x) / 2 = x / -2
  868. // 2 / (-x) = -2 / x
  869. FoldingRule MergeDivNegateArithmetic() {
  870. return [](IRContext* context, Instruction* inst,
  871. const std::vector<const analysis::Constant*>& constants) {
  872. assert(inst->opcode() == SpvOpFDiv);
  873. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  874. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  875. const analysis::Constant* const_input1 = ConstInput(constants);
  876. if (!const_input1) return false;
  877. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  878. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  879. bool first_is_variable = constants[0] == nullptr;
  880. if (other_inst->opcode() == SpvOpFNegate) {
  881. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  882. if (first_is_variable) {
  883. inst->SetInOperands(
  884. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  885. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  886. } else {
  887. inst->SetInOperands(
  888. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  889. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  890. }
  891. return true;
  892. }
  893. return false;
  894. };
  895. }
  896. // Folds addition of a constant and a negation.
  897. // Cases:
  898. // (-x) + 2 = 2 - x
  899. // 2 + (-x) = 2 - x
  900. FoldingRule MergeAddNegateArithmetic() {
  901. return [](IRContext* context, Instruction* inst,
  902. const std::vector<const analysis::Constant*>& constants) {
  903. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  904. const analysis::Type* type =
  905. context->get_type_mgr()->GetType(inst->type_id());
  906. bool uses_float = HasFloatingPoint(type);
  907. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  908. const analysis::Constant* const_input1 = ConstInput(constants);
  909. if (!const_input1) return false;
  910. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  911. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  912. return false;
  913. if (other_inst->opcode() == SpvOpSNegate ||
  914. other_inst->opcode() == SpvOpFNegate) {
  915. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  916. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  917. : inst->GetSingleWordInOperand(1u);
  918. inst->SetInOperands(
  919. {{SPV_OPERAND_TYPE_ID, {const_id}},
  920. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  921. return true;
  922. }
  923. return false;
  924. };
  925. }
  926. // Folds subtraction of a constant and a negation.
  927. // Cases:
  928. // (-x) - 2 = -2 - x
  929. // 2 - (-x) = x + 2
  930. FoldingRule MergeSubNegateArithmetic() {
  931. return [](IRContext* context, Instruction* inst,
  932. const std::vector<const analysis::Constant*>& constants) {
  933. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  934. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  935. const analysis::Type* type =
  936. context->get_type_mgr()->GetType(inst->type_id());
  937. bool uses_float = HasFloatingPoint(type);
  938. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  939. uint32_t width = ElementWidth(type);
  940. if (width != 32 && width != 64) return false;
  941. const analysis::Constant* const_input1 = ConstInput(constants);
  942. if (!const_input1) return false;
  943. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  944. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  945. return false;
  946. if (other_inst->opcode() == SpvOpSNegate ||
  947. other_inst->opcode() == SpvOpFNegate) {
  948. uint32_t op1 = 0;
  949. uint32_t op2 = 0;
  950. SpvOp opcode = inst->opcode();
  951. if (constants[0] != nullptr) {
  952. op1 = other_inst->GetSingleWordInOperand(0u);
  953. op2 = inst->GetSingleWordInOperand(0u);
  954. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  955. } else {
  956. op1 = NegateConstant(const_mgr, const_input1);
  957. op2 = other_inst->GetSingleWordInOperand(0u);
  958. }
  959. inst->SetOpcode(opcode);
  960. inst->SetInOperands(
  961. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  962. return true;
  963. }
  964. return false;
  965. };
  966. }
  967. // Folds addition of an addition where each operation has a constant operand.
  968. // Cases:
  969. // (x + 2) + 2 = x + 4
  970. // (2 + x) + 2 = x + 4
  971. // 2 + (x + 2) = x + 4
  972. // 2 + (2 + x) = x + 4
  973. FoldingRule MergeAddAddArithmetic() {
  974. return [](IRContext* context, Instruction* inst,
  975. const std::vector<const analysis::Constant*>& constants) {
  976. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  977. const analysis::Type* type =
  978. context->get_type_mgr()->GetType(inst->type_id());
  979. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  980. bool uses_float = HasFloatingPoint(type);
  981. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  982. uint32_t width = ElementWidth(type);
  983. if (width != 32 && width != 64) return false;
  984. const analysis::Constant* const_input1 = ConstInput(constants);
  985. if (!const_input1) return false;
  986. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  987. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  988. return false;
  989. if (other_inst->opcode() == SpvOpFAdd ||
  990. other_inst->opcode() == SpvOpIAdd) {
  991. std::vector<const analysis::Constant*> other_constants =
  992. const_mgr->GetOperandConstants(other_inst);
  993. const analysis::Constant* const_input2 = ConstInput(other_constants);
  994. if (!const_input2) return false;
  995. Instruction* non_const_input =
  996. NonConstInput(context, other_constants[0], other_inst);
  997. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  998. const_input1, const_input2);
  999. if (merged_id == 0) return false;
  1000. inst->SetInOperands(
  1001. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  1002. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  1003. return true;
  1004. }
  1005. return false;
  1006. };
  1007. }
  1008. // Folds addition of a subtraction where each operation has a constant operand.
  1009. // Cases:
  1010. // (x - 2) + 2 = x + 0
  1011. // (2 - x) + 2 = 4 - x
  1012. // 2 + (x - 2) = x + 0
  1013. // 2 + (2 - x) = 4 - x
  1014. FoldingRule MergeAddSubArithmetic() {
  1015. return [](IRContext* context, Instruction* inst,
  1016. const std::vector<const analysis::Constant*>& constants) {
  1017. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1018. const analysis::Type* type =
  1019. context->get_type_mgr()->GetType(inst->type_id());
  1020. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1021. bool uses_float = HasFloatingPoint(type);
  1022. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1023. uint32_t width = ElementWidth(type);
  1024. if (width != 32 && width != 64) return false;
  1025. const analysis::Constant* const_input1 = ConstInput(constants);
  1026. if (!const_input1) return false;
  1027. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1028. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1029. return false;
  1030. if (other_inst->opcode() == SpvOpFSub ||
  1031. other_inst->opcode() == SpvOpISub) {
  1032. std::vector<const analysis::Constant*> other_constants =
  1033. const_mgr->GetOperandConstants(other_inst);
  1034. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1035. if (!const_input2) return false;
  1036. bool first_is_variable = other_constants[0] == nullptr;
  1037. SpvOp op = inst->opcode();
  1038. uint32_t op1 = 0;
  1039. uint32_t op2 = 0;
  1040. if (first_is_variable) {
  1041. // Subtract constants. Non-constant operand is first.
  1042. op1 = other_inst->GetSingleWordInOperand(0u);
  1043. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  1044. const_input2);
  1045. } else {
  1046. // Add constants. Constant operand is first. Change the opcode.
  1047. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  1048. const_input2);
  1049. op2 = other_inst->GetSingleWordInOperand(1u);
  1050. op = other_inst->opcode();
  1051. }
  1052. if (op1 == 0 || op2 == 0) return false;
  1053. inst->SetOpcode(op);
  1054. inst->SetInOperands(
  1055. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1056. return true;
  1057. }
  1058. return false;
  1059. };
  1060. }
  1061. // Folds subtraction of an addition where each operand has a constant operand.
  1062. // Cases:
  1063. // (x + 2) - 2 = x + 0
  1064. // (2 + x) - 2 = x + 0
  1065. // 2 - (x + 2) = 0 - x
  1066. // 2 - (2 + x) = 0 - x
  1067. FoldingRule MergeSubAddArithmetic() {
  1068. return [](IRContext* context, Instruction* inst,
  1069. const std::vector<const analysis::Constant*>& constants) {
  1070. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1071. const analysis::Type* type =
  1072. context->get_type_mgr()->GetType(inst->type_id());
  1073. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1074. bool uses_float = HasFloatingPoint(type);
  1075. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1076. uint32_t width = ElementWidth(type);
  1077. if (width != 32 && width != 64) return false;
  1078. const analysis::Constant* const_input1 = ConstInput(constants);
  1079. if (!const_input1) return false;
  1080. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1081. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1082. return false;
  1083. if (other_inst->opcode() == SpvOpFAdd ||
  1084. other_inst->opcode() == SpvOpIAdd) {
  1085. std::vector<const analysis::Constant*> other_constants =
  1086. const_mgr->GetOperandConstants(other_inst);
  1087. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1088. if (!const_input2) return false;
  1089. Instruction* non_const_input =
  1090. NonConstInput(context, other_constants[0], other_inst);
  1091. // If the first operand of the sub is not a constant, swap the constants
  1092. // so the subtraction has the correct operands.
  1093. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1094. // Subtract the constants.
  1095. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1096. const_input1, const_input2);
  1097. SpvOp op = inst->opcode();
  1098. uint32_t op1 = 0;
  1099. uint32_t op2 = 0;
  1100. if (constants[0] == nullptr) {
  1101. // Non-constant operand is first. Change the opcode.
  1102. op1 = non_const_input->result_id();
  1103. op2 = merged_id;
  1104. op = other_inst->opcode();
  1105. } else {
  1106. // Constant operand is first.
  1107. op1 = merged_id;
  1108. op2 = non_const_input->result_id();
  1109. }
  1110. if (op1 == 0 || op2 == 0) return false;
  1111. inst->SetOpcode(op);
  1112. inst->SetInOperands(
  1113. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1114. return true;
  1115. }
  1116. return false;
  1117. };
  1118. }
  1119. // Folds subtraction of a subtraction where each operand has a constant operand.
  1120. // Cases:
  1121. // (x - 2) - 2 = x - 4
  1122. // (2 - x) - 2 = 0 - x
  1123. // 2 - (x - 2) = 4 - x
  1124. // 2 - (2 - x) = x + 0
  1125. FoldingRule MergeSubSubArithmetic() {
  1126. return [](IRContext* context, Instruction* inst,
  1127. const std::vector<const analysis::Constant*>& constants) {
  1128. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1129. const analysis::Type* type =
  1130. context->get_type_mgr()->GetType(inst->type_id());
  1131. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1132. bool uses_float = HasFloatingPoint(type);
  1133. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1134. uint32_t width = ElementWidth(type);
  1135. if (width != 32 && width != 64) return false;
  1136. const analysis::Constant* const_input1 = ConstInput(constants);
  1137. if (!const_input1) return false;
  1138. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1139. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1140. return false;
  1141. if (other_inst->opcode() == SpvOpFSub ||
  1142. other_inst->opcode() == SpvOpISub) {
  1143. std::vector<const analysis::Constant*> other_constants =
  1144. const_mgr->GetOperandConstants(other_inst);
  1145. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1146. if (!const_input2) return false;
  1147. Instruction* non_const_input =
  1148. NonConstInput(context, other_constants[0], other_inst);
  1149. // Merge the constants.
  1150. uint32_t merged_id = 0;
  1151. SpvOp merge_op = inst->opcode();
  1152. if (other_constants[0] == nullptr) {
  1153. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1154. } else if (constants[0] == nullptr) {
  1155. std::swap(const_input1, const_input2);
  1156. }
  1157. merged_id =
  1158. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1159. if (merged_id == 0) return false;
  1160. SpvOp op = inst->opcode();
  1161. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1162. // Change the operation.
  1163. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1164. }
  1165. uint32_t op1 = 0;
  1166. uint32_t op2 = 0;
  1167. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1168. op1 = merged_id;
  1169. op2 = non_const_input->result_id();
  1170. } else {
  1171. op1 = non_const_input->result_id();
  1172. op2 = merged_id;
  1173. }
  1174. inst->SetOpcode(op);
  1175. inst->SetInOperands(
  1176. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1177. return true;
  1178. }
  1179. return false;
  1180. };
  1181. }
  1182. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1183. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1184. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1185. IRContext* context = inst->context();
  1186. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1187. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1188. if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
  1189. return false;
  1190. if (sub_inst->opcode() == SpvOpFSub &&
  1191. !sub_inst->IsFloatingPointFoldingAllowed())
  1192. return false;
  1193. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1194. inst->SetOpcode(SpvOpCopyObject);
  1195. inst->SetInOperands(
  1196. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1197. context->UpdateDefUse(inst);
  1198. return true;
  1199. }
  1200. // Folds addition of a subtraction where the subtrahend is equal to the
  1201. // other addend. Return a copy of the minuend. Accepts generic (const and
  1202. // non-const) operands.
  1203. // Cases:
  1204. // (a - b) + b = a
  1205. // b + (a - b) = a
  1206. FoldingRule MergeGenericAddSubArithmetic() {
  1207. return [](IRContext* context, Instruction* inst,
  1208. const std::vector<const analysis::Constant*>&) {
  1209. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1210. const analysis::Type* type =
  1211. context->get_type_mgr()->GetType(inst->type_id());
  1212. bool uses_float = HasFloatingPoint(type);
  1213. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1214. uint32_t width = ElementWidth(type);
  1215. if (width != 32 && width != 64) return false;
  1216. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1217. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1218. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1219. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1220. };
  1221. }
  1222. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1223. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1224. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1225. uint32_t factor1_0, uint32_t factor1_1,
  1226. Instruction* inst) {
  1227. IRContext* context = inst->context();
  1228. if (factor0_0 != factor1_0) return false;
  1229. InstructionBuilder ir_builder(
  1230. context, inst,
  1231. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1232. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1233. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1234. inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
  1235. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1236. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1237. context->UpdateDefUse(inst);
  1238. return true;
  1239. }
  1240. // Perform the following factoring identity, handling all operand order
  1241. // combinations: (a * b) + (a * c) = a * (b + c)
  1242. FoldingRule FactorAddMuls() {
  1243. return [](IRContext* context, Instruction* inst,
  1244. const std::vector<const analysis::Constant*>&) {
  1245. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1246. const analysis::Type* type =
  1247. context->get_type_mgr()->GetType(inst->type_id());
  1248. bool uses_float = HasFloatingPoint(type);
  1249. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1250. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1251. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1252. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1253. if (add_op0_inst->opcode() != SpvOpFMul &&
  1254. add_op0_inst->opcode() != SpvOpIMul)
  1255. return false;
  1256. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1257. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1258. if (add_op1_inst->opcode() != SpvOpFMul &&
  1259. add_op1_inst->opcode() != SpvOpIMul)
  1260. return false;
  1261. // Only perform this optimization if both of the muls only have one use.
  1262. // Otherwise this is a deoptimization in size and performance.
  1263. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1264. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1265. if (add_op0_inst->opcode() == SpvOpFMul &&
  1266. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1267. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1268. return false;
  1269. for (int i = 0; i < 2; i++) {
  1270. for (int j = 0; j < 2; j++) {
  1271. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1272. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1273. add_op0_inst->GetSingleWordInOperand(1 - i),
  1274. add_op1_inst->GetSingleWordInOperand(j),
  1275. add_op1_inst->GetSingleWordInOperand(1 - j),
  1276. inst))
  1277. return true;
  1278. }
  1279. }
  1280. return false;
  1281. };
  1282. }
  1283. // Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
  1284. void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
  1285. uint32_t ext =
  1286. inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1287. if (ext == 0) {
  1288. inst->context()->AddExtInstImport("GLSL.std.450");
  1289. ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1290. assert(ext != 0 &&
  1291. "Could not add the GLSL.std.450 extended instruction set");
  1292. }
  1293. std::vector<Operand> operands;
  1294. operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
  1295. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
  1296. operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
  1297. operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
  1298. operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
  1299. inst->SetOpcode(SpvOpExtInst);
  1300. inst->SetInOperands(std::move(operands));
  1301. }
  1302. // Folds a multiple and add into an Fma.
  1303. //
  1304. // Cases:
  1305. // (x * y) + a = Fma x y a
  1306. // a + (x * y) = Fma x y a
  1307. bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
  1308. const std::vector<const analysis::Constant*>&) {
  1309. assert(inst->opcode() == SpvOpFAdd);
  1310. if (!inst->IsFloatingPointFoldingAllowed()) {
  1311. return false;
  1312. }
  1313. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1314. for (int i = 0; i < 2; i++) {
  1315. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1316. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  1317. if (op_inst->opcode() != SpvOpFMul) {
  1318. continue;
  1319. }
  1320. if (!op_inst->IsFloatingPointFoldingAllowed()) {
  1321. continue;
  1322. }
  1323. uint32_t x = op_inst->GetSingleWordInOperand(0);
  1324. uint32_t y = op_inst->GetSingleWordInOperand(1);
  1325. uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
  1326. ReplaceWithFma(inst, x, y, a);
  1327. return true;
  1328. }
  1329. return false;
  1330. }
  1331. // Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
  1332. // negated if |negate_addition| is true, otherwise |x| gets negated.
  1333. void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
  1334. uint32_t a, bool negate_addition) {
  1335. uint32_t ext =
  1336. sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1337. if (ext == 0) {
  1338. sub->context()->AddExtInstImport("GLSL.std.450");
  1339. ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1340. assert(ext != 0 &&
  1341. "Could not add the GLSL.std.450 extended instruction set");
  1342. }
  1343. InstructionBuilder ir_builder(
  1344. sub->context(), sub,
  1345. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1346. Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate,
  1347. negate_addition ? a : x);
  1348. uint32_t neg_op = neg->result_id(); // -a : -x
  1349. std::vector<Operand> operands;
  1350. operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
  1351. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
  1352. operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
  1353. operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
  1354. operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
  1355. sub->SetOpcode(SpvOpExtInst);
  1356. sub->SetInOperands(std::move(operands));
  1357. }
  1358. // Folds a multiply and subtract into an Fma and negation.
  1359. //
  1360. // Cases:
  1361. // (x * y) - a = Fma x y -a
  1362. // a - (x * y) = Fma -x y a
  1363. bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
  1364. const std::vector<const analysis::Constant*>&) {
  1365. assert(sub->opcode() == SpvOpFSub);
  1366. if (!sub->IsFloatingPointFoldingAllowed()) {
  1367. return false;
  1368. }
  1369. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1370. for (int i = 0; i < 2; i++) {
  1371. uint32_t op_id = sub->GetSingleWordInOperand(i);
  1372. Instruction* mul = def_use_mgr->GetDef(op_id);
  1373. if (mul->opcode() != SpvOpFMul) {
  1374. continue;
  1375. }
  1376. if (!mul->IsFloatingPointFoldingAllowed()) {
  1377. continue;
  1378. }
  1379. uint32_t x = mul->GetSingleWordInOperand(0);
  1380. uint32_t y = mul->GetSingleWordInOperand(1);
  1381. uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
  1382. ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
  1383. return true;
  1384. }
  1385. return false;
  1386. }
  1387. FoldingRule IntMultipleBy1() {
  1388. return [](IRContext*, Instruction* inst,
  1389. const std::vector<const analysis::Constant*>& constants) {
  1390. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1391. for (uint32_t i = 0; i < 2; i++) {
  1392. if (constants[i] == nullptr) {
  1393. continue;
  1394. }
  1395. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1396. if (int_constant) {
  1397. uint32_t width = ElementWidth(int_constant->type());
  1398. if (width != 32 && width != 64) return false;
  1399. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1400. : int_constant->GetU64BitValue() == 1ull;
  1401. if (is_one) {
  1402. inst->SetOpcode(SpvOpCopyObject);
  1403. inst->SetInOperands(
  1404. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1405. return true;
  1406. }
  1407. }
  1408. }
  1409. return false;
  1410. };
  1411. }
  1412. // Returns the number of elements that the |index|th in operand in |inst|
  1413. // contributes to the result of |inst|. |inst| must be an
  1414. // OpCompositeConstructInstruction.
  1415. uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
  1416. const Instruction* inst,
  1417. uint32_t index) {
  1418. assert(inst->opcode() == SpvOpCompositeConstruct);
  1419. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1420. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1421. analysis::Vector* result_type =
  1422. type_mgr->GetType(inst->type_id())->AsVector();
  1423. if (result_type == nullptr) {
  1424. // If the result of the OpCompositeConstruct is not a vector then every
  1425. // operands corresponds to a single element in the result.
  1426. return 1;
  1427. }
  1428. // If the result type is a vector then the operands are either scalars or
  1429. // vectors. If it is a scalar, then it corresponds to a single element. If it
  1430. // is a vector, then each element in the vector will be an element in the
  1431. // result.
  1432. uint32_t id = inst->GetSingleWordInOperand(index);
  1433. Instruction* def = def_use_mgr->GetDef(id);
  1434. analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
  1435. if (type == nullptr) {
  1436. return 1;
  1437. }
  1438. return type->element_count();
  1439. }
  1440. // Returns the in-operands for an OpCompositeExtract instruction that are needed
  1441. // to extract the |result_index|th element in the result of |inst| without using
  1442. // the result of |inst|. Returns the empty vector if |result_index| is
  1443. // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
  1444. std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
  1445. IRContext* context, const Instruction* inst, uint32_t result_index) {
  1446. assert(inst->opcode() == SpvOpCompositeConstruct);
  1447. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1448. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1449. analysis::Type* result_type = type_mgr->GetType(inst->type_id());
  1450. if (result_type->AsVector() == nullptr) {
  1451. uint32_t id = inst->GetSingleWordInOperand(result_index);
  1452. return {Operand(SPV_OPERAND_TYPE_ID, {id})};
  1453. }
  1454. // If the result type is a vector, then vector operands are concatenated.
  1455. uint32_t total_element_count = 0;
  1456. for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
  1457. uint32_t element_count =
  1458. GetNumOfElementsContributedByOperand(context, inst, idx);
  1459. total_element_count += element_count;
  1460. if (result_index < total_element_count) {
  1461. std::vector<Operand> operands;
  1462. uint32_t id = inst->GetSingleWordInOperand(idx);
  1463. Instruction* operand_def = def_use_mgr->GetDef(id);
  1464. analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
  1465. operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
  1466. if (operand_type->AsVector()) {
  1467. uint32_t start_index_of_id = total_element_count - element_count;
  1468. uint32_t index_into_id = result_index - start_index_of_id;
  1469. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
  1470. }
  1471. return operands;
  1472. }
  1473. }
  1474. return {};
  1475. }
  1476. bool CompositeConstructFeedingExtract(
  1477. IRContext* context, Instruction* inst,
  1478. const std::vector<const analysis::Constant*>&) {
  1479. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1480. // then we can simply use the appropriate element in the construction.
  1481. assert(inst->opcode() == SpvOpCompositeExtract &&
  1482. "Wrong opcode. Should be OpCompositeExtract.");
  1483. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1484. // If there are no index operands, then this rule cannot do anything.
  1485. if (inst->NumInOperands() <= 1) {
  1486. return false;
  1487. }
  1488. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1489. Instruction* cinst = def_use_mgr->GetDef(cid);
  1490. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1491. return false;
  1492. }
  1493. uint32_t index_into_result = inst->GetSingleWordInOperand(1);
  1494. std::vector<Operand> operands =
  1495. GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
  1496. index_into_result);
  1497. if (operands.empty()) {
  1498. return false;
  1499. }
  1500. // Add the remaining indices for extraction.
  1501. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1502. operands.push_back(
  1503. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
  1504. }
  1505. if (operands.size() == 1) {
  1506. // If there were no extra indices, then we have the final object. No need
  1507. // to extract any more.
  1508. inst->SetOpcode(SpvOpCopyObject);
  1509. }
  1510. inst->SetInOperands(std::move(operands));
  1511. return true;
  1512. }
  1513. // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
  1514. // OpCompositeExtract instruction, and returns the type of the final element
  1515. // being accessed.
  1516. const analysis::Type* GetElementType(uint32_t type_id,
  1517. Instruction::iterator start,
  1518. Instruction::iterator end,
  1519. const analysis::TypeManager* type_mgr) {
  1520. const analysis::Type* type = type_mgr->GetType(type_id);
  1521. for (auto index : make_range(std::move(start), std::move(end))) {
  1522. assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
  1523. index.words.size() == 1);
  1524. if (auto* array_type = type->AsArray()) {
  1525. type = array_type->element_type();
  1526. } else if (auto* matrix_type = type->AsMatrix()) {
  1527. type = matrix_type->element_type();
  1528. } else if (auto* struct_type = type->AsStruct()) {
  1529. type = struct_type->element_types()[index.words[0]];
  1530. } else {
  1531. type = nullptr;
  1532. }
  1533. }
  1534. return type;
  1535. }
  1536. // Returns true of |inst_1| and |inst_2| have the same indexes that will be used
  1537. // to index into a composite object, excluding the last index. The two
  1538. // instructions must have the same opcode, and be either OpCompositeExtract or
  1539. // OpCompositeInsert instructions.
  1540. bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
  1541. assert(inst_1->opcode() == inst_2->opcode() &&
  1542. "Expecting the opcodes to be the same.");
  1543. assert((inst_1->opcode() == SpvOpCompositeInsert ||
  1544. inst_1->opcode() == SpvOpCompositeExtract) &&
  1545. "Instructions must be OpCompositeInsert or OpCompositeExtract.");
  1546. if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
  1547. return false;
  1548. }
  1549. uint32_t first_index_position =
  1550. (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
  1551. for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
  1552. i++) {
  1553. if (inst_1->GetSingleWordInOperand(i) !=
  1554. inst_2->GetSingleWordInOperand(i)) {
  1555. return false;
  1556. }
  1557. }
  1558. return true;
  1559. }
  1560. // If the OpCompositeConstruct is simply putting back together elements that
  1561. // where extracted from the same source, we can simply reuse the source.
  1562. //
  1563. // This is a common code pattern because of the way that scalar replacement
  1564. // works.
  1565. bool CompositeExtractFeedingConstruct(
  1566. IRContext* context, Instruction* inst,
  1567. const std::vector<const analysis::Constant*>&) {
  1568. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1569. "Wrong opcode. Should be OpCompositeConstruct.");
  1570. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1571. uint32_t original_id = 0;
  1572. if (inst->NumInOperands() == 0) {
  1573. // The struct being constructed has no members.
  1574. return false;
  1575. }
  1576. // Check each element to make sure they are:
  1577. // - extractions
  1578. // - extracting the same position they are inserting
  1579. // - all extract from the same id.
  1580. Instruction* first_element_inst = nullptr;
  1581. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1582. const uint32_t element_id = inst->GetSingleWordInOperand(i);
  1583. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1584. if (first_element_inst == nullptr) {
  1585. first_element_inst = element_inst;
  1586. }
  1587. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1588. return false;
  1589. }
  1590. if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
  1591. return false;
  1592. }
  1593. if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
  1594. 1) != i) {
  1595. return false;
  1596. }
  1597. if (i == 0) {
  1598. original_id =
  1599. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1600. } else if (original_id !=
  1601. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
  1602. return false;
  1603. }
  1604. }
  1605. // The last check it to see that the object being extracted from is the
  1606. // correct type.
  1607. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1608. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1609. const analysis::Type* original_type =
  1610. GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
  1611. first_element_inst->end() - 1, type_mgr);
  1612. if (original_type == nullptr) {
  1613. return false;
  1614. }
  1615. if (inst->type_id() != type_mgr->GetId(original_type)) {
  1616. return false;
  1617. }
  1618. if (first_element_inst->NumInOperands() == 2) {
  1619. // Simplify by using the original object.
  1620. inst->SetOpcode(SpvOpCopyObject);
  1621. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1622. return true;
  1623. }
  1624. // Copies the original id and all indexes except for the last to the new
  1625. // extract instruction.
  1626. inst->SetOpcode(SpvOpCompositeExtract);
  1627. inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
  1628. first_element_inst->end() - 1));
  1629. return true;
  1630. }
  1631. FoldingRule InsertFeedingExtract() {
  1632. return [](IRContext* context, Instruction* inst,
  1633. const std::vector<const analysis::Constant*>&) {
  1634. assert(inst->opcode() == SpvOpCompositeExtract &&
  1635. "Wrong opcode. Should be OpCompositeExtract.");
  1636. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1637. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1638. Instruction* cinst = def_use_mgr->GetDef(cid);
  1639. if (cinst->opcode() != SpvOpCompositeInsert) {
  1640. return false;
  1641. }
  1642. // Find the first position where the list of insert and extract indicies
  1643. // differ, if at all.
  1644. uint32_t i;
  1645. for (i = 1; i < inst->NumInOperands(); ++i) {
  1646. if (i + 1 >= cinst->NumInOperands()) {
  1647. break;
  1648. }
  1649. if (inst->GetSingleWordInOperand(i) !=
  1650. cinst->GetSingleWordInOperand(i + 1)) {
  1651. break;
  1652. }
  1653. }
  1654. // We are extracting the element that was inserted.
  1655. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1656. inst->SetOpcode(SpvOpCopyObject);
  1657. inst->SetInOperands(
  1658. {{SPV_OPERAND_TYPE_ID,
  1659. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1660. return true;
  1661. }
  1662. // Extracting the value that was inserted along with values for the base
  1663. // composite. Cannot do anything.
  1664. if (i == inst->NumInOperands()) {
  1665. return false;
  1666. }
  1667. // Extracting an element of the value that was inserted. Extract from
  1668. // that value directly.
  1669. if (i + 1 == cinst->NumInOperands()) {
  1670. std::vector<Operand> operands;
  1671. operands.push_back(
  1672. {SPV_OPERAND_TYPE_ID,
  1673. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1674. for (; i < inst->NumInOperands(); ++i) {
  1675. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1676. {inst->GetSingleWordInOperand(i)}});
  1677. }
  1678. inst->SetInOperands(std::move(operands));
  1679. return true;
  1680. }
  1681. // Extracting a value that is disjoint from the element being inserted.
  1682. // Rewrite the extract to use the composite input to the insert.
  1683. std::vector<Operand> operands;
  1684. operands.push_back(
  1685. {SPV_OPERAND_TYPE_ID,
  1686. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1687. for (i = 1; i < inst->NumInOperands(); ++i) {
  1688. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1689. {inst->GetSingleWordInOperand(i)}});
  1690. }
  1691. inst->SetInOperands(std::move(operands));
  1692. return true;
  1693. };
  1694. }
  1695. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1696. // operands of the VectorShuffle. We just need to adjust the index in the
  1697. // extract instruction.
  1698. FoldingRule VectorShuffleFeedingExtract() {
  1699. return [](IRContext* context, Instruction* inst,
  1700. const std::vector<const analysis::Constant*>&) {
  1701. assert(inst->opcode() == SpvOpCompositeExtract &&
  1702. "Wrong opcode. Should be OpCompositeExtract.");
  1703. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1704. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1705. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1706. Instruction* cinst = def_use_mgr->GetDef(cid);
  1707. if (cinst->opcode() != SpvOpVectorShuffle) {
  1708. return false;
  1709. }
  1710. // Find the size of the first vector operand of the VectorShuffle
  1711. Instruction* first_input =
  1712. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1713. analysis::Type* first_input_type =
  1714. type_mgr->GetType(first_input->type_id());
  1715. assert(first_input_type->AsVector() &&
  1716. "Input to vector shuffle should be vectors.");
  1717. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1718. // Get index of the element the vector shuffle is placing in the position
  1719. // being extracted.
  1720. uint32_t new_index =
  1721. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1722. // Extracting an undefined value so fold this extract into an undef.
  1723. const uint32_t undef_literal_value = 0xffffffff;
  1724. if (new_index == undef_literal_value) {
  1725. inst->SetOpcode(SpvOpUndef);
  1726. inst->SetInOperands({});
  1727. return true;
  1728. }
  1729. // Get the id of the of the vector the elemtent comes from, and update the
  1730. // index if needed.
  1731. uint32_t new_vector = 0;
  1732. if (new_index < first_input_size) {
  1733. new_vector = cinst->GetSingleWordInOperand(0);
  1734. } else {
  1735. new_vector = cinst->GetSingleWordInOperand(1);
  1736. new_index -= first_input_size;
  1737. }
  1738. // Update the extract instruction.
  1739. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1740. inst->SetInOperand(1, {new_index});
  1741. return true;
  1742. };
  1743. }
  1744. // When an FMix with is feeding an Extract that extracts an element whose
  1745. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1746. // operands of the FMix.
  1747. FoldingRule FMixFeedingExtract() {
  1748. return [](IRContext* context, Instruction* inst,
  1749. const std::vector<const analysis::Constant*>&) {
  1750. assert(inst->opcode() == SpvOpCompositeExtract &&
  1751. "Wrong opcode. Should be OpCompositeExtract.");
  1752. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1753. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1754. uint32_t composite_id =
  1755. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1756. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1757. if (composite_inst->opcode() != SpvOpExtInst) {
  1758. return false;
  1759. }
  1760. uint32_t inst_set_id =
  1761. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1762. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1763. inst_set_id ||
  1764. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1765. GLSLstd450FMix) {
  1766. return false;
  1767. }
  1768. // Get the |a| for the FMix instruction.
  1769. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1770. std::unique_ptr<Instruction> a(inst->Clone(context));
  1771. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1772. context->get_instruction_folder().FoldInstruction(a.get());
  1773. if (a->opcode() != SpvOpCopyObject) {
  1774. return false;
  1775. }
  1776. const analysis::Constant* a_const =
  1777. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1778. if (!a_const) {
  1779. return false;
  1780. }
  1781. bool use_x = false;
  1782. assert(a_const->type()->AsFloat());
  1783. double element_value = a_const->GetValueAsDouble();
  1784. if (element_value == 0.0) {
  1785. use_x = true;
  1786. } else if (element_value == 1.0) {
  1787. use_x = false;
  1788. } else {
  1789. return false;
  1790. }
  1791. // Get the id of the of the vector the element comes from.
  1792. uint32_t new_vector = 0;
  1793. if (use_x) {
  1794. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1795. } else {
  1796. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1797. }
  1798. // Update the extract instruction.
  1799. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1800. return true;
  1801. };
  1802. }
  1803. // Returns the number of elements in the composite type |type|. Returns 0 if
  1804. // |type| is a scalar value.
  1805. uint32_t GetNumberOfElements(const analysis::Type* type) {
  1806. if (auto* vector_type = type->AsVector()) {
  1807. return vector_type->element_count();
  1808. }
  1809. if (auto* matrix_type = type->AsMatrix()) {
  1810. return matrix_type->element_count();
  1811. }
  1812. if (auto* struct_type = type->AsStruct()) {
  1813. return static_cast<uint32_t>(struct_type->element_types().size());
  1814. }
  1815. if (auto* array_type = type->AsArray()) {
  1816. return array_type->length_info().words[0];
  1817. }
  1818. return 0;
  1819. }
  1820. // Returns a map with the set of values that were inserted into an object by
  1821. // the chain of OpCompositeInsertInstruction starting with |inst|.
  1822. // The map will map the index to the value inserted at that index.
  1823. std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
  1824. analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
  1825. std::map<uint32_t, uint32_t> values_inserted;
  1826. Instruction* current_inst = inst;
  1827. while (current_inst->opcode() == SpvOpCompositeInsert) {
  1828. if (current_inst->NumInOperands() > inst->NumInOperands()) {
  1829. // This is the catch the case
  1830. // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
  1831. // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
  1832. // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
  1833. // In this case we cannot do a single construct to get the matrix.
  1834. uint32_t partially_inserted_element_index =
  1835. current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
  1836. if (values_inserted.count(partially_inserted_element_index) == 0)
  1837. return {};
  1838. }
  1839. if (HaveSameIndexesExceptForLast(inst, current_inst)) {
  1840. values_inserted.insert(
  1841. {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
  1842. 1),
  1843. current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
  1844. }
  1845. current_inst = def_use_mgr->GetDef(
  1846. current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
  1847. }
  1848. return values_inserted;
  1849. }
  1850. // Returns true of there is an entry in |values_inserted| for every element of
  1851. // |Type|.
  1852. bool DoInsertedValuesCoverEntireObject(
  1853. const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
  1854. uint32_t container_size = GetNumberOfElements(type);
  1855. if (container_size != values_inserted.size()) {
  1856. return false;
  1857. }
  1858. if (values_inserted.rbegin()->first >= container_size) {
  1859. return false;
  1860. }
  1861. return true;
  1862. }
  1863. // Returns the type of the element that immediately contains the element being
  1864. // inserted by the OpCompositeInsert instruction |inst|.
  1865. const analysis::Type* GetContainerType(Instruction* inst) {
  1866. assert(inst->opcode() == SpvOpCompositeInsert);
  1867. analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
  1868. return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
  1869. type_mgr);
  1870. }
  1871. // Returns an OpCompositeConstruct instruction that build an object with
  1872. // |type_id| out of the values in |values_inserted|. Each value will be
  1873. // placed at the index corresponding to the value. The new instruction will
  1874. // be placed before |insert_before|.
  1875. Instruction* BuildCompositeConstruct(
  1876. uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
  1877. Instruction* insert_before) {
  1878. InstructionBuilder ir_builder(
  1879. insert_before->context(), insert_before,
  1880. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1881. std::vector<uint32_t> ids_in_order;
  1882. for (auto it : values_inserted) {
  1883. ids_in_order.push_back(it.second);
  1884. }
  1885. Instruction* construct =
  1886. ir_builder.AddCompositeConstruct(type_id, ids_in_order);
  1887. return construct;
  1888. }
  1889. // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
  1890. // object as |inst| with final index removed. If the resulting
  1891. // OpCompositeInsert instruction would have no remaining indexes, the
  1892. // instruction is replaced with an OpCopyObject instead.
  1893. void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
  1894. if (inst->NumInOperands() == 3) {
  1895. inst->SetOpcode(SpvOpCopyObject);
  1896. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
  1897. } else {
  1898. inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
  1899. inst->RemoveOperand(inst->NumOperands() - 1);
  1900. }
  1901. }
  1902. // Replaces a series of |OpCompositeInsert| instruction that cover the entire
  1903. // object with an |OpCompositeConstruct|.
  1904. bool CompositeInsertToCompositeConstruct(
  1905. IRContext* context, Instruction* inst,
  1906. const std::vector<const analysis::Constant*>&) {
  1907. assert(inst->opcode() == SpvOpCompositeInsert &&
  1908. "Wrong opcode. Should be OpCompositeInsert.");
  1909. if (inst->NumInOperands() < 3) return false;
  1910. std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
  1911. const analysis::Type* container_type = GetContainerType(inst);
  1912. if (container_type == nullptr) {
  1913. return false;
  1914. }
  1915. if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
  1916. return false;
  1917. }
  1918. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1919. Instruction* construct = BuildCompositeConstruct(
  1920. type_mgr->GetId(container_type), values_inserted, inst);
  1921. InsertConstructedObject(inst, construct);
  1922. return true;
  1923. }
  1924. FoldingRule RedundantPhi() {
  1925. // An OpPhi instruction where all values are the same or the result of the phi
  1926. // itself, can be replaced by the value itself.
  1927. return [](IRContext*, Instruction* inst,
  1928. const std::vector<const analysis::Constant*>&) {
  1929. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1930. uint32_t incoming_value = 0;
  1931. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1932. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1933. if (op_id == inst->result_id()) {
  1934. continue;
  1935. }
  1936. if (incoming_value == 0) {
  1937. incoming_value = op_id;
  1938. } else if (op_id != incoming_value) {
  1939. // Found two possible value. Can't simplify.
  1940. return false;
  1941. }
  1942. }
  1943. if (incoming_value == 0) {
  1944. // Code looks invalid. Don't do anything.
  1945. return false;
  1946. }
  1947. // We have a single incoming value. Simplify using that value.
  1948. inst->SetOpcode(SpvOpCopyObject);
  1949. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1950. return true;
  1951. };
  1952. }
  1953. FoldingRule BitCastScalarOrVector() {
  1954. return [](IRContext* context, Instruction* inst,
  1955. const std::vector<const analysis::Constant*>& constants) {
  1956. assert(inst->opcode() == SpvOpBitcast && constants.size() == 1);
  1957. if (constants[0] == nullptr) return false;
  1958. const analysis::Type* type =
  1959. context->get_type_mgr()->GetType(inst->type_id());
  1960. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  1961. return false;
  1962. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1963. std::vector<uint32_t> words =
  1964. GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
  1965. if (words.size() == 0) return false;
  1966. const analysis::Constant* bitcasted_constant =
  1967. ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
  1968. if (!bitcasted_constant) return false;
  1969. auto new_feeder_id =
  1970. const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
  1971. ->result_id();
  1972. inst->SetOpcode(SpvOpCopyObject);
  1973. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
  1974. return true;
  1975. };
  1976. }
  1977. FoldingRule RedundantSelect() {
  1978. // An OpSelect instruction where both values are the same or the condition is
  1979. // constant can be replaced by one of the values
  1980. return [](IRContext*, Instruction* inst,
  1981. const std::vector<const analysis::Constant*>& constants) {
  1982. assert(inst->opcode() == SpvOpSelect &&
  1983. "Wrong opcode. Should be OpSelect.");
  1984. assert(inst->NumInOperands() == 3);
  1985. assert(constants.size() == 3);
  1986. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1987. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1988. if (true_id == false_id) {
  1989. // Both results are the same, condition doesn't matter
  1990. inst->SetOpcode(SpvOpCopyObject);
  1991. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1992. return true;
  1993. } else if (constants[0]) {
  1994. const analysis::Type* type = constants[0]->type();
  1995. if (type->AsBool()) {
  1996. // Scalar constant value, select the corresponding value.
  1997. inst->SetOpcode(SpvOpCopyObject);
  1998. if (constants[0]->AsNullConstant() ||
  1999. !constants[0]->AsBoolConstant()->value()) {
  2000. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  2001. } else {
  2002. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  2003. }
  2004. return true;
  2005. } else {
  2006. assert(type->AsVector());
  2007. if (constants[0]->AsNullConstant()) {
  2008. // All values come from false id.
  2009. inst->SetOpcode(SpvOpCopyObject);
  2010. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  2011. return true;
  2012. } else {
  2013. // Convert to a vector shuffle.
  2014. std::vector<Operand> ops;
  2015. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  2016. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  2017. const analysis::VectorConstant* vector_const =
  2018. constants[0]->AsVectorConstant();
  2019. uint32_t size =
  2020. static_cast<uint32_t>(vector_const->GetComponents().size());
  2021. for (uint32_t i = 0; i != size; ++i) {
  2022. const analysis::Constant* component =
  2023. vector_const->GetComponents()[i];
  2024. if (component->AsNullConstant() ||
  2025. !component->AsBoolConstant()->value()) {
  2026. // Selecting from the false vector which is the second input
  2027. // vector to the shuffle. Offset the index by |size|.
  2028. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  2029. } else {
  2030. // Selecting from true vector which is the first input vector to
  2031. // the shuffle.
  2032. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  2033. }
  2034. }
  2035. inst->SetOpcode(SpvOpVectorShuffle);
  2036. inst->SetInOperands(std::move(ops));
  2037. return true;
  2038. }
  2039. }
  2040. }
  2041. return false;
  2042. };
  2043. }
  2044. enum class FloatConstantKind { Unknown, Zero, One };
  2045. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  2046. if (constant == nullptr) {
  2047. return FloatConstantKind::Unknown;
  2048. }
  2049. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  2050. if (constant->AsNullConstant()) {
  2051. return FloatConstantKind::Zero;
  2052. } else if (const analysis::VectorConstant* vc =
  2053. constant->AsVectorConstant()) {
  2054. const std::vector<const analysis::Constant*>& components =
  2055. vc->GetComponents();
  2056. assert(!components.empty());
  2057. FloatConstantKind kind = getFloatConstantKind(components[0]);
  2058. for (size_t i = 1; i < components.size(); ++i) {
  2059. if (getFloatConstantKind(components[i]) != kind) {
  2060. return FloatConstantKind::Unknown;
  2061. }
  2062. }
  2063. return kind;
  2064. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  2065. if (fc->IsZero()) return FloatConstantKind::Zero;
  2066. uint32_t width = fc->type()->AsFloat()->width();
  2067. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  2068. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  2069. if (value == 0.0) {
  2070. return FloatConstantKind::Zero;
  2071. } else if (value == 1.0) {
  2072. return FloatConstantKind::One;
  2073. } else {
  2074. return FloatConstantKind::Unknown;
  2075. }
  2076. } else {
  2077. return FloatConstantKind::Unknown;
  2078. }
  2079. }
  2080. FoldingRule RedundantFAdd() {
  2081. return [](IRContext*, Instruction* inst,
  2082. const std::vector<const analysis::Constant*>& constants) {
  2083. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  2084. assert(constants.size() == 2);
  2085. if (!inst->IsFloatingPointFoldingAllowed()) {
  2086. return false;
  2087. }
  2088. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2089. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2090. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  2091. inst->SetOpcode(SpvOpCopyObject);
  2092. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2093. {inst->GetSingleWordInOperand(
  2094. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  2095. return true;
  2096. }
  2097. return false;
  2098. };
  2099. }
  2100. FoldingRule RedundantFSub() {
  2101. return [](IRContext*, Instruction* inst,
  2102. const std::vector<const analysis::Constant*>& constants) {
  2103. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  2104. assert(constants.size() == 2);
  2105. if (!inst->IsFloatingPointFoldingAllowed()) {
  2106. return false;
  2107. }
  2108. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2109. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2110. if (kind0 == FloatConstantKind::Zero) {
  2111. inst->SetOpcode(SpvOpFNegate);
  2112. inst->SetInOperands(
  2113. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  2114. return true;
  2115. }
  2116. if (kind1 == FloatConstantKind::Zero) {
  2117. inst->SetOpcode(SpvOpCopyObject);
  2118. inst->SetInOperands(
  2119. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2120. return true;
  2121. }
  2122. return false;
  2123. };
  2124. }
  2125. FoldingRule RedundantFMul() {
  2126. return [](IRContext*, Instruction* inst,
  2127. const std::vector<const analysis::Constant*>& constants) {
  2128. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  2129. assert(constants.size() == 2);
  2130. if (!inst->IsFloatingPointFoldingAllowed()) {
  2131. return false;
  2132. }
  2133. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2134. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2135. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  2136. inst->SetOpcode(SpvOpCopyObject);
  2137. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2138. {inst->GetSingleWordInOperand(
  2139. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  2140. return true;
  2141. }
  2142. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  2143. inst->SetOpcode(SpvOpCopyObject);
  2144. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  2145. {inst->GetSingleWordInOperand(
  2146. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  2147. return true;
  2148. }
  2149. return false;
  2150. };
  2151. }
  2152. FoldingRule RedundantFDiv() {
  2153. return [](IRContext*, Instruction* inst,
  2154. const std::vector<const analysis::Constant*>& constants) {
  2155. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  2156. assert(constants.size() == 2);
  2157. if (!inst->IsFloatingPointFoldingAllowed()) {
  2158. return false;
  2159. }
  2160. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  2161. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  2162. if (kind0 == FloatConstantKind::Zero) {
  2163. inst->SetOpcode(SpvOpCopyObject);
  2164. inst->SetInOperands(
  2165. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2166. return true;
  2167. }
  2168. if (kind1 == FloatConstantKind::One) {
  2169. inst->SetOpcode(SpvOpCopyObject);
  2170. inst->SetInOperands(
  2171. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  2172. return true;
  2173. }
  2174. return false;
  2175. };
  2176. }
  2177. FoldingRule RedundantFMix() {
  2178. return [](IRContext* context, Instruction* inst,
  2179. const std::vector<const analysis::Constant*>& constants) {
  2180. assert(inst->opcode() == SpvOpExtInst &&
  2181. "Wrong opcode. Should be OpExtInst.");
  2182. if (!inst->IsFloatingPointFoldingAllowed()) {
  2183. return false;
  2184. }
  2185. uint32_t instSetId =
  2186. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  2187. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  2188. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  2189. GLSLstd450FMix) {
  2190. assert(constants.size() == 5);
  2191. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  2192. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  2193. inst->SetOpcode(SpvOpCopyObject);
  2194. inst->SetInOperands(
  2195. {{SPV_OPERAND_TYPE_ID,
  2196. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  2197. ? kFMixXIdInIdx
  2198. : kFMixYIdInIdx)}}});
  2199. return true;
  2200. }
  2201. }
  2202. return false;
  2203. };
  2204. }
  2205. // This rule handles addition of zero for integers.
  2206. FoldingRule RedundantIAdd() {
  2207. return [](IRContext* context, Instruction* inst,
  2208. const std::vector<const analysis::Constant*>& constants) {
  2209. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  2210. uint32_t operand = std::numeric_limits<uint32_t>::max();
  2211. const analysis::Type* operand_type = nullptr;
  2212. if (constants[0] && constants[0]->IsZero()) {
  2213. operand = inst->GetSingleWordInOperand(1);
  2214. operand_type = constants[0]->type();
  2215. } else if (constants[1] && constants[1]->IsZero()) {
  2216. operand = inst->GetSingleWordInOperand(0);
  2217. operand_type = constants[1]->type();
  2218. }
  2219. if (operand != std::numeric_limits<uint32_t>::max()) {
  2220. const analysis::Type* inst_type =
  2221. context->get_type_mgr()->GetType(inst->type_id());
  2222. if (inst_type->IsSame(operand_type)) {
  2223. inst->SetOpcode(SpvOpCopyObject);
  2224. } else {
  2225. inst->SetOpcode(SpvOpBitcast);
  2226. }
  2227. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  2228. return true;
  2229. }
  2230. return false;
  2231. };
  2232. }
  2233. // This rule look for a dot with a constant vector containing a single 1 and
  2234. // the rest 0s. This is the same as doing an extract.
  2235. FoldingRule DotProductDoingExtract() {
  2236. return [](IRContext* context, Instruction* inst,
  2237. const std::vector<const analysis::Constant*>& constants) {
  2238. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  2239. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2240. if (!inst->IsFloatingPointFoldingAllowed()) {
  2241. return false;
  2242. }
  2243. for (int i = 0; i < 2; ++i) {
  2244. if (!constants[i]) {
  2245. continue;
  2246. }
  2247. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  2248. assert(vector_type && "Inputs to OpDot must be vectors.");
  2249. const analysis::Float* element_type =
  2250. vector_type->element_type()->AsFloat();
  2251. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  2252. uint32_t element_width = element_type->width();
  2253. if (element_width != 32 && element_width != 64) {
  2254. return false;
  2255. }
  2256. std::vector<const analysis::Constant*> components;
  2257. components = constants[i]->GetVectorComponents(const_mgr);
  2258. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  2259. uint32_t component_with_one = kNotFound;
  2260. bool all_others_zero = true;
  2261. for (uint32_t j = 0; j < components.size(); ++j) {
  2262. const analysis::Constant* element = components[j];
  2263. double value =
  2264. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  2265. if (value == 0.0) {
  2266. continue;
  2267. } else if (value == 1.0) {
  2268. if (component_with_one == kNotFound) {
  2269. component_with_one = j;
  2270. } else {
  2271. component_with_one = kNotFound;
  2272. break;
  2273. }
  2274. } else {
  2275. all_others_zero = false;
  2276. break;
  2277. }
  2278. }
  2279. if (!all_others_zero || component_with_one == kNotFound) {
  2280. continue;
  2281. }
  2282. std::vector<Operand> operands;
  2283. operands.push_back(
  2284. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  2285. operands.push_back(
  2286. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  2287. inst->SetOpcode(SpvOpCompositeExtract);
  2288. inst->SetInOperands(std::move(operands));
  2289. return true;
  2290. }
  2291. return false;
  2292. };
  2293. }
  2294. // If we are storing an undef, then we can remove the store.
  2295. //
  2296. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  2297. // is complicated. Waiting to see if it is needed.
  2298. FoldingRule StoringUndef() {
  2299. return [](IRContext* context, Instruction* inst,
  2300. const std::vector<const analysis::Constant*>&) {
  2301. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  2302. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  2303. // If this is a volatile store, the store cannot be removed.
  2304. if (inst->NumInOperands() == 3) {
  2305. if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
  2306. return false;
  2307. }
  2308. }
  2309. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  2310. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  2311. if (object_inst->opcode() == SpvOpUndef) {
  2312. inst->ToNop();
  2313. return true;
  2314. }
  2315. return false;
  2316. };
  2317. }
  2318. FoldingRule VectorShuffleFeedingShuffle() {
  2319. return [](IRContext* context, Instruction* inst,
  2320. const std::vector<const analysis::Constant*>&) {
  2321. assert(inst->opcode() == SpvOpVectorShuffle &&
  2322. "Wrong opcode. Should be OpVectorShuffle.");
  2323. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  2324. analysis::TypeManager* type_mgr = context->get_type_mgr();
  2325. Instruction* feeding_shuffle_inst =
  2326. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  2327. analysis::Vector* op0_type =
  2328. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  2329. uint32_t op0_length = op0_type->element_count();
  2330. bool feeder_is_op0 = true;
  2331. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  2332. feeding_shuffle_inst =
  2333. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  2334. feeder_is_op0 = false;
  2335. }
  2336. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  2337. return false;
  2338. }
  2339. Instruction* feeder2 =
  2340. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  2341. analysis::Vector* feeder_op0_type =
  2342. type_mgr->GetType(feeder2->type_id())->AsVector();
  2343. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  2344. uint32_t new_feeder_id = 0;
  2345. std::vector<Operand> new_operands;
  2346. new_operands.resize(
  2347. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  2348. const uint32_t undef_literal = 0xffffffff;
  2349. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  2350. uint32_t component_index = inst->GetSingleWordInOperand(op);
  2351. // Do not interpret the undefined value literal as coming from operand 1.
  2352. if (component_index != undef_literal &&
  2353. feeder_is_op0 == (component_index < op0_length)) {
  2354. // This component comes from the feeding_shuffle_inst. Update
  2355. // |component_index| to be the index into the operand of the feeder.
  2356. // Adjust component_index to get the index into the operands of the
  2357. // feeding_shuffle_inst.
  2358. if (component_index >= op0_length) {
  2359. component_index -= op0_length;
  2360. }
  2361. component_index =
  2362. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  2363. // Check if we are using a component from the first or second operand of
  2364. // the feeding instruction.
  2365. if (component_index < feeder_op0_length) {
  2366. if (new_feeder_id == 0) {
  2367. // First time through, save the id of the operand the element comes
  2368. // from.
  2369. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  2370. } else if (new_feeder_id !=
  2371. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  2372. // We need both elements of the feeding_shuffle_inst, so we cannot
  2373. // fold.
  2374. return false;
  2375. }
  2376. } else if (component_index != undef_literal) {
  2377. if (new_feeder_id == 0) {
  2378. // First time through, save the id of the operand the element comes
  2379. // from.
  2380. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  2381. } else if (new_feeder_id !=
  2382. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  2383. // We need both elements of the feeding_shuffle_inst, so we cannot
  2384. // fold.
  2385. return false;
  2386. }
  2387. component_index -= feeder_op0_length;
  2388. }
  2389. if (!feeder_is_op0 && component_index != undef_literal) {
  2390. component_index += op0_length;
  2391. }
  2392. }
  2393. new_operands.push_back(
  2394. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  2395. }
  2396. if (new_feeder_id == 0) {
  2397. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2398. const analysis::Type* type =
  2399. type_mgr->GetType(feeding_shuffle_inst->type_id());
  2400. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  2401. new_feeder_id =
  2402. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  2403. }
  2404. if (feeder_is_op0) {
  2405. // If the size of the first vector operand changed then the indices
  2406. // referring to the second operand need to be adjusted.
  2407. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  2408. analysis::Type* new_feeder_type =
  2409. type_mgr->GetType(new_feeder_inst->type_id());
  2410. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  2411. int32_t adjustment = op0_length - new_op0_size;
  2412. if (adjustment != 0) {
  2413. for (uint32_t i = 2; i < new_operands.size(); i++) {
  2414. if (inst->GetSingleWordInOperand(i) >= op0_length) {
  2415. new_operands[i].words[0] -= adjustment;
  2416. }
  2417. }
  2418. }
  2419. new_operands[0].words[0] = new_feeder_id;
  2420. new_operands[1] = inst->GetInOperand(1);
  2421. } else {
  2422. new_operands[1].words[0] = new_feeder_id;
  2423. new_operands[0] = inst->GetInOperand(0);
  2424. }
  2425. inst->SetInOperands(std::move(new_operands));
  2426. return true;
  2427. };
  2428. }
  2429. // Removes duplicate ids from the interface list of an OpEntryPoint
  2430. // instruction.
  2431. FoldingRule RemoveRedundantOperands() {
  2432. return [](IRContext*, Instruction* inst,
  2433. const std::vector<const analysis::Constant*>&) {
  2434. assert(inst->opcode() == SpvOpEntryPoint &&
  2435. "Wrong opcode. Should be OpEntryPoint.");
  2436. bool has_redundant_operand = false;
  2437. std::unordered_set<uint32_t> seen_operands;
  2438. std::vector<Operand> new_operands;
  2439. new_operands.emplace_back(inst->GetOperand(0));
  2440. new_operands.emplace_back(inst->GetOperand(1));
  2441. new_operands.emplace_back(inst->GetOperand(2));
  2442. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2443. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2444. new_operands.emplace_back(inst->GetOperand(i));
  2445. } else {
  2446. has_redundant_operand = true;
  2447. }
  2448. }
  2449. if (!has_redundant_operand) {
  2450. return false;
  2451. }
  2452. inst->SetInOperands(std::move(new_operands));
  2453. return true;
  2454. };
  2455. }
  2456. // If an image instruction's operand is a constant, updates the image operand
  2457. // flag from Offset to ConstOffset.
  2458. FoldingRule UpdateImageOperands() {
  2459. return [](IRContext*, Instruction* inst,
  2460. const std::vector<const analysis::Constant*>& constants) {
  2461. const auto opcode = inst->opcode();
  2462. (void)opcode;
  2463. assert((opcode == SpvOpImageSampleImplicitLod ||
  2464. opcode == SpvOpImageSampleExplicitLod ||
  2465. opcode == SpvOpImageSampleDrefImplicitLod ||
  2466. opcode == SpvOpImageSampleDrefExplicitLod ||
  2467. opcode == SpvOpImageSampleProjImplicitLod ||
  2468. opcode == SpvOpImageSampleProjExplicitLod ||
  2469. opcode == SpvOpImageSampleProjDrefImplicitLod ||
  2470. opcode == SpvOpImageSampleProjDrefExplicitLod ||
  2471. opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
  2472. opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
  2473. opcode == SpvOpImageWrite ||
  2474. opcode == SpvOpImageSparseSampleImplicitLod ||
  2475. opcode == SpvOpImageSparseSampleExplicitLod ||
  2476. opcode == SpvOpImageSparseSampleDrefImplicitLod ||
  2477. opcode == SpvOpImageSparseSampleDrefExplicitLod ||
  2478. opcode == SpvOpImageSparseSampleProjImplicitLod ||
  2479. opcode == SpvOpImageSparseSampleProjExplicitLod ||
  2480. opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
  2481. opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
  2482. opcode == SpvOpImageSparseFetch ||
  2483. opcode == SpvOpImageSparseGather ||
  2484. opcode == SpvOpImageSparseDrefGather ||
  2485. opcode == SpvOpImageSparseRead) &&
  2486. "Wrong opcode. Should be an image instruction.");
  2487. int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
  2488. if (operand_index >= 0) {
  2489. auto image_operands = inst->GetSingleWordInOperand(operand_index);
  2490. if (image_operands & SpvImageOperandsOffsetMask) {
  2491. uint32_t offset_operand_index = operand_index + 1;
  2492. if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
  2493. if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
  2494. if (image_operands & SpvImageOperandsGradMask)
  2495. offset_operand_index += 2;
  2496. assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
  2497. "Offset and ConstOffset may not be used together");
  2498. if (offset_operand_index < inst->NumOperands()) {
  2499. if (constants[offset_operand_index]) {
  2500. image_operands = image_operands | SpvImageOperandsConstOffsetMask;
  2501. image_operands = image_operands & ~SpvImageOperandsOffsetMask;
  2502. inst->SetInOperand(operand_index, {image_operands});
  2503. return true;
  2504. }
  2505. }
  2506. }
  2507. }
  2508. return false;
  2509. };
  2510. }
  2511. } // namespace
  2512. void FoldingRules::AddFoldingRules() {
  2513. // Add all folding rules to the list for the opcodes to which they apply.
  2514. // Note that the order in which rules are added to the list matters. If a rule
  2515. // applies to the instruction, the rest of the rules will not be attempted.
  2516. // Take that into consideration.
  2517. rules_[SpvOpBitcast].push_back(BitCastScalarOrVector());
  2518. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
  2519. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  2520. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract);
  2521. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2522. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  2523. rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
  2524. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  2525. rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
  2526. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  2527. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  2528. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  2529. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  2530. rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
  2531. rules_[SpvOpFAdd].push_back(FactorAddMuls());
  2532. rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
  2533. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  2534. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  2535. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  2536. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  2537. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  2538. rules_[SpvOpFMul].push_back(RedundantFMul());
  2539. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  2540. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  2541. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  2542. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  2543. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  2544. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  2545. rules_[SpvOpFSub].push_back(RedundantFSub());
  2546. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  2547. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  2548. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  2549. rules_[SpvOpFSub].push_back(MergeMulSubArithmetic);
  2550. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  2551. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  2552. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  2553. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  2554. rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
  2555. rules_[SpvOpIAdd].push_back(FactorAddMuls());
  2556. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  2557. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  2558. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  2559. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  2560. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  2561. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  2562. rules_[SpvOpPhi].push_back(RedundantPhi());
  2563. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  2564. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  2565. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  2566. rules_[SpvOpSelect].push_back(RedundantSelect());
  2567. rules_[SpvOpStore].push_back(StoringUndef());
  2568. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2569. rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
  2570. rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
  2571. rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
  2572. rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
  2573. rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
  2574. rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
  2575. rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
  2576. rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
  2577. rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
  2578. rules_[SpvOpImageGather].push_back(UpdateImageOperands());
  2579. rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
  2580. rules_[SpvOpImageRead].push_back(UpdateImageOperands());
  2581. rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
  2582. rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
  2583. rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
  2584. rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
  2585. UpdateImageOperands());
  2586. rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
  2587. UpdateImageOperands());
  2588. rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
  2589. UpdateImageOperands());
  2590. rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
  2591. UpdateImageOperands());
  2592. rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
  2593. UpdateImageOperands());
  2594. rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
  2595. UpdateImageOperands());
  2596. rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
  2597. rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
  2598. rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
  2599. rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
  2600. FeatureManager* feature_manager = context_->get_feature_mgr();
  2601. // Add rules for GLSLstd450
  2602. uint32_t ext_inst_glslstd450_id =
  2603. feature_manager->GetExtInstImportId_GLSLstd450();
  2604. if (ext_inst_glslstd450_id != 0) {
  2605. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2606. RedundantFMix());
  2607. }
  2608. }
  2609. } // namespace opt
  2610. } // namespace spvtools