folding_rules.cpp 94 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <limits>
  16. #include <memory>
  17. #include <utility>
  18. #include "ir_builder.h"
  19. #include "source/latest_version_glsl_std_450_header.h"
  20. #include "source/opt/ir_context.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. const uint32_t kExtractCompositeIdInIdx = 0;
  25. const uint32_t kInsertObjectIdInIdx = 0;
  26. const uint32_t kInsertCompositeIdInIdx = 1;
  27. const uint32_t kExtInstSetIdInIdx = 0;
  28. const uint32_t kExtInstInstructionInIdx = 1;
  29. const uint32_t kFMixXIdInIdx = 2;
  30. const uint32_t kFMixYIdInIdx = 3;
  31. const uint32_t kFMixAIdInIdx = 4;
  32. const uint32_t kStoreObjectInIdx = 1;
  33. // Some image instructions may contain an "image operands" argument.
  34. // Returns the operand index for the "image operands".
  35. // Returns -1 if the instruction does not have image operands.
  36. int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
  37. const auto opcode = inst->opcode();
  38. switch (opcode) {
  39. case SpvOpImageSampleImplicitLod:
  40. case SpvOpImageSampleExplicitLod:
  41. case SpvOpImageSampleProjImplicitLod:
  42. case SpvOpImageSampleProjExplicitLod:
  43. case SpvOpImageFetch:
  44. case SpvOpImageRead:
  45. case SpvOpImageSparseSampleImplicitLod:
  46. case SpvOpImageSparseSampleExplicitLod:
  47. case SpvOpImageSparseSampleProjImplicitLod:
  48. case SpvOpImageSparseSampleProjExplicitLod:
  49. case SpvOpImageSparseFetch:
  50. case SpvOpImageSparseRead:
  51. return inst->NumOperands() > 4 ? 2 : -1;
  52. case SpvOpImageSampleDrefImplicitLod:
  53. case SpvOpImageSampleDrefExplicitLod:
  54. case SpvOpImageSampleProjDrefImplicitLod:
  55. case SpvOpImageSampleProjDrefExplicitLod:
  56. case SpvOpImageGather:
  57. case SpvOpImageDrefGather:
  58. case SpvOpImageSparseSampleDrefImplicitLod:
  59. case SpvOpImageSparseSampleDrefExplicitLod:
  60. case SpvOpImageSparseSampleProjDrefImplicitLod:
  61. case SpvOpImageSparseSampleProjDrefExplicitLod:
  62. case SpvOpImageSparseGather:
  63. case SpvOpImageSparseDrefGather:
  64. return inst->NumOperands() > 5 ? 3 : -1;
  65. case SpvOpImageWrite:
  66. return inst->NumOperands() > 3 ? 3 : -1;
  67. default:
  68. return -1;
  69. }
  70. }
  71. // Returns the element width of |type|.
  72. uint32_t ElementWidth(const analysis::Type* type) {
  73. if (const analysis::Vector* vec_type = type->AsVector()) {
  74. return ElementWidth(vec_type->element_type());
  75. } else if (const analysis::Float* float_type = type->AsFloat()) {
  76. return float_type->width();
  77. } else {
  78. assert(type->AsInteger());
  79. return type->AsInteger()->width();
  80. }
  81. }
  82. // Returns true if |type| is Float or a vector of Float.
  83. bool HasFloatingPoint(const analysis::Type* type) {
  84. if (type->AsFloat()) {
  85. return true;
  86. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  87. return vec_type->element_type()->AsFloat() != nullptr;
  88. }
  89. return false;
  90. }
  91. // Returns false if |val| is NaN, infinite or subnormal.
  92. template <typename T>
  93. bool IsValidResult(T val) {
  94. int classified = std::fpclassify(val);
  95. switch (classified) {
  96. case FP_NAN:
  97. case FP_INFINITE:
  98. case FP_SUBNORMAL:
  99. return false;
  100. default:
  101. return true;
  102. }
  103. }
  104. const analysis::Constant* ConstInput(
  105. const std::vector<const analysis::Constant*>& constants) {
  106. return constants[0] ? constants[0] : constants[1];
  107. }
  108. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  109. Instruction* inst) {
  110. uint32_t in_op = c ? 1u : 0u;
  111. return context->get_def_use_mgr()->GetDef(
  112. inst->GetSingleWordInOperand(in_op));
  113. }
  114. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  115. // constant.
  116. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  117. const analysis::Constant* c) {
  118. assert(c);
  119. assert(c->type()->AsFloat());
  120. uint32_t width = c->type()->AsFloat()->width();
  121. assert(width == 32 || width == 64);
  122. std::vector<uint32_t> words;
  123. if (width == 64) {
  124. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  125. words = result.GetWords();
  126. } else {
  127. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  128. words = result.GetWords();
  129. }
  130. const analysis::Constant* negated_const =
  131. const_mgr->GetConstant(c->type(), std::move(words));
  132. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  133. }
  134. std::vector<uint32_t> ExtractInts(uint64_t val) {
  135. std::vector<uint32_t> words;
  136. words.push_back(static_cast<uint32_t>(val));
  137. words.push_back(static_cast<uint32_t>(val >> 32));
  138. return words;
  139. }
  140. // Negates the integer constant |c|. Returns the id of the defining instruction.
  141. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  142. const analysis::Constant* c) {
  143. assert(c);
  144. assert(c->type()->AsInteger());
  145. uint32_t width = c->type()->AsInteger()->width();
  146. assert(width == 32 || width == 64);
  147. std::vector<uint32_t> words;
  148. if (width == 64) {
  149. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  150. words = ExtractInts(uval);
  151. } else {
  152. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  153. }
  154. const analysis::Constant* negated_const =
  155. const_mgr->GetConstant(c->type(), std::move(words));
  156. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  157. }
  158. // Negates the vector constant |c|. Returns the id of the defining instruction.
  159. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  160. const analysis::Constant* c) {
  161. assert(const_mgr && c);
  162. assert(c->type()->AsVector());
  163. if (c->AsNullConstant()) {
  164. // 0.0 vs -0.0 shouldn't matter.
  165. return const_mgr->GetDefiningInstruction(c)->result_id();
  166. } else {
  167. const analysis::Type* component_type =
  168. c->AsVectorConstant()->component_type();
  169. std::vector<uint32_t> words;
  170. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  171. if (component_type->AsFloat()) {
  172. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  173. } else {
  174. assert(component_type->AsInteger());
  175. words.push_back(NegateIntegerConstant(const_mgr, comp));
  176. }
  177. }
  178. const analysis::Constant* negated_const =
  179. const_mgr->GetConstant(c->type(), std::move(words));
  180. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  181. }
  182. }
  183. // Negates |c|. Returns the id of the defining instruction.
  184. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  185. const analysis::Constant* c) {
  186. if (c->type()->AsVector()) {
  187. return NegateVectorConstant(const_mgr, c);
  188. } else if (c->type()->AsFloat()) {
  189. return NegateFloatingPointConstant(const_mgr, c);
  190. } else {
  191. assert(c->type()->AsInteger());
  192. return NegateIntegerConstant(const_mgr, c);
  193. }
  194. }
  195. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  196. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  197. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  198. const analysis::Constant* c) {
  199. assert(const_mgr && c);
  200. assert(c->type()->AsFloat());
  201. uint32_t width = c->type()->AsFloat()->width();
  202. assert(width == 32 || width == 64);
  203. std::vector<uint32_t> words;
  204. if (width == 64) {
  205. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  206. if (!IsValidResult(result.getAsFloat())) return 0;
  207. words = result.GetWords();
  208. } else {
  209. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  210. if (!IsValidResult(result.getAsFloat())) return 0;
  211. words = result.GetWords();
  212. }
  213. const analysis::Constant* negated_const =
  214. const_mgr->GetConstant(c->type(), std::move(words));
  215. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  216. }
  217. // Replaces fdiv where second operand is constant with fmul.
  218. FoldingRule ReciprocalFDiv() {
  219. return [](IRContext* context, Instruction* inst,
  220. const std::vector<const analysis::Constant*>& constants) {
  221. assert(inst->opcode() == SpvOpFDiv);
  222. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  223. const analysis::Type* type =
  224. context->get_type_mgr()->GetType(inst->type_id());
  225. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  226. uint32_t width = ElementWidth(type);
  227. if (width != 32 && width != 64) return false;
  228. if (constants[1] != nullptr) {
  229. uint32_t id = 0;
  230. if (const analysis::VectorConstant* vector_const =
  231. constants[1]->AsVectorConstant()) {
  232. std::vector<uint32_t> neg_ids;
  233. for (auto& comp : vector_const->GetComponents()) {
  234. id = Reciprocal(const_mgr, comp);
  235. if (id == 0) return false;
  236. neg_ids.push_back(id);
  237. }
  238. const analysis::Constant* negated_const =
  239. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  240. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  241. } else if (constants[1]->AsFloatConstant()) {
  242. id = Reciprocal(const_mgr, constants[1]);
  243. if (id == 0) return false;
  244. } else {
  245. // Don't fold a null constant.
  246. return false;
  247. }
  248. inst->SetOpcode(SpvOpFMul);
  249. inst->SetInOperands(
  250. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  251. {SPV_OPERAND_TYPE_ID, {id}}});
  252. return true;
  253. }
  254. return false;
  255. };
  256. }
  257. // Elides consecutive negate instructions.
  258. FoldingRule MergeNegateArithmetic() {
  259. return [](IRContext* context, Instruction* inst,
  260. const std::vector<const analysis::Constant*>& constants) {
  261. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  262. (void)constants;
  263. const analysis::Type* type =
  264. context->get_type_mgr()->GetType(inst->type_id());
  265. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  266. return false;
  267. Instruction* op_inst =
  268. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  269. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  270. return false;
  271. if (op_inst->opcode() == inst->opcode()) {
  272. // Elide negates.
  273. inst->SetOpcode(SpvOpCopyObject);
  274. inst->SetInOperands(
  275. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  276. return true;
  277. }
  278. return false;
  279. };
  280. }
  281. // Merges negate into a mul or div operation if that operation contains a
  282. // constant operand.
  283. // Cases:
  284. // -(x * 2) = x * -2
  285. // -(2 * x) = x * -2
  286. // -(x / 2) = x / -2
  287. // -(2 / x) = -2 / x
  288. FoldingRule MergeNegateMulDivArithmetic() {
  289. return [](IRContext* context, Instruction* inst,
  290. const std::vector<const analysis::Constant*>& constants) {
  291. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  292. (void)constants;
  293. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  294. const analysis::Type* type =
  295. context->get_type_mgr()->GetType(inst->type_id());
  296. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  297. return false;
  298. Instruction* op_inst =
  299. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  300. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  301. return false;
  302. uint32_t width = ElementWidth(type);
  303. if (width != 32 && width != 64) return false;
  304. SpvOp opcode = op_inst->opcode();
  305. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  306. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  307. std::vector<const analysis::Constant*> op_constants =
  308. const_mgr->GetOperandConstants(op_inst);
  309. // Merge negate into mul or div if one operand is constant.
  310. if (op_constants[0] || op_constants[1]) {
  311. bool zero_is_variable = op_constants[0] == nullptr;
  312. const analysis::Constant* c = ConstInput(op_constants);
  313. uint32_t neg_id = NegateConstant(const_mgr, c);
  314. uint32_t non_const_id = zero_is_variable
  315. ? op_inst->GetSingleWordInOperand(0u)
  316. : op_inst->GetSingleWordInOperand(1u);
  317. // Change this instruction to a mul/div.
  318. inst->SetOpcode(op_inst->opcode());
  319. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  320. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  321. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  322. inst->SetInOperands(
  323. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  324. } else {
  325. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  326. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  327. }
  328. return true;
  329. }
  330. }
  331. return false;
  332. };
  333. }
  334. // Merges negate into a add or sub operation if that operation contains a
  335. // constant operand.
  336. // Cases:
  337. // -(x + 2) = -2 - x
  338. // -(2 + x) = -2 - x
  339. // -(x - 2) = 2 - x
  340. // -(2 - x) = x - 2
  341. FoldingRule MergeNegateAddSubArithmetic() {
  342. return [](IRContext* context, Instruction* inst,
  343. const std::vector<const analysis::Constant*>& constants) {
  344. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  345. (void)constants;
  346. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  347. const analysis::Type* type =
  348. context->get_type_mgr()->GetType(inst->type_id());
  349. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  350. return false;
  351. Instruction* op_inst =
  352. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  353. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  354. return false;
  355. uint32_t width = ElementWidth(type);
  356. if (width != 32 && width != 64) return false;
  357. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  358. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  359. std::vector<const analysis::Constant*> op_constants =
  360. const_mgr->GetOperandConstants(op_inst);
  361. if (op_constants[0] || op_constants[1]) {
  362. bool zero_is_variable = op_constants[0] == nullptr;
  363. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  364. (op_inst->opcode() == SpvOpIAdd);
  365. bool swap_operands = !is_add || zero_is_variable;
  366. bool negate_const = is_add;
  367. const analysis::Constant* c = ConstInput(op_constants);
  368. uint32_t const_id = 0;
  369. if (negate_const) {
  370. const_id = NegateConstant(const_mgr, c);
  371. } else {
  372. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  373. : op_inst->GetSingleWordInOperand(0u);
  374. }
  375. // Swap operands if necessary and make the instruction a subtraction.
  376. uint32_t op0 =
  377. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  378. uint32_t op1 =
  379. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  380. if (swap_operands) std::swap(op0, op1);
  381. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  382. inst->SetInOperands(
  383. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  384. return true;
  385. }
  386. }
  387. return false;
  388. };
  389. }
  390. // Returns true if |c| has a zero element.
  391. bool HasZero(const analysis::Constant* c) {
  392. if (c->AsNullConstant()) {
  393. return true;
  394. }
  395. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  396. for (auto& comp : vec_const->GetComponents())
  397. if (HasZero(comp)) return true;
  398. } else {
  399. assert(c->AsScalarConstant());
  400. return c->AsScalarConstant()->IsZero();
  401. }
  402. return false;
  403. }
  404. // Performs |input1| |opcode| |input2| and returns the merged constant result
  405. // id. Returns 0 if the result is not a valid value. The input types must be
  406. // Float.
  407. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  408. SpvOp opcode,
  409. const analysis::Constant* input1,
  410. const analysis::Constant* input2) {
  411. const analysis::Type* type = input1->type();
  412. assert(type->AsFloat());
  413. uint32_t width = type->AsFloat()->width();
  414. assert(width == 32 || width == 64);
  415. std::vector<uint32_t> words;
  416. #define FOLD_OP(op) \
  417. if (width == 64) { \
  418. utils::FloatProxy<double> val = \
  419. input1->GetDouble() op input2->GetDouble(); \
  420. double dval = val.getAsFloat(); \
  421. if (!IsValidResult(dval)) return 0; \
  422. words = val.GetWords(); \
  423. } else { \
  424. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  425. float fval = val.getAsFloat(); \
  426. if (!IsValidResult(fval)) return 0; \
  427. words = val.GetWords(); \
  428. }
  429. switch (opcode) {
  430. case SpvOpFMul:
  431. FOLD_OP(*);
  432. break;
  433. case SpvOpFDiv:
  434. if (HasZero(input2)) return 0;
  435. FOLD_OP(/);
  436. break;
  437. case SpvOpFAdd:
  438. FOLD_OP(+);
  439. break;
  440. case SpvOpFSub:
  441. FOLD_OP(-);
  442. break;
  443. default:
  444. assert(false && "Unexpected operation");
  445. break;
  446. }
  447. #undef FOLD_OP
  448. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  449. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  450. }
  451. // Performs |input1| |opcode| |input2| and returns the merged constant result
  452. // id. Returns 0 if the result is not a valid value. The input types must be
  453. // Integers.
  454. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  455. SpvOp opcode, const analysis::Constant* input1,
  456. const analysis::Constant* input2) {
  457. assert(input1->type()->AsInteger());
  458. const analysis::Integer* type = input1->type()->AsInteger();
  459. uint32_t width = type->AsInteger()->width();
  460. assert(width == 32 || width == 64);
  461. std::vector<uint32_t> words;
  462. #define FOLD_OP(op) \
  463. if (width == 64) { \
  464. if (type->IsSigned()) { \
  465. int64_t val = input1->GetS64() op input2->GetS64(); \
  466. words = ExtractInts(static_cast<uint64_t>(val)); \
  467. } else { \
  468. uint64_t val = input1->GetU64() op input2->GetU64(); \
  469. words = ExtractInts(val); \
  470. } \
  471. } else { \
  472. if (type->IsSigned()) { \
  473. int32_t val = input1->GetS32() op input2->GetS32(); \
  474. words.push_back(static_cast<uint32_t>(val)); \
  475. } else { \
  476. uint32_t val = input1->GetU32() op input2->GetU32(); \
  477. words.push_back(val); \
  478. } \
  479. }
  480. switch (opcode) {
  481. case SpvOpIMul:
  482. FOLD_OP(*);
  483. break;
  484. case SpvOpSDiv:
  485. case SpvOpUDiv:
  486. assert(false && "Should not merge integer division");
  487. break;
  488. case SpvOpIAdd:
  489. FOLD_OP(+);
  490. break;
  491. case SpvOpISub:
  492. FOLD_OP(-);
  493. break;
  494. default:
  495. assert(false && "Unexpected operation");
  496. break;
  497. }
  498. #undef FOLD_OP
  499. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  500. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  501. }
  502. // Performs |input1| |opcode| |input2| and returns the merged constant result
  503. // id. Returns 0 if the result is not a valid value. The input types must be
  504. // Integers, Floats or Vectors of such.
  505. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  506. const analysis::Constant* input1,
  507. const analysis::Constant* input2) {
  508. assert(input1 && input2);
  509. assert(input1->type() == input2->type());
  510. const analysis::Type* type = input1->type();
  511. std::vector<uint32_t> words;
  512. if (const analysis::Vector* vector_type = type->AsVector()) {
  513. const analysis::Type* ele_type = vector_type->element_type();
  514. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  515. uint32_t id = 0;
  516. const analysis::Constant* input1_comp = nullptr;
  517. if (const analysis::VectorConstant* input1_vector =
  518. input1->AsVectorConstant()) {
  519. input1_comp = input1_vector->GetComponents()[i];
  520. } else {
  521. assert(input1->AsNullConstant());
  522. input1_comp = const_mgr->GetConstant(ele_type, {});
  523. }
  524. const analysis::Constant* input2_comp = nullptr;
  525. if (const analysis::VectorConstant* input2_vector =
  526. input2->AsVectorConstant()) {
  527. input2_comp = input2_vector->GetComponents()[i];
  528. } else {
  529. assert(input2->AsNullConstant());
  530. input2_comp = const_mgr->GetConstant(ele_type, {});
  531. }
  532. if (ele_type->AsFloat()) {
  533. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  534. input2_comp);
  535. } else {
  536. assert(ele_type->AsInteger());
  537. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  538. input2_comp);
  539. }
  540. if (id == 0) return 0;
  541. words.push_back(id);
  542. }
  543. const analysis::Constant* merged_const =
  544. const_mgr->GetConstant(type, words);
  545. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  546. } else if (type->AsFloat()) {
  547. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  548. } else {
  549. assert(type->AsInteger());
  550. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  551. }
  552. }
  553. // Merges consecutive multiplies where each contains one constant operand.
  554. // Cases:
  555. // 2 * (x * 2) = x * 4
  556. // 2 * (2 * x) = x * 4
  557. // (x * 2) * 2 = x * 4
  558. // (2 * x) * 2 = x * 4
  559. FoldingRule MergeMulMulArithmetic() {
  560. return [](IRContext* context, Instruction* inst,
  561. const std::vector<const analysis::Constant*>& constants) {
  562. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  563. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  564. const analysis::Type* type =
  565. context->get_type_mgr()->GetType(inst->type_id());
  566. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  567. return false;
  568. uint32_t width = ElementWidth(type);
  569. if (width != 32 && width != 64) return false;
  570. // Determine the constant input and the variable input in |inst|.
  571. const analysis::Constant* const_input1 = ConstInput(constants);
  572. if (!const_input1) return false;
  573. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  574. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  575. return false;
  576. if (other_inst->opcode() == inst->opcode()) {
  577. std::vector<const analysis::Constant*> other_constants =
  578. const_mgr->GetOperandConstants(other_inst);
  579. const analysis::Constant* const_input2 = ConstInput(other_constants);
  580. if (!const_input2) return false;
  581. bool other_first_is_variable = other_constants[0] == nullptr;
  582. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  583. const_input1, const_input2);
  584. if (merged_id == 0) return false;
  585. uint32_t non_const_id = other_first_is_variable
  586. ? other_inst->GetSingleWordInOperand(0u)
  587. : other_inst->GetSingleWordInOperand(1u);
  588. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  589. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  590. return true;
  591. }
  592. return false;
  593. };
  594. }
  595. // Merges divides into subsequent multiplies if each instruction contains one
  596. // constant operand. Does not support integer operations.
  597. // Cases:
  598. // 2 * (x / 2) = x * 1
  599. // 2 * (2 / x) = 4 / x
  600. // (x / 2) * 2 = x * 1
  601. // (2 / x) * 2 = 4 / x
  602. // (y / x) * x = y
  603. // x * (y / x) = y
  604. FoldingRule MergeMulDivArithmetic() {
  605. return [](IRContext* context, Instruction* inst,
  606. const std::vector<const analysis::Constant*>& constants) {
  607. assert(inst->opcode() == SpvOpFMul);
  608. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  609. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  610. const analysis::Type* type =
  611. context->get_type_mgr()->GetType(inst->type_id());
  612. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  613. uint32_t width = ElementWidth(type);
  614. if (width != 32 && width != 64) return false;
  615. for (uint32_t i = 0; i < 2; i++) {
  616. uint32_t op_id = inst->GetSingleWordInOperand(i);
  617. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  618. if (op_inst->opcode() == SpvOpFDiv) {
  619. if (op_inst->GetSingleWordInOperand(1) ==
  620. inst->GetSingleWordInOperand(1 - i)) {
  621. inst->SetOpcode(SpvOpCopyObject);
  622. inst->SetInOperands(
  623. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  624. return true;
  625. }
  626. }
  627. }
  628. const analysis::Constant* const_input1 = ConstInput(constants);
  629. if (!const_input1) return false;
  630. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  631. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  632. if (other_inst->opcode() == SpvOpFDiv) {
  633. std::vector<const analysis::Constant*> other_constants =
  634. const_mgr->GetOperandConstants(other_inst);
  635. const analysis::Constant* const_input2 = ConstInput(other_constants);
  636. if (!const_input2 || HasZero(const_input2)) return false;
  637. bool other_first_is_variable = other_constants[0] == nullptr;
  638. // If the variable value is the second operand of the divide, multiply
  639. // the constants together. Otherwise divide the constants.
  640. uint32_t merged_id = PerformOperation(
  641. const_mgr,
  642. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  643. const_input1, const_input2);
  644. if (merged_id == 0) return false;
  645. uint32_t non_const_id = other_first_is_variable
  646. ? other_inst->GetSingleWordInOperand(0u)
  647. : other_inst->GetSingleWordInOperand(1u);
  648. // If the variable value is on the second operand of the div, then this
  649. // operation is a div. Otherwise it should be a multiply.
  650. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  651. : other_inst->opcode());
  652. if (other_first_is_variable) {
  653. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  654. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  655. } else {
  656. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  657. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  658. }
  659. return true;
  660. }
  661. return false;
  662. };
  663. }
  664. // Merges multiply of constant and negation.
  665. // Cases:
  666. // (-x) * 2 = x * -2
  667. // 2 * (-x) = x * -2
  668. FoldingRule MergeMulNegateArithmetic() {
  669. return [](IRContext* context, Instruction* inst,
  670. const std::vector<const analysis::Constant*>& constants) {
  671. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  672. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  673. const analysis::Type* type =
  674. context->get_type_mgr()->GetType(inst->type_id());
  675. bool uses_float = HasFloatingPoint(type);
  676. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  677. uint32_t width = ElementWidth(type);
  678. if (width != 32 && width != 64) return false;
  679. const analysis::Constant* const_input1 = ConstInput(constants);
  680. if (!const_input1) return false;
  681. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  682. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  683. return false;
  684. if (other_inst->opcode() == SpvOpFNegate ||
  685. other_inst->opcode() == SpvOpSNegate) {
  686. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  687. inst->SetInOperands(
  688. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  689. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  690. return true;
  691. }
  692. return false;
  693. };
  694. }
  695. // Merges consecutive divides if each instruction contains one constant operand.
  696. // Does not support integer division.
  697. // Cases:
  698. // 2 / (x / 2) = 4 / x
  699. // 4 / (2 / x) = 2 * x
  700. // (4 / x) / 2 = 2 / x
  701. // (x / 2) / 2 = x / 4
  702. FoldingRule MergeDivDivArithmetic() {
  703. return [](IRContext* context, Instruction* inst,
  704. const std::vector<const analysis::Constant*>& constants) {
  705. assert(inst->opcode() == SpvOpFDiv);
  706. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  707. const analysis::Type* type =
  708. context->get_type_mgr()->GetType(inst->type_id());
  709. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  710. uint32_t width = ElementWidth(type);
  711. if (width != 32 && width != 64) return false;
  712. const analysis::Constant* const_input1 = ConstInput(constants);
  713. if (!const_input1 || HasZero(const_input1)) return false;
  714. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  715. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  716. bool first_is_variable = constants[0] == nullptr;
  717. if (other_inst->opcode() == inst->opcode()) {
  718. std::vector<const analysis::Constant*> other_constants =
  719. const_mgr->GetOperandConstants(other_inst);
  720. const analysis::Constant* const_input2 = ConstInput(other_constants);
  721. if (!const_input2 || HasZero(const_input2)) return false;
  722. bool other_first_is_variable = other_constants[0] == nullptr;
  723. SpvOp merge_op = inst->opcode();
  724. if (other_first_is_variable) {
  725. // Constants magnify.
  726. merge_op = SpvOpFMul;
  727. }
  728. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  729. // because it is commutative.
  730. if (first_is_variable) std::swap(const_input1, const_input2);
  731. uint32_t merged_id =
  732. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  733. if (merged_id == 0) return false;
  734. uint32_t non_const_id = other_first_is_variable
  735. ? other_inst->GetSingleWordInOperand(0u)
  736. : other_inst->GetSingleWordInOperand(1u);
  737. SpvOp op = inst->opcode();
  738. if (!first_is_variable && !other_first_is_variable) {
  739. // Effectively div of 1/x, so change to multiply.
  740. op = SpvOpFMul;
  741. }
  742. uint32_t op1 = merged_id;
  743. uint32_t op2 = non_const_id;
  744. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  745. inst->SetOpcode(op);
  746. inst->SetInOperands(
  747. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  748. return true;
  749. }
  750. return false;
  751. };
  752. }
  753. // Fold multiplies succeeded by divides where each instruction contains a
  754. // constant operand. Does not support integer divide.
  755. // Cases:
  756. // 4 / (x * 2) = 2 / x
  757. // 4 / (2 * x) = 2 / x
  758. // (x * 4) / 2 = x * 2
  759. // (4 * x) / 2 = x * 2
  760. // (x * y) / x = y
  761. // (y * x) / x = y
  762. FoldingRule MergeDivMulArithmetic() {
  763. return [](IRContext* context, Instruction* inst,
  764. const std::vector<const analysis::Constant*>& constants) {
  765. assert(inst->opcode() == SpvOpFDiv);
  766. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  767. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  768. const analysis::Type* type =
  769. context->get_type_mgr()->GetType(inst->type_id());
  770. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  771. uint32_t width = ElementWidth(type);
  772. if (width != 32 && width != 64) return false;
  773. uint32_t op_id = inst->GetSingleWordInOperand(0);
  774. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  775. if (op_inst->opcode() == SpvOpFMul) {
  776. for (uint32_t i = 0; i < 2; i++) {
  777. if (op_inst->GetSingleWordInOperand(i) ==
  778. inst->GetSingleWordInOperand(1)) {
  779. inst->SetOpcode(SpvOpCopyObject);
  780. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  781. {op_inst->GetSingleWordInOperand(1 - i)}}});
  782. return true;
  783. }
  784. }
  785. }
  786. const analysis::Constant* const_input1 = ConstInput(constants);
  787. if (!const_input1 || HasZero(const_input1)) return false;
  788. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  789. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  790. bool first_is_variable = constants[0] == nullptr;
  791. if (other_inst->opcode() == SpvOpFMul) {
  792. std::vector<const analysis::Constant*> other_constants =
  793. const_mgr->GetOperandConstants(other_inst);
  794. const analysis::Constant* const_input2 = ConstInput(other_constants);
  795. if (!const_input2) return false;
  796. bool other_first_is_variable = other_constants[0] == nullptr;
  797. // This is an x / (*) case. Swap the inputs.
  798. if (first_is_variable) std::swap(const_input1, const_input2);
  799. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  800. const_input1, const_input2);
  801. if (merged_id == 0) return false;
  802. uint32_t non_const_id = other_first_is_variable
  803. ? other_inst->GetSingleWordInOperand(0u)
  804. : other_inst->GetSingleWordInOperand(1u);
  805. uint32_t op1 = merged_id;
  806. uint32_t op2 = non_const_id;
  807. if (first_is_variable) std::swap(op1, op2);
  808. // Convert to multiply
  809. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  810. inst->SetInOperands(
  811. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  812. return true;
  813. }
  814. return false;
  815. };
  816. }
  817. // Fold divides of a constant and a negation.
  818. // Cases:
  819. // (-x) / 2 = x / -2
  820. // 2 / (-x) = 2 / -x
  821. FoldingRule MergeDivNegateArithmetic() {
  822. return [](IRContext* context, Instruction* inst,
  823. const std::vector<const analysis::Constant*>& constants) {
  824. assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
  825. inst->opcode() == SpvOpUDiv);
  826. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  827. const analysis::Type* type =
  828. context->get_type_mgr()->GetType(inst->type_id());
  829. bool uses_float = HasFloatingPoint(type);
  830. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  831. uint32_t width = ElementWidth(type);
  832. if (width != 32 && width != 64) return false;
  833. const analysis::Constant* const_input1 = ConstInput(constants);
  834. if (!const_input1) return false;
  835. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  836. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  837. return false;
  838. bool first_is_variable = constants[0] == nullptr;
  839. if (other_inst->opcode() == SpvOpFNegate ||
  840. other_inst->opcode() == SpvOpSNegate) {
  841. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  842. if (first_is_variable) {
  843. inst->SetInOperands(
  844. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  845. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  846. } else {
  847. inst->SetInOperands(
  848. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  849. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  850. }
  851. return true;
  852. }
  853. return false;
  854. };
  855. }
  856. // Folds addition of a constant and a negation.
  857. // Cases:
  858. // (-x) + 2 = 2 - x
  859. // 2 + (-x) = 2 - x
  860. FoldingRule MergeAddNegateArithmetic() {
  861. return [](IRContext* context, Instruction* inst,
  862. const std::vector<const analysis::Constant*>& constants) {
  863. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  864. const analysis::Type* type =
  865. context->get_type_mgr()->GetType(inst->type_id());
  866. bool uses_float = HasFloatingPoint(type);
  867. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  868. const analysis::Constant* const_input1 = ConstInput(constants);
  869. if (!const_input1) return false;
  870. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  871. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  872. return false;
  873. if (other_inst->opcode() == SpvOpSNegate ||
  874. other_inst->opcode() == SpvOpFNegate) {
  875. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  876. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  877. : inst->GetSingleWordInOperand(1u);
  878. inst->SetInOperands(
  879. {{SPV_OPERAND_TYPE_ID, {const_id}},
  880. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  881. return true;
  882. }
  883. return false;
  884. };
  885. }
  886. // Folds subtraction of a constant and a negation.
  887. // Cases:
  888. // (-x) - 2 = -2 - x
  889. // 2 - (-x) = x + 2
  890. FoldingRule MergeSubNegateArithmetic() {
  891. return [](IRContext* context, Instruction* inst,
  892. const std::vector<const analysis::Constant*>& constants) {
  893. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  894. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  895. const analysis::Type* type =
  896. context->get_type_mgr()->GetType(inst->type_id());
  897. bool uses_float = HasFloatingPoint(type);
  898. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  899. uint32_t width = ElementWidth(type);
  900. if (width != 32 && width != 64) return false;
  901. const analysis::Constant* const_input1 = ConstInput(constants);
  902. if (!const_input1) return false;
  903. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  904. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  905. return false;
  906. if (other_inst->opcode() == SpvOpSNegate ||
  907. other_inst->opcode() == SpvOpFNegate) {
  908. uint32_t op1 = 0;
  909. uint32_t op2 = 0;
  910. SpvOp opcode = inst->opcode();
  911. if (constants[0] != nullptr) {
  912. op1 = other_inst->GetSingleWordInOperand(0u);
  913. op2 = inst->GetSingleWordInOperand(0u);
  914. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  915. } else {
  916. op1 = NegateConstant(const_mgr, const_input1);
  917. op2 = other_inst->GetSingleWordInOperand(0u);
  918. }
  919. inst->SetOpcode(opcode);
  920. inst->SetInOperands(
  921. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  922. return true;
  923. }
  924. return false;
  925. };
  926. }
  927. // Folds addition of an addition where each operation has a constant operand.
  928. // Cases:
  929. // (x + 2) + 2 = x + 4
  930. // (2 + x) + 2 = x + 4
  931. // 2 + (x + 2) = x + 4
  932. // 2 + (2 + x) = x + 4
  933. FoldingRule MergeAddAddArithmetic() {
  934. return [](IRContext* context, Instruction* inst,
  935. const std::vector<const analysis::Constant*>& constants) {
  936. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  937. const analysis::Type* type =
  938. context->get_type_mgr()->GetType(inst->type_id());
  939. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  940. bool uses_float = HasFloatingPoint(type);
  941. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  942. uint32_t width = ElementWidth(type);
  943. if (width != 32 && width != 64) return false;
  944. const analysis::Constant* const_input1 = ConstInput(constants);
  945. if (!const_input1) return false;
  946. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  947. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  948. return false;
  949. if (other_inst->opcode() == SpvOpFAdd ||
  950. other_inst->opcode() == SpvOpIAdd) {
  951. std::vector<const analysis::Constant*> other_constants =
  952. const_mgr->GetOperandConstants(other_inst);
  953. const analysis::Constant* const_input2 = ConstInput(other_constants);
  954. if (!const_input2) return false;
  955. Instruction* non_const_input =
  956. NonConstInput(context, other_constants[0], other_inst);
  957. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  958. const_input1, const_input2);
  959. if (merged_id == 0) return false;
  960. inst->SetInOperands(
  961. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  962. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  963. return true;
  964. }
  965. return false;
  966. };
  967. }
  968. // Folds addition of a subtraction where each operation has a constant operand.
  969. // Cases:
  970. // (x - 2) + 2 = x + 0
  971. // (2 - x) + 2 = 4 - x
  972. // 2 + (x - 2) = x + 0
  973. // 2 + (2 - x) = 4 - x
  974. FoldingRule MergeAddSubArithmetic() {
  975. return [](IRContext* context, Instruction* inst,
  976. const std::vector<const analysis::Constant*>& constants) {
  977. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  978. const analysis::Type* type =
  979. context->get_type_mgr()->GetType(inst->type_id());
  980. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  981. bool uses_float = HasFloatingPoint(type);
  982. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  983. uint32_t width = ElementWidth(type);
  984. if (width != 32 && width != 64) return false;
  985. const analysis::Constant* const_input1 = ConstInput(constants);
  986. if (!const_input1) return false;
  987. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  988. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  989. return false;
  990. if (other_inst->opcode() == SpvOpFSub ||
  991. other_inst->opcode() == SpvOpISub) {
  992. std::vector<const analysis::Constant*> other_constants =
  993. const_mgr->GetOperandConstants(other_inst);
  994. const analysis::Constant* const_input2 = ConstInput(other_constants);
  995. if (!const_input2) return false;
  996. bool first_is_variable = other_constants[0] == nullptr;
  997. SpvOp op = inst->opcode();
  998. uint32_t op1 = 0;
  999. uint32_t op2 = 0;
  1000. if (first_is_variable) {
  1001. // Subtract constants. Non-constant operand is first.
  1002. op1 = other_inst->GetSingleWordInOperand(0u);
  1003. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  1004. const_input2);
  1005. } else {
  1006. // Add constants. Constant operand is first. Change the opcode.
  1007. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  1008. const_input2);
  1009. op2 = other_inst->GetSingleWordInOperand(1u);
  1010. op = other_inst->opcode();
  1011. }
  1012. if (op1 == 0 || op2 == 0) return false;
  1013. inst->SetOpcode(op);
  1014. inst->SetInOperands(
  1015. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1016. return true;
  1017. }
  1018. return false;
  1019. };
  1020. }
  1021. // Folds subtraction of an addition where each operand has a constant operand.
  1022. // Cases:
  1023. // (x + 2) - 2 = x + 0
  1024. // (2 + x) - 2 = x + 0
  1025. // 2 - (x + 2) = 0 - x
  1026. // 2 - (2 + x) = 0 - x
  1027. FoldingRule MergeSubAddArithmetic() {
  1028. return [](IRContext* context, Instruction* inst,
  1029. const std::vector<const analysis::Constant*>& constants) {
  1030. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1031. const analysis::Type* type =
  1032. context->get_type_mgr()->GetType(inst->type_id());
  1033. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1034. bool uses_float = HasFloatingPoint(type);
  1035. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1036. uint32_t width = ElementWidth(type);
  1037. if (width != 32 && width != 64) return false;
  1038. const analysis::Constant* const_input1 = ConstInput(constants);
  1039. if (!const_input1) return false;
  1040. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1041. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1042. return false;
  1043. if (other_inst->opcode() == SpvOpFAdd ||
  1044. other_inst->opcode() == SpvOpIAdd) {
  1045. std::vector<const analysis::Constant*> other_constants =
  1046. const_mgr->GetOperandConstants(other_inst);
  1047. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1048. if (!const_input2) return false;
  1049. Instruction* non_const_input =
  1050. NonConstInput(context, other_constants[0], other_inst);
  1051. // If the first operand of the sub is not a constant, swap the constants
  1052. // so the subtraction has the correct operands.
  1053. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1054. // Subtract the constants.
  1055. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1056. const_input1, const_input2);
  1057. SpvOp op = inst->opcode();
  1058. uint32_t op1 = 0;
  1059. uint32_t op2 = 0;
  1060. if (constants[0] == nullptr) {
  1061. // Non-constant operand is first. Change the opcode.
  1062. op1 = non_const_input->result_id();
  1063. op2 = merged_id;
  1064. op = other_inst->opcode();
  1065. } else {
  1066. // Constant operand is first.
  1067. op1 = merged_id;
  1068. op2 = non_const_input->result_id();
  1069. }
  1070. if (op1 == 0 || op2 == 0) return false;
  1071. inst->SetOpcode(op);
  1072. inst->SetInOperands(
  1073. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1074. return true;
  1075. }
  1076. return false;
  1077. };
  1078. }
  1079. // Folds subtraction of a subtraction where each operand has a constant operand.
  1080. // Cases:
  1081. // (x - 2) - 2 = x - 4
  1082. // (2 - x) - 2 = 0 - x
  1083. // 2 - (x - 2) = 4 - x
  1084. // 2 - (2 - x) = x + 0
  1085. FoldingRule MergeSubSubArithmetic() {
  1086. return [](IRContext* context, Instruction* inst,
  1087. const std::vector<const analysis::Constant*>& constants) {
  1088. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1089. const analysis::Type* type =
  1090. context->get_type_mgr()->GetType(inst->type_id());
  1091. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1092. bool uses_float = HasFloatingPoint(type);
  1093. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1094. uint32_t width = ElementWidth(type);
  1095. if (width != 32 && width != 64) return false;
  1096. const analysis::Constant* const_input1 = ConstInput(constants);
  1097. if (!const_input1) return false;
  1098. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1099. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1100. return false;
  1101. if (other_inst->opcode() == SpvOpFSub ||
  1102. other_inst->opcode() == SpvOpISub) {
  1103. std::vector<const analysis::Constant*> other_constants =
  1104. const_mgr->GetOperandConstants(other_inst);
  1105. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1106. if (!const_input2) return false;
  1107. Instruction* non_const_input =
  1108. NonConstInput(context, other_constants[0], other_inst);
  1109. // Merge the constants.
  1110. uint32_t merged_id = 0;
  1111. SpvOp merge_op = inst->opcode();
  1112. if (other_constants[0] == nullptr) {
  1113. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1114. } else if (constants[0] == nullptr) {
  1115. std::swap(const_input1, const_input2);
  1116. }
  1117. merged_id =
  1118. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1119. if (merged_id == 0) return false;
  1120. SpvOp op = inst->opcode();
  1121. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1122. // Change the operation.
  1123. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1124. }
  1125. uint32_t op1 = 0;
  1126. uint32_t op2 = 0;
  1127. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1128. op1 = merged_id;
  1129. op2 = non_const_input->result_id();
  1130. } else {
  1131. op1 = non_const_input->result_id();
  1132. op2 = merged_id;
  1133. }
  1134. inst->SetOpcode(op);
  1135. inst->SetInOperands(
  1136. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1137. return true;
  1138. }
  1139. return false;
  1140. };
  1141. }
  1142. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1143. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1144. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1145. IRContext* context = inst->context();
  1146. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1147. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1148. if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
  1149. return false;
  1150. if (sub_inst->opcode() == SpvOpFSub &&
  1151. !sub_inst->IsFloatingPointFoldingAllowed())
  1152. return false;
  1153. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1154. inst->SetOpcode(SpvOpCopyObject);
  1155. inst->SetInOperands(
  1156. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1157. context->UpdateDefUse(inst);
  1158. return true;
  1159. }
  1160. // Folds addition of a subtraction where the subtrahend is equal to the
  1161. // other addend. Return a copy of the minuend. Accepts generic (const and
  1162. // non-const) operands.
  1163. // Cases:
  1164. // (a - b) + b = a
  1165. // b + (a - b) = a
  1166. FoldingRule MergeGenericAddSubArithmetic() {
  1167. return [](IRContext* context, Instruction* inst,
  1168. const std::vector<const analysis::Constant*>&) {
  1169. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1170. const analysis::Type* type =
  1171. context->get_type_mgr()->GetType(inst->type_id());
  1172. bool uses_float = HasFloatingPoint(type);
  1173. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1174. uint32_t width = ElementWidth(type);
  1175. if (width != 32 && width != 64) return false;
  1176. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1177. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1178. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1179. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1180. };
  1181. }
  1182. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1183. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1184. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1185. uint32_t factor1_0, uint32_t factor1_1,
  1186. Instruction* inst) {
  1187. IRContext* context = inst->context();
  1188. if (factor0_0 != factor1_0) return false;
  1189. InstructionBuilder ir_builder(
  1190. context, inst,
  1191. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1192. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1193. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1194. inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
  1195. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1196. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1197. context->UpdateDefUse(inst);
  1198. return true;
  1199. }
  1200. // Perform the following factoring identity, handling all operand order
  1201. // combinations: (a * b) + (a * c) = a * (b + c)
  1202. FoldingRule FactorAddMuls() {
  1203. return [](IRContext* context, Instruction* inst,
  1204. const std::vector<const analysis::Constant*>&) {
  1205. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1206. const analysis::Type* type =
  1207. context->get_type_mgr()->GetType(inst->type_id());
  1208. bool uses_float = HasFloatingPoint(type);
  1209. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1210. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1211. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1212. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1213. if (add_op0_inst->opcode() != SpvOpFMul &&
  1214. add_op0_inst->opcode() != SpvOpIMul)
  1215. return false;
  1216. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1217. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1218. if (add_op1_inst->opcode() != SpvOpFMul &&
  1219. add_op1_inst->opcode() != SpvOpIMul)
  1220. return false;
  1221. // Only perform this optimization if both of the muls only have one use.
  1222. // Otherwise this is a deoptimization in size and performance.
  1223. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1224. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1225. if (add_op0_inst->opcode() == SpvOpFMul &&
  1226. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1227. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1228. return false;
  1229. for (int i = 0; i < 2; i++) {
  1230. for (int j = 0; j < 2; j++) {
  1231. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1232. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1233. add_op0_inst->GetSingleWordInOperand(1 - i),
  1234. add_op1_inst->GetSingleWordInOperand(j),
  1235. add_op1_inst->GetSingleWordInOperand(1 - j),
  1236. inst))
  1237. return true;
  1238. }
  1239. }
  1240. return false;
  1241. };
  1242. }
  1243. FoldingRule IntMultipleBy1() {
  1244. return [](IRContext*, Instruction* inst,
  1245. const std::vector<const analysis::Constant*>& constants) {
  1246. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1247. for (uint32_t i = 0; i < 2; i++) {
  1248. if (constants[i] == nullptr) {
  1249. continue;
  1250. }
  1251. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1252. if (int_constant) {
  1253. uint32_t width = ElementWidth(int_constant->type());
  1254. if (width != 32 && width != 64) return false;
  1255. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1256. : int_constant->GetU64BitValue() == 1ull;
  1257. if (is_one) {
  1258. inst->SetOpcode(SpvOpCopyObject);
  1259. inst->SetInOperands(
  1260. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1261. return true;
  1262. }
  1263. }
  1264. }
  1265. return false;
  1266. };
  1267. }
  1268. FoldingRule CompositeConstructFeedingExtract() {
  1269. return [](IRContext* context, Instruction* inst,
  1270. const std::vector<const analysis::Constant*>&) {
  1271. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1272. // then we can simply use the appropriate element in the construction.
  1273. assert(inst->opcode() == SpvOpCompositeExtract &&
  1274. "Wrong opcode. Should be OpCompositeExtract.");
  1275. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1276. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1277. // If there are no index operands, then this rule cannot do anything.
  1278. if (inst->NumInOperands() <= 1) {
  1279. return false;
  1280. }
  1281. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1282. Instruction* cinst = def_use_mgr->GetDef(cid);
  1283. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1284. return false;
  1285. }
  1286. std::vector<Operand> operands;
  1287. analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
  1288. if (composite_type->AsVector() == nullptr) {
  1289. // Get the element being extracted from the OpCompositeConstruct
  1290. // Since it is not a vector, it is simple to extract the single element.
  1291. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1292. uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
  1293. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1294. // Add the remaining indices for extraction.
  1295. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1296. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1297. {inst->GetSingleWordInOperand(i)}});
  1298. }
  1299. } else {
  1300. // With vectors we have to handle the case where it is concatenating
  1301. // vectors.
  1302. assert(inst->NumInOperands() == 2 &&
  1303. "Expecting a vector of scalar values.");
  1304. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1305. for (uint32_t construct_index = 0;
  1306. construct_index < cinst->NumInOperands(); ++construct_index) {
  1307. uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
  1308. Instruction* element_def = def_use_mgr->GetDef(element_id);
  1309. analysis::Vector* element_type =
  1310. type_mgr->GetType(element_def->type_id())->AsVector();
  1311. if (element_type) {
  1312. uint32_t vector_size = element_type->element_count();
  1313. if (vector_size < element_index) {
  1314. // The element we want comes after this vector.
  1315. element_index -= vector_size;
  1316. } else {
  1317. // We want an element of this vector.
  1318. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1319. operands.push_back(
  1320. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
  1321. break;
  1322. }
  1323. } else {
  1324. if (element_index == 0) {
  1325. // This is a scalar, and we this is the element we are extracting.
  1326. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1327. break;
  1328. } else {
  1329. // Skip over this scalar value.
  1330. --element_index;
  1331. }
  1332. }
  1333. }
  1334. }
  1335. // If there were no extra indices, then we have the final object. No need
  1336. // to extract even more.
  1337. if (operands.size() == 1) {
  1338. inst->SetOpcode(SpvOpCopyObject);
  1339. }
  1340. inst->SetInOperands(std::move(operands));
  1341. return true;
  1342. };
  1343. }
  1344. FoldingRule CompositeExtractFeedingConstruct() {
  1345. // If the OpCompositeConstruct is simply putting back together elements that
  1346. // where extracted from the same souce, we can simlpy reuse the source.
  1347. //
  1348. // This is a common code pattern because of the way that scalar replacement
  1349. // works.
  1350. return [](IRContext* context, Instruction* inst,
  1351. const std::vector<const analysis::Constant*>&) {
  1352. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1353. "Wrong opcode. Should be OpCompositeConstruct.");
  1354. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1355. uint32_t original_id = 0;
  1356. // Check each element to make sure they are:
  1357. // - extractions
  1358. // - extracting the same position they are inserting
  1359. // - all extract from the same id.
  1360. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1361. uint32_t element_id = inst->GetSingleWordInOperand(i);
  1362. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1363. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1364. return false;
  1365. }
  1366. if (element_inst->NumInOperands() != 2) {
  1367. return false;
  1368. }
  1369. if (element_inst->GetSingleWordInOperand(1) != i) {
  1370. return false;
  1371. }
  1372. if (i == 0) {
  1373. original_id =
  1374. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1375. } else if (original_id != element_inst->GetSingleWordInOperand(
  1376. kExtractCompositeIdInIdx)) {
  1377. return false;
  1378. }
  1379. }
  1380. // The last check it to see that the object being extracted from is the
  1381. // correct type.
  1382. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1383. if (original_inst->type_id() != inst->type_id()) {
  1384. return false;
  1385. }
  1386. // Simplify by using the original object.
  1387. inst->SetOpcode(SpvOpCopyObject);
  1388. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1389. return true;
  1390. };
  1391. }
  1392. FoldingRule InsertFeedingExtract() {
  1393. return [](IRContext* context, Instruction* inst,
  1394. const std::vector<const analysis::Constant*>&) {
  1395. assert(inst->opcode() == SpvOpCompositeExtract &&
  1396. "Wrong opcode. Should be OpCompositeExtract.");
  1397. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1398. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1399. Instruction* cinst = def_use_mgr->GetDef(cid);
  1400. if (cinst->opcode() != SpvOpCompositeInsert) {
  1401. return false;
  1402. }
  1403. // Find the first position where the list of insert and extract indicies
  1404. // differ, if at all.
  1405. uint32_t i;
  1406. for (i = 1; i < inst->NumInOperands(); ++i) {
  1407. if (i + 1 >= cinst->NumInOperands()) {
  1408. break;
  1409. }
  1410. if (inst->GetSingleWordInOperand(i) !=
  1411. cinst->GetSingleWordInOperand(i + 1)) {
  1412. break;
  1413. }
  1414. }
  1415. // We are extracting the element that was inserted.
  1416. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1417. inst->SetOpcode(SpvOpCopyObject);
  1418. inst->SetInOperands(
  1419. {{SPV_OPERAND_TYPE_ID,
  1420. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1421. return true;
  1422. }
  1423. // Extracting the value that was inserted along with values for the base
  1424. // composite. Cannot do anything.
  1425. if (i == inst->NumInOperands()) {
  1426. return false;
  1427. }
  1428. // Extracting an element of the value that was inserted. Extract from
  1429. // that value directly.
  1430. if (i + 1 == cinst->NumInOperands()) {
  1431. std::vector<Operand> operands;
  1432. operands.push_back(
  1433. {SPV_OPERAND_TYPE_ID,
  1434. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1435. for (; i < inst->NumInOperands(); ++i) {
  1436. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1437. {inst->GetSingleWordInOperand(i)}});
  1438. }
  1439. inst->SetInOperands(std::move(operands));
  1440. return true;
  1441. }
  1442. // Extracting a value that is disjoint from the element being inserted.
  1443. // Rewrite the extract to use the composite input to the insert.
  1444. std::vector<Operand> operands;
  1445. operands.push_back(
  1446. {SPV_OPERAND_TYPE_ID,
  1447. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1448. for (i = 1; i < inst->NumInOperands(); ++i) {
  1449. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1450. {inst->GetSingleWordInOperand(i)}});
  1451. }
  1452. inst->SetInOperands(std::move(operands));
  1453. return true;
  1454. };
  1455. }
  1456. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1457. // operands of the VectorShuffle. We just need to adjust the index in the
  1458. // extract instruction.
  1459. FoldingRule VectorShuffleFeedingExtract() {
  1460. return [](IRContext* context, Instruction* inst,
  1461. const std::vector<const analysis::Constant*>&) {
  1462. assert(inst->opcode() == SpvOpCompositeExtract &&
  1463. "Wrong opcode. Should be OpCompositeExtract.");
  1464. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1465. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1466. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1467. Instruction* cinst = def_use_mgr->GetDef(cid);
  1468. if (cinst->opcode() != SpvOpVectorShuffle) {
  1469. return false;
  1470. }
  1471. // Find the size of the first vector operand of the VectorShuffle
  1472. Instruction* first_input =
  1473. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1474. analysis::Type* first_input_type =
  1475. type_mgr->GetType(first_input->type_id());
  1476. assert(first_input_type->AsVector() &&
  1477. "Input to vector shuffle should be vectors.");
  1478. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1479. // Get index of the element the vector shuffle is placing in the position
  1480. // being extracted.
  1481. uint32_t new_index =
  1482. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1483. // Extracting an undefined value so fold this extract into an undef.
  1484. const uint32_t undef_literal_value = 0xffffffff;
  1485. if (new_index == undef_literal_value) {
  1486. inst->SetOpcode(SpvOpUndef);
  1487. inst->SetInOperands({});
  1488. return true;
  1489. }
  1490. // Get the id of the of the vector the elemtent comes from, and update the
  1491. // index if needed.
  1492. uint32_t new_vector = 0;
  1493. if (new_index < first_input_size) {
  1494. new_vector = cinst->GetSingleWordInOperand(0);
  1495. } else {
  1496. new_vector = cinst->GetSingleWordInOperand(1);
  1497. new_index -= first_input_size;
  1498. }
  1499. // Update the extract instruction.
  1500. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1501. inst->SetInOperand(1, {new_index});
  1502. return true;
  1503. };
  1504. }
  1505. // When an FMix with is feeding an Extract that extracts an element whose
  1506. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1507. // operands of the FMix.
  1508. FoldingRule FMixFeedingExtract() {
  1509. return [](IRContext* context, Instruction* inst,
  1510. const std::vector<const analysis::Constant*>&) {
  1511. assert(inst->opcode() == SpvOpCompositeExtract &&
  1512. "Wrong opcode. Should be OpCompositeExtract.");
  1513. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1514. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1515. uint32_t composite_id =
  1516. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1517. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1518. if (composite_inst->opcode() != SpvOpExtInst) {
  1519. return false;
  1520. }
  1521. uint32_t inst_set_id =
  1522. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1523. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1524. inst_set_id ||
  1525. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1526. GLSLstd450FMix) {
  1527. return false;
  1528. }
  1529. // Get the |a| for the FMix instruction.
  1530. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1531. std::unique_ptr<Instruction> a(inst->Clone(context));
  1532. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1533. context->get_instruction_folder().FoldInstruction(a.get());
  1534. if (a->opcode() != SpvOpCopyObject) {
  1535. return false;
  1536. }
  1537. const analysis::Constant* a_const =
  1538. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1539. if (!a_const) {
  1540. return false;
  1541. }
  1542. bool use_x = false;
  1543. assert(a_const->type()->AsFloat());
  1544. double element_value = a_const->GetValueAsDouble();
  1545. if (element_value == 0.0) {
  1546. use_x = true;
  1547. } else if (element_value == 1.0) {
  1548. use_x = false;
  1549. } else {
  1550. return false;
  1551. }
  1552. // Get the id of the of the vector the element comes from.
  1553. uint32_t new_vector = 0;
  1554. if (use_x) {
  1555. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1556. } else {
  1557. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1558. }
  1559. // Update the extract instruction.
  1560. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1561. return true;
  1562. };
  1563. }
  1564. FoldingRule RedundantPhi() {
  1565. // An OpPhi instruction where all values are the same or the result of the phi
  1566. // itself, can be replaced by the value itself.
  1567. return [](IRContext*, Instruction* inst,
  1568. const std::vector<const analysis::Constant*>&) {
  1569. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1570. uint32_t incoming_value = 0;
  1571. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1572. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1573. if (op_id == inst->result_id()) {
  1574. continue;
  1575. }
  1576. if (incoming_value == 0) {
  1577. incoming_value = op_id;
  1578. } else if (op_id != incoming_value) {
  1579. // Found two possible value. Can't simplify.
  1580. return false;
  1581. }
  1582. }
  1583. if (incoming_value == 0) {
  1584. // Code looks invalid. Don't do anything.
  1585. return false;
  1586. }
  1587. // We have a single incoming value. Simplify using that value.
  1588. inst->SetOpcode(SpvOpCopyObject);
  1589. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1590. return true;
  1591. };
  1592. }
  1593. FoldingRule RedundantSelect() {
  1594. // An OpSelect instruction where both values are the same or the condition is
  1595. // constant can be replaced by one of the values
  1596. return [](IRContext*, Instruction* inst,
  1597. const std::vector<const analysis::Constant*>& constants) {
  1598. assert(inst->opcode() == SpvOpSelect &&
  1599. "Wrong opcode. Should be OpSelect.");
  1600. assert(inst->NumInOperands() == 3);
  1601. assert(constants.size() == 3);
  1602. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1603. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1604. if (true_id == false_id) {
  1605. // Both results are the same, condition doesn't matter
  1606. inst->SetOpcode(SpvOpCopyObject);
  1607. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1608. return true;
  1609. } else if (constants[0]) {
  1610. const analysis::Type* type = constants[0]->type();
  1611. if (type->AsBool()) {
  1612. // Scalar constant value, select the corresponding value.
  1613. inst->SetOpcode(SpvOpCopyObject);
  1614. if (constants[0]->AsNullConstant() ||
  1615. !constants[0]->AsBoolConstant()->value()) {
  1616. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1617. } else {
  1618. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1619. }
  1620. return true;
  1621. } else {
  1622. assert(type->AsVector());
  1623. if (constants[0]->AsNullConstant()) {
  1624. // All values come from false id.
  1625. inst->SetOpcode(SpvOpCopyObject);
  1626. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1627. return true;
  1628. } else {
  1629. // Convert to a vector shuffle.
  1630. std::vector<Operand> ops;
  1631. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  1632. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  1633. const analysis::VectorConstant* vector_const =
  1634. constants[0]->AsVectorConstant();
  1635. uint32_t size =
  1636. static_cast<uint32_t>(vector_const->GetComponents().size());
  1637. for (uint32_t i = 0; i != size; ++i) {
  1638. const analysis::Constant* component =
  1639. vector_const->GetComponents()[i];
  1640. if (component->AsNullConstant() ||
  1641. !component->AsBoolConstant()->value()) {
  1642. // Selecting from the false vector which is the second input
  1643. // vector to the shuffle. Offset the index by |size|.
  1644. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  1645. } else {
  1646. // Selecting from true vector which is the first input vector to
  1647. // the shuffle.
  1648. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  1649. }
  1650. }
  1651. inst->SetOpcode(SpvOpVectorShuffle);
  1652. inst->SetInOperands(std::move(ops));
  1653. return true;
  1654. }
  1655. }
  1656. }
  1657. return false;
  1658. };
  1659. }
  1660. enum class FloatConstantKind { Unknown, Zero, One };
  1661. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  1662. if (constant == nullptr) {
  1663. return FloatConstantKind::Unknown;
  1664. }
  1665. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  1666. if (constant->AsNullConstant()) {
  1667. return FloatConstantKind::Zero;
  1668. } else if (const analysis::VectorConstant* vc =
  1669. constant->AsVectorConstant()) {
  1670. const std::vector<const analysis::Constant*>& components =
  1671. vc->GetComponents();
  1672. assert(!components.empty());
  1673. FloatConstantKind kind = getFloatConstantKind(components[0]);
  1674. for (size_t i = 1; i < components.size(); ++i) {
  1675. if (getFloatConstantKind(components[i]) != kind) {
  1676. return FloatConstantKind::Unknown;
  1677. }
  1678. }
  1679. return kind;
  1680. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  1681. if (fc->IsZero()) return FloatConstantKind::Zero;
  1682. uint32_t width = fc->type()->AsFloat()->width();
  1683. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  1684. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  1685. if (value == 0.0) {
  1686. return FloatConstantKind::Zero;
  1687. } else if (value == 1.0) {
  1688. return FloatConstantKind::One;
  1689. } else {
  1690. return FloatConstantKind::Unknown;
  1691. }
  1692. } else {
  1693. return FloatConstantKind::Unknown;
  1694. }
  1695. }
  1696. FoldingRule RedundantFAdd() {
  1697. return [](IRContext*, Instruction* inst,
  1698. const std::vector<const analysis::Constant*>& constants) {
  1699. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  1700. assert(constants.size() == 2);
  1701. if (!inst->IsFloatingPointFoldingAllowed()) {
  1702. return false;
  1703. }
  1704. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1705. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1706. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1707. inst->SetOpcode(SpvOpCopyObject);
  1708. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1709. {inst->GetSingleWordInOperand(
  1710. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  1711. return true;
  1712. }
  1713. return false;
  1714. };
  1715. }
  1716. FoldingRule RedundantFSub() {
  1717. return [](IRContext*, Instruction* inst,
  1718. const std::vector<const analysis::Constant*>& constants) {
  1719. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  1720. assert(constants.size() == 2);
  1721. if (!inst->IsFloatingPointFoldingAllowed()) {
  1722. return false;
  1723. }
  1724. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1725. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1726. if (kind0 == FloatConstantKind::Zero) {
  1727. inst->SetOpcode(SpvOpFNegate);
  1728. inst->SetInOperands(
  1729. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  1730. return true;
  1731. }
  1732. if (kind1 == FloatConstantKind::Zero) {
  1733. inst->SetOpcode(SpvOpCopyObject);
  1734. inst->SetInOperands(
  1735. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1736. return true;
  1737. }
  1738. return false;
  1739. };
  1740. }
  1741. FoldingRule RedundantFMul() {
  1742. return [](IRContext*, Instruction* inst,
  1743. const std::vector<const analysis::Constant*>& constants) {
  1744. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  1745. assert(constants.size() == 2);
  1746. if (!inst->IsFloatingPointFoldingAllowed()) {
  1747. return false;
  1748. }
  1749. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1750. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1751. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1752. inst->SetOpcode(SpvOpCopyObject);
  1753. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1754. {inst->GetSingleWordInOperand(
  1755. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  1756. return true;
  1757. }
  1758. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  1759. inst->SetOpcode(SpvOpCopyObject);
  1760. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1761. {inst->GetSingleWordInOperand(
  1762. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  1763. return true;
  1764. }
  1765. return false;
  1766. };
  1767. }
  1768. FoldingRule RedundantFDiv() {
  1769. return [](IRContext*, Instruction* inst,
  1770. const std::vector<const analysis::Constant*>& constants) {
  1771. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  1772. assert(constants.size() == 2);
  1773. if (!inst->IsFloatingPointFoldingAllowed()) {
  1774. return false;
  1775. }
  1776. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1777. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1778. if (kind0 == FloatConstantKind::Zero) {
  1779. inst->SetOpcode(SpvOpCopyObject);
  1780. inst->SetInOperands(
  1781. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1782. return true;
  1783. }
  1784. if (kind1 == FloatConstantKind::One) {
  1785. inst->SetOpcode(SpvOpCopyObject);
  1786. inst->SetInOperands(
  1787. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1788. return true;
  1789. }
  1790. return false;
  1791. };
  1792. }
  1793. FoldingRule RedundantFMix() {
  1794. return [](IRContext* context, Instruction* inst,
  1795. const std::vector<const analysis::Constant*>& constants) {
  1796. assert(inst->opcode() == SpvOpExtInst &&
  1797. "Wrong opcode. Should be OpExtInst.");
  1798. if (!inst->IsFloatingPointFoldingAllowed()) {
  1799. return false;
  1800. }
  1801. uint32_t instSetId =
  1802. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1803. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  1804. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  1805. GLSLstd450FMix) {
  1806. assert(constants.size() == 5);
  1807. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  1808. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  1809. inst->SetOpcode(SpvOpCopyObject);
  1810. inst->SetInOperands(
  1811. {{SPV_OPERAND_TYPE_ID,
  1812. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  1813. ? kFMixXIdInIdx
  1814. : kFMixYIdInIdx)}}});
  1815. return true;
  1816. }
  1817. }
  1818. return false;
  1819. };
  1820. }
  1821. // This rule handles addition of zero for integers.
  1822. FoldingRule RedundantIAdd() {
  1823. return [](IRContext* context, Instruction* inst,
  1824. const std::vector<const analysis::Constant*>& constants) {
  1825. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  1826. uint32_t operand = std::numeric_limits<uint32_t>::max();
  1827. const analysis::Type* operand_type = nullptr;
  1828. if (constants[0] && constants[0]->IsZero()) {
  1829. operand = inst->GetSingleWordInOperand(1);
  1830. operand_type = constants[0]->type();
  1831. } else if (constants[1] && constants[1]->IsZero()) {
  1832. operand = inst->GetSingleWordInOperand(0);
  1833. operand_type = constants[1]->type();
  1834. }
  1835. if (operand != std::numeric_limits<uint32_t>::max()) {
  1836. const analysis::Type* inst_type =
  1837. context->get_type_mgr()->GetType(inst->type_id());
  1838. if (inst_type->IsSame(operand_type)) {
  1839. inst->SetOpcode(SpvOpCopyObject);
  1840. } else {
  1841. inst->SetOpcode(SpvOpBitcast);
  1842. }
  1843. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  1844. return true;
  1845. }
  1846. return false;
  1847. };
  1848. }
  1849. // This rule look for a dot with a constant vector containing a single 1 and
  1850. // the rest 0s. This is the same as doing an extract.
  1851. FoldingRule DotProductDoingExtract() {
  1852. return [](IRContext* context, Instruction* inst,
  1853. const std::vector<const analysis::Constant*>& constants) {
  1854. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  1855. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1856. if (!inst->IsFloatingPointFoldingAllowed()) {
  1857. return false;
  1858. }
  1859. for (int i = 0; i < 2; ++i) {
  1860. if (!constants[i]) {
  1861. continue;
  1862. }
  1863. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  1864. assert(vector_type && "Inputs to OpDot must be vectors.");
  1865. const analysis::Float* element_type =
  1866. vector_type->element_type()->AsFloat();
  1867. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  1868. uint32_t element_width = element_type->width();
  1869. if (element_width != 32 && element_width != 64) {
  1870. return false;
  1871. }
  1872. std::vector<const analysis::Constant*> components;
  1873. components = constants[i]->GetVectorComponents(const_mgr);
  1874. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  1875. uint32_t component_with_one = kNotFound;
  1876. bool all_others_zero = true;
  1877. for (uint32_t j = 0; j < components.size(); ++j) {
  1878. const analysis::Constant* element = components[j];
  1879. double value =
  1880. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  1881. if (value == 0.0) {
  1882. continue;
  1883. } else if (value == 1.0) {
  1884. if (component_with_one == kNotFound) {
  1885. component_with_one = j;
  1886. } else {
  1887. component_with_one = kNotFound;
  1888. break;
  1889. }
  1890. } else {
  1891. all_others_zero = false;
  1892. break;
  1893. }
  1894. }
  1895. if (!all_others_zero || component_with_one == kNotFound) {
  1896. continue;
  1897. }
  1898. std::vector<Operand> operands;
  1899. operands.push_back(
  1900. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  1901. operands.push_back(
  1902. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  1903. inst->SetOpcode(SpvOpCompositeExtract);
  1904. inst->SetInOperands(std::move(operands));
  1905. return true;
  1906. }
  1907. return false;
  1908. };
  1909. }
  1910. // If we are storing an undef, then we can remove the store.
  1911. //
  1912. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  1913. // is complicated. Waiting to see if it is needed.
  1914. FoldingRule StoringUndef() {
  1915. return [](IRContext* context, Instruction* inst,
  1916. const std::vector<const analysis::Constant*>&) {
  1917. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  1918. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1919. // If this is a volatile store, the store cannot be removed.
  1920. if (inst->NumInOperands() == 3) {
  1921. if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
  1922. return false;
  1923. }
  1924. }
  1925. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  1926. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  1927. if (object_inst->opcode() == SpvOpUndef) {
  1928. inst->ToNop();
  1929. return true;
  1930. }
  1931. return false;
  1932. };
  1933. }
  1934. FoldingRule VectorShuffleFeedingShuffle() {
  1935. return [](IRContext* context, Instruction* inst,
  1936. const std::vector<const analysis::Constant*>&) {
  1937. assert(inst->opcode() == SpvOpVectorShuffle &&
  1938. "Wrong opcode. Should be OpVectorShuffle.");
  1939. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1940. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1941. Instruction* feeding_shuffle_inst =
  1942. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  1943. analysis::Vector* op0_type =
  1944. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  1945. uint32_t op0_length = op0_type->element_count();
  1946. bool feeder_is_op0 = true;
  1947. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1948. feeding_shuffle_inst =
  1949. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  1950. feeder_is_op0 = false;
  1951. }
  1952. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1953. return false;
  1954. }
  1955. Instruction* feeder2 =
  1956. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  1957. analysis::Vector* feeder_op0_type =
  1958. type_mgr->GetType(feeder2->type_id())->AsVector();
  1959. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  1960. uint32_t new_feeder_id = 0;
  1961. std::vector<Operand> new_operands;
  1962. new_operands.resize(
  1963. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  1964. const uint32_t undef_literal = 0xffffffff;
  1965. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  1966. uint32_t component_index = inst->GetSingleWordInOperand(op);
  1967. // Do not interpret the undefined value literal as coming from operand 1.
  1968. if (component_index != undef_literal &&
  1969. feeder_is_op0 == (component_index < op0_length)) {
  1970. // This component comes from the feeding_shuffle_inst. Update
  1971. // |component_index| to be the index into the operand of the feeder.
  1972. // Adjust component_index to get the index into the operands of the
  1973. // feeding_shuffle_inst.
  1974. if (component_index >= op0_length) {
  1975. component_index -= op0_length;
  1976. }
  1977. component_index =
  1978. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  1979. // Check if we are using a component from the first or second operand of
  1980. // the feeding instruction.
  1981. if (component_index < feeder_op0_length) {
  1982. if (new_feeder_id == 0) {
  1983. // First time through, save the id of the operand the element comes
  1984. // from.
  1985. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  1986. } else if (new_feeder_id !=
  1987. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  1988. // We need both elements of the feeding_shuffle_inst, so we cannot
  1989. // fold.
  1990. return false;
  1991. }
  1992. } else {
  1993. if (new_feeder_id == 0) {
  1994. // First time through, save the id of the operand the element comes
  1995. // from.
  1996. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  1997. } else if (new_feeder_id !=
  1998. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  1999. // We need both elements of the feeding_shuffle_inst, so we cannot
  2000. // fold.
  2001. return false;
  2002. }
  2003. component_index -= feeder_op0_length;
  2004. }
  2005. if (!feeder_is_op0) {
  2006. component_index += op0_length;
  2007. }
  2008. }
  2009. new_operands.push_back(
  2010. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  2011. }
  2012. if (new_feeder_id == 0) {
  2013. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2014. const analysis::Type* type =
  2015. type_mgr->GetType(feeding_shuffle_inst->type_id());
  2016. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  2017. new_feeder_id =
  2018. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  2019. }
  2020. if (feeder_is_op0) {
  2021. // If the size of the first vector operand changed then the indices
  2022. // referring to the second operand need to be adjusted.
  2023. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  2024. analysis::Type* new_feeder_type =
  2025. type_mgr->GetType(new_feeder_inst->type_id());
  2026. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  2027. int32_t adjustment = op0_length - new_op0_size;
  2028. if (adjustment != 0) {
  2029. for (uint32_t i = 2; i < new_operands.size(); i++) {
  2030. if (inst->GetSingleWordInOperand(i) >= op0_length) {
  2031. new_operands[i].words[0] -= adjustment;
  2032. }
  2033. }
  2034. }
  2035. new_operands[0].words[0] = new_feeder_id;
  2036. new_operands[1] = inst->GetInOperand(1);
  2037. } else {
  2038. new_operands[1].words[0] = new_feeder_id;
  2039. new_operands[0] = inst->GetInOperand(0);
  2040. }
  2041. inst->SetInOperands(std::move(new_operands));
  2042. return true;
  2043. };
  2044. }
  2045. // Removes duplicate ids from the interface list of an OpEntryPoint
  2046. // instruction.
  2047. FoldingRule RemoveRedundantOperands() {
  2048. return [](IRContext*, Instruction* inst,
  2049. const std::vector<const analysis::Constant*>&) {
  2050. assert(inst->opcode() == SpvOpEntryPoint &&
  2051. "Wrong opcode. Should be OpEntryPoint.");
  2052. bool has_redundant_operand = false;
  2053. std::unordered_set<uint32_t> seen_operands;
  2054. std::vector<Operand> new_operands;
  2055. new_operands.emplace_back(inst->GetOperand(0));
  2056. new_operands.emplace_back(inst->GetOperand(1));
  2057. new_operands.emplace_back(inst->GetOperand(2));
  2058. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2059. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2060. new_operands.emplace_back(inst->GetOperand(i));
  2061. } else {
  2062. has_redundant_operand = true;
  2063. }
  2064. }
  2065. if (!has_redundant_operand) {
  2066. return false;
  2067. }
  2068. inst->SetInOperands(std::move(new_operands));
  2069. return true;
  2070. };
  2071. }
  2072. // If an image instruction's operand is a constant, updates the image operand
  2073. // flag from Offset to ConstOffset.
  2074. FoldingRule UpdateImageOperands() {
  2075. return [](IRContext*, Instruction* inst,
  2076. const std::vector<const analysis::Constant*>& constants) {
  2077. const auto opcode = inst->opcode();
  2078. (void)opcode;
  2079. assert((opcode == SpvOpImageSampleImplicitLod ||
  2080. opcode == SpvOpImageSampleExplicitLod ||
  2081. opcode == SpvOpImageSampleDrefImplicitLod ||
  2082. opcode == SpvOpImageSampleDrefExplicitLod ||
  2083. opcode == SpvOpImageSampleProjImplicitLod ||
  2084. opcode == SpvOpImageSampleProjExplicitLod ||
  2085. opcode == SpvOpImageSampleProjDrefImplicitLod ||
  2086. opcode == SpvOpImageSampleProjDrefExplicitLod ||
  2087. opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
  2088. opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
  2089. opcode == SpvOpImageWrite ||
  2090. opcode == SpvOpImageSparseSampleImplicitLod ||
  2091. opcode == SpvOpImageSparseSampleExplicitLod ||
  2092. opcode == SpvOpImageSparseSampleDrefImplicitLod ||
  2093. opcode == SpvOpImageSparseSampleDrefExplicitLod ||
  2094. opcode == SpvOpImageSparseSampleProjImplicitLod ||
  2095. opcode == SpvOpImageSparseSampleProjExplicitLod ||
  2096. opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
  2097. opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
  2098. opcode == SpvOpImageSparseFetch ||
  2099. opcode == SpvOpImageSparseGather ||
  2100. opcode == SpvOpImageSparseDrefGather ||
  2101. opcode == SpvOpImageSparseRead) &&
  2102. "Wrong opcode. Should be an image instruction.");
  2103. int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
  2104. if (operand_index >= 0) {
  2105. auto image_operands = inst->GetSingleWordInOperand(operand_index);
  2106. if (image_operands & SpvImageOperandsOffsetMask) {
  2107. uint32_t offset_operand_index = operand_index + 1;
  2108. if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
  2109. if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
  2110. if (image_operands & SpvImageOperandsGradMask)
  2111. offset_operand_index += 2;
  2112. assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
  2113. "Offset and ConstOffset may not be used together");
  2114. if (offset_operand_index < inst->NumOperands()) {
  2115. if (constants[offset_operand_index]) {
  2116. image_operands = image_operands | SpvImageOperandsConstOffsetMask;
  2117. image_operands = image_operands & ~SpvImageOperandsOffsetMask;
  2118. inst->SetInOperand(operand_index, {image_operands});
  2119. return true;
  2120. }
  2121. }
  2122. }
  2123. }
  2124. return false;
  2125. };
  2126. }
  2127. } // namespace
  2128. void FoldingRules::AddFoldingRules() {
  2129. // Add all folding rules to the list for the opcodes to which they apply.
  2130. // Note that the order in which rules are added to the list matters. If a rule
  2131. // applies to the instruction, the rest of the rules will not be attempted.
  2132. // Take that into consideration.
  2133. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
  2134. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  2135. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
  2136. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2137. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  2138. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  2139. rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
  2140. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  2141. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  2142. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  2143. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  2144. rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
  2145. rules_[SpvOpFAdd].push_back(FactorAddMuls());
  2146. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  2147. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  2148. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  2149. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  2150. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  2151. rules_[SpvOpFMul].push_back(RedundantFMul());
  2152. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  2153. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  2154. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  2155. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  2156. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  2157. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  2158. rules_[SpvOpFSub].push_back(RedundantFSub());
  2159. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  2160. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  2161. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  2162. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  2163. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  2164. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  2165. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  2166. rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
  2167. rules_[SpvOpIAdd].push_back(FactorAddMuls());
  2168. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  2169. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  2170. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  2171. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  2172. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  2173. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  2174. rules_[SpvOpPhi].push_back(RedundantPhi());
  2175. rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
  2176. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  2177. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  2178. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  2179. rules_[SpvOpSelect].push_back(RedundantSelect());
  2180. rules_[SpvOpStore].push_back(StoringUndef());
  2181. rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
  2182. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2183. rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
  2184. rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
  2185. rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
  2186. rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
  2187. rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
  2188. rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
  2189. rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
  2190. rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
  2191. rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
  2192. rules_[SpvOpImageGather].push_back(UpdateImageOperands());
  2193. rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
  2194. rules_[SpvOpImageRead].push_back(UpdateImageOperands());
  2195. rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
  2196. rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
  2197. rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
  2198. rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
  2199. UpdateImageOperands());
  2200. rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
  2201. UpdateImageOperands());
  2202. rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
  2203. UpdateImageOperands());
  2204. rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
  2205. UpdateImageOperands());
  2206. rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
  2207. UpdateImageOperands());
  2208. rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
  2209. UpdateImageOperands());
  2210. rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
  2211. rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
  2212. rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
  2213. rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
  2214. FeatureManager* feature_manager = context_->get_feature_mgr();
  2215. // Add rules for GLSLstd450
  2216. uint32_t ext_inst_glslstd450_id =
  2217. feature_manager->GetExtInstImportId_GLSLstd450();
  2218. if (ext_inst_glslstd450_id != 0) {
  2219. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2220. RedundantFMix());
  2221. }
  2222. }
  2223. } // namespace opt
  2224. } // namespace spvtools