folding_rules.cpp 94 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/folding_rules.h"
  15. #include <limits>
  16. #include <memory>
  17. #include <utility>
  18. #include "ir_builder.h"
  19. #include "source/latest_version_glsl_std_450_header.h"
  20. #include "source/opt/ir_context.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. const uint32_t kExtractCompositeIdInIdx = 0;
  25. const uint32_t kInsertObjectIdInIdx = 0;
  26. const uint32_t kInsertCompositeIdInIdx = 1;
  27. const uint32_t kExtInstSetIdInIdx = 0;
  28. const uint32_t kExtInstInstructionInIdx = 1;
  29. const uint32_t kFMixXIdInIdx = 2;
  30. const uint32_t kFMixYIdInIdx = 3;
  31. const uint32_t kFMixAIdInIdx = 4;
  32. const uint32_t kStoreObjectInIdx = 1;
  33. // Some image instructions may contain an "image operands" argument.
  34. // Returns the operand index for the "image operands".
  35. // Returns -1 if the instruction does not have image operands.
  36. int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
  37. const auto opcode = inst->opcode();
  38. switch (opcode) {
  39. case SpvOpImageSampleImplicitLod:
  40. case SpvOpImageSampleExplicitLod:
  41. case SpvOpImageSampleProjImplicitLod:
  42. case SpvOpImageSampleProjExplicitLod:
  43. case SpvOpImageFetch:
  44. case SpvOpImageRead:
  45. case SpvOpImageSparseSampleImplicitLod:
  46. case SpvOpImageSparseSampleExplicitLod:
  47. case SpvOpImageSparseSampleProjImplicitLod:
  48. case SpvOpImageSparseSampleProjExplicitLod:
  49. case SpvOpImageSparseFetch:
  50. case SpvOpImageSparseRead:
  51. return inst->NumOperands() > 4 ? 2 : -1;
  52. case SpvOpImageSampleDrefImplicitLod:
  53. case SpvOpImageSampleDrefExplicitLod:
  54. case SpvOpImageSampleProjDrefImplicitLod:
  55. case SpvOpImageSampleProjDrefExplicitLod:
  56. case SpvOpImageGather:
  57. case SpvOpImageDrefGather:
  58. case SpvOpImageSparseSampleDrefImplicitLod:
  59. case SpvOpImageSparseSampleDrefExplicitLod:
  60. case SpvOpImageSparseSampleProjDrefImplicitLod:
  61. case SpvOpImageSparseSampleProjDrefExplicitLod:
  62. case SpvOpImageSparseGather:
  63. case SpvOpImageSparseDrefGather:
  64. return inst->NumOperands() > 5 ? 3 : -1;
  65. case SpvOpImageWrite:
  66. return inst->NumOperands() > 3 ? 3 : -1;
  67. default:
  68. return -1;
  69. }
  70. }
  71. // Returns the element width of |type|.
  72. uint32_t ElementWidth(const analysis::Type* type) {
  73. if (const analysis::Vector* vec_type = type->AsVector()) {
  74. return ElementWidth(vec_type->element_type());
  75. } else if (const analysis::Float* float_type = type->AsFloat()) {
  76. return float_type->width();
  77. } else {
  78. assert(type->AsInteger());
  79. return type->AsInteger()->width();
  80. }
  81. }
  82. // Returns true if |type| is Float or a vector of Float.
  83. bool HasFloatingPoint(const analysis::Type* type) {
  84. if (type->AsFloat()) {
  85. return true;
  86. } else if (const analysis::Vector* vec_type = type->AsVector()) {
  87. return vec_type->element_type()->AsFloat() != nullptr;
  88. }
  89. return false;
  90. }
  91. // Returns false if |val| is NaN, infinite or subnormal.
  92. template <typename T>
  93. bool IsValidResult(T val) {
  94. int classified = std::fpclassify(val);
  95. switch (classified) {
  96. case FP_NAN:
  97. case FP_INFINITE:
  98. case FP_SUBNORMAL:
  99. return false;
  100. default:
  101. return true;
  102. }
  103. }
  104. const analysis::Constant* ConstInput(
  105. const std::vector<const analysis::Constant*>& constants) {
  106. return constants[0] ? constants[0] : constants[1];
  107. }
  108. Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
  109. Instruction* inst) {
  110. uint32_t in_op = c ? 1u : 0u;
  111. return context->get_def_use_mgr()->GetDef(
  112. inst->GetSingleWordInOperand(in_op));
  113. }
  114. // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
  115. // constant.
  116. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
  117. const analysis::Constant* c) {
  118. assert(c);
  119. assert(c->type()->AsFloat());
  120. uint32_t width = c->type()->AsFloat()->width();
  121. assert(width == 32 || width == 64);
  122. std::vector<uint32_t> words;
  123. if (width == 64) {
  124. utils::FloatProxy<double> result(c->GetDouble() * -1.0);
  125. words = result.GetWords();
  126. } else {
  127. utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
  128. words = result.GetWords();
  129. }
  130. const analysis::Constant* negated_const =
  131. const_mgr->GetConstant(c->type(), std::move(words));
  132. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  133. }
  134. std::vector<uint32_t> ExtractInts(uint64_t val) {
  135. std::vector<uint32_t> words;
  136. words.push_back(static_cast<uint32_t>(val));
  137. words.push_back(static_cast<uint32_t>(val >> 32));
  138. return words;
  139. }
  140. // Negates the integer constant |c|. Returns the id of the defining instruction.
  141. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
  142. const analysis::Constant* c) {
  143. assert(c);
  144. assert(c->type()->AsInteger());
  145. uint32_t width = c->type()->AsInteger()->width();
  146. assert(width == 32 || width == 64);
  147. std::vector<uint32_t> words;
  148. if (width == 64) {
  149. uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
  150. words = ExtractInts(uval);
  151. } else {
  152. words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
  153. }
  154. const analysis::Constant* negated_const =
  155. const_mgr->GetConstant(c->type(), std::move(words));
  156. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  157. }
  158. // Negates the vector constant |c|. Returns the id of the defining instruction.
  159. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
  160. const analysis::Constant* c) {
  161. assert(const_mgr && c);
  162. assert(c->type()->AsVector());
  163. if (c->AsNullConstant()) {
  164. // 0.0 vs -0.0 shouldn't matter.
  165. return const_mgr->GetDefiningInstruction(c)->result_id();
  166. } else {
  167. const analysis::Type* component_type =
  168. c->AsVectorConstant()->component_type();
  169. std::vector<uint32_t> words;
  170. for (auto& comp : c->AsVectorConstant()->GetComponents()) {
  171. if (component_type->AsFloat()) {
  172. words.push_back(NegateFloatingPointConstant(const_mgr, comp));
  173. } else {
  174. assert(component_type->AsInteger());
  175. words.push_back(NegateIntegerConstant(const_mgr, comp));
  176. }
  177. }
  178. const analysis::Constant* negated_const =
  179. const_mgr->GetConstant(c->type(), std::move(words));
  180. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  181. }
  182. }
  183. // Negates |c|. Returns the id of the defining instruction.
  184. uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
  185. const analysis::Constant* c) {
  186. if (c->type()->AsVector()) {
  187. return NegateVectorConstant(const_mgr, c);
  188. } else if (c->type()->AsFloat()) {
  189. return NegateFloatingPointConstant(const_mgr, c);
  190. } else {
  191. assert(c->type()->AsInteger());
  192. return NegateIntegerConstant(const_mgr, c);
  193. }
  194. }
  195. // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
  196. // Returns 0 if the reciprocal is NaN, infinite or subnormal.
  197. uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
  198. const analysis::Constant* c) {
  199. assert(const_mgr && c);
  200. assert(c->type()->AsFloat());
  201. uint32_t width = c->type()->AsFloat()->width();
  202. assert(width == 32 || width == 64);
  203. std::vector<uint32_t> words;
  204. if (width == 64) {
  205. spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
  206. if (!IsValidResult(result.getAsFloat())) return 0;
  207. words = result.GetWords();
  208. } else {
  209. spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
  210. if (!IsValidResult(result.getAsFloat())) return 0;
  211. words = result.GetWords();
  212. }
  213. const analysis::Constant* negated_const =
  214. const_mgr->GetConstant(c->type(), std::move(words));
  215. return const_mgr->GetDefiningInstruction(negated_const)->result_id();
  216. }
  217. // Replaces fdiv where second operand is constant with fmul.
  218. FoldingRule ReciprocalFDiv() {
  219. return [](IRContext* context, Instruction* inst,
  220. const std::vector<const analysis::Constant*>& constants) {
  221. assert(inst->opcode() == SpvOpFDiv);
  222. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  223. const analysis::Type* type =
  224. context->get_type_mgr()->GetType(inst->type_id());
  225. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  226. uint32_t width = ElementWidth(type);
  227. if (width != 32 && width != 64) return false;
  228. if (constants[1] != nullptr) {
  229. uint32_t id = 0;
  230. if (const analysis::VectorConstant* vector_const =
  231. constants[1]->AsVectorConstant()) {
  232. std::vector<uint32_t> neg_ids;
  233. for (auto& comp : vector_const->GetComponents()) {
  234. id = Reciprocal(const_mgr, comp);
  235. if (id == 0) return false;
  236. neg_ids.push_back(id);
  237. }
  238. const analysis::Constant* negated_const =
  239. const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
  240. id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
  241. } else if (constants[1]->AsFloatConstant()) {
  242. id = Reciprocal(const_mgr, constants[1]);
  243. if (id == 0) return false;
  244. } else {
  245. // Don't fold a null constant.
  246. return false;
  247. }
  248. inst->SetOpcode(SpvOpFMul);
  249. inst->SetInOperands(
  250. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
  251. {SPV_OPERAND_TYPE_ID, {id}}});
  252. return true;
  253. }
  254. return false;
  255. };
  256. }
  257. // Elides consecutive negate instructions.
  258. FoldingRule MergeNegateArithmetic() {
  259. return [](IRContext* context, Instruction* inst,
  260. const std::vector<const analysis::Constant*>& constants) {
  261. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  262. (void)constants;
  263. const analysis::Type* type =
  264. context->get_type_mgr()->GetType(inst->type_id());
  265. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  266. return false;
  267. Instruction* op_inst =
  268. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  269. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  270. return false;
  271. if (op_inst->opcode() == inst->opcode()) {
  272. // Elide negates.
  273. inst->SetOpcode(SpvOpCopyObject);
  274. inst->SetInOperands(
  275. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
  276. return true;
  277. }
  278. return false;
  279. };
  280. }
  281. // Merges negate into a mul or div operation if that operation contains a
  282. // constant operand.
  283. // Cases:
  284. // -(x * 2) = x * -2
  285. // -(2 * x) = x * -2
  286. // -(x / 2) = x / -2
  287. // -(2 / x) = -2 / x
  288. FoldingRule MergeNegateMulDivArithmetic() {
  289. return [](IRContext* context, Instruction* inst,
  290. const std::vector<const analysis::Constant*>& constants) {
  291. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  292. (void)constants;
  293. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  294. const analysis::Type* type =
  295. context->get_type_mgr()->GetType(inst->type_id());
  296. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  297. return false;
  298. Instruction* op_inst =
  299. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  300. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  301. return false;
  302. uint32_t width = ElementWidth(type);
  303. if (width != 32 && width != 64) return false;
  304. SpvOp opcode = op_inst->opcode();
  305. if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
  306. opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
  307. std::vector<const analysis::Constant*> op_constants =
  308. const_mgr->GetOperandConstants(op_inst);
  309. // Merge negate into mul or div if one operand is constant.
  310. if (op_constants[0] || op_constants[1]) {
  311. bool zero_is_variable = op_constants[0] == nullptr;
  312. const analysis::Constant* c = ConstInput(op_constants);
  313. uint32_t neg_id = NegateConstant(const_mgr, c);
  314. uint32_t non_const_id = zero_is_variable
  315. ? op_inst->GetSingleWordInOperand(0u)
  316. : op_inst->GetSingleWordInOperand(1u);
  317. // Change this instruction to a mul/div.
  318. inst->SetOpcode(op_inst->opcode());
  319. if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
  320. uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
  321. uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
  322. inst->SetInOperands(
  323. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  324. } else {
  325. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  326. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  327. }
  328. return true;
  329. }
  330. }
  331. return false;
  332. };
  333. }
  334. // Merges negate into a add or sub operation if that operation contains a
  335. // constant operand.
  336. // Cases:
  337. // -(x + 2) = -2 - x
  338. // -(2 + x) = -2 - x
  339. // -(x - 2) = 2 - x
  340. // -(2 - x) = x - 2
  341. FoldingRule MergeNegateAddSubArithmetic() {
  342. return [](IRContext* context, Instruction* inst,
  343. const std::vector<const analysis::Constant*>& constants) {
  344. assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
  345. (void)constants;
  346. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  347. const analysis::Type* type =
  348. context->get_type_mgr()->GetType(inst->type_id());
  349. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  350. return false;
  351. Instruction* op_inst =
  352. context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
  353. if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
  354. return false;
  355. uint32_t width = ElementWidth(type);
  356. if (width != 32 && width != 64) return false;
  357. if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
  358. op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
  359. std::vector<const analysis::Constant*> op_constants =
  360. const_mgr->GetOperandConstants(op_inst);
  361. if (op_constants[0] || op_constants[1]) {
  362. bool zero_is_variable = op_constants[0] == nullptr;
  363. bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
  364. (op_inst->opcode() == SpvOpIAdd);
  365. bool swap_operands = !is_add || zero_is_variable;
  366. bool negate_const = is_add;
  367. const analysis::Constant* c = ConstInput(op_constants);
  368. uint32_t const_id = 0;
  369. if (negate_const) {
  370. const_id = NegateConstant(const_mgr, c);
  371. } else {
  372. const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
  373. : op_inst->GetSingleWordInOperand(0u);
  374. }
  375. // Swap operands if necessary and make the instruction a subtraction.
  376. uint32_t op0 =
  377. zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
  378. uint32_t op1 =
  379. zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
  380. if (swap_operands) std::swap(op0, op1);
  381. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  382. inst->SetInOperands(
  383. {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
  384. return true;
  385. }
  386. }
  387. return false;
  388. };
  389. }
  390. // Returns true if |c| has a zero element.
  391. bool HasZero(const analysis::Constant* c) {
  392. if (c->AsNullConstant()) {
  393. return true;
  394. }
  395. if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
  396. for (auto& comp : vec_const->GetComponents())
  397. if (HasZero(comp)) return true;
  398. } else {
  399. assert(c->AsScalarConstant());
  400. return c->AsScalarConstant()->IsZero();
  401. }
  402. return false;
  403. }
  404. // Performs |input1| |opcode| |input2| and returns the merged constant result
  405. // id. Returns 0 if the result is not a valid value. The input types must be
  406. // Float.
  407. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
  408. SpvOp opcode,
  409. const analysis::Constant* input1,
  410. const analysis::Constant* input2) {
  411. const analysis::Type* type = input1->type();
  412. assert(type->AsFloat());
  413. uint32_t width = type->AsFloat()->width();
  414. assert(width == 32 || width == 64);
  415. std::vector<uint32_t> words;
  416. #define FOLD_OP(op) \
  417. if (width == 64) { \
  418. utils::FloatProxy<double> val = \
  419. input1->GetDouble() op input2->GetDouble(); \
  420. double dval = val.getAsFloat(); \
  421. if (!IsValidResult(dval)) return 0; \
  422. words = val.GetWords(); \
  423. } else { \
  424. utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
  425. float fval = val.getAsFloat(); \
  426. if (!IsValidResult(fval)) return 0; \
  427. words = val.GetWords(); \
  428. }
  429. switch (opcode) {
  430. case SpvOpFMul:
  431. FOLD_OP(*);
  432. break;
  433. case SpvOpFDiv:
  434. if (HasZero(input2)) return 0;
  435. FOLD_OP(/);
  436. break;
  437. case SpvOpFAdd:
  438. FOLD_OP(+);
  439. break;
  440. case SpvOpFSub:
  441. FOLD_OP(-);
  442. break;
  443. default:
  444. assert(false && "Unexpected operation");
  445. break;
  446. }
  447. #undef FOLD_OP
  448. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  449. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  450. }
  451. // Performs |input1| |opcode| |input2| and returns the merged constant result
  452. // id. Returns 0 if the result is not a valid value. The input types must be
  453. // Integers.
  454. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
  455. SpvOp opcode, const analysis::Constant* input1,
  456. const analysis::Constant* input2) {
  457. assert(input1->type()->AsInteger());
  458. const analysis::Integer* type = input1->type()->AsInteger();
  459. uint32_t width = type->AsInteger()->width();
  460. assert(width == 32 || width == 64);
  461. std::vector<uint32_t> words;
  462. #define FOLD_OP(op) \
  463. if (width == 64) { \
  464. if (type->IsSigned()) { \
  465. int64_t val = input1->GetS64() op input2->GetS64(); \
  466. words = ExtractInts(static_cast<uint64_t>(val)); \
  467. } else { \
  468. uint64_t val = input1->GetU64() op input2->GetU64(); \
  469. words = ExtractInts(val); \
  470. } \
  471. } else { \
  472. if (type->IsSigned()) { \
  473. int32_t val = input1->GetS32() op input2->GetS32(); \
  474. words.push_back(static_cast<uint32_t>(val)); \
  475. } else { \
  476. uint32_t val = input1->GetU32() op input2->GetU32(); \
  477. words.push_back(val); \
  478. } \
  479. }
  480. switch (opcode) {
  481. case SpvOpIMul:
  482. FOLD_OP(*);
  483. break;
  484. case SpvOpSDiv:
  485. case SpvOpUDiv:
  486. assert(false && "Should not merge integer division");
  487. break;
  488. case SpvOpIAdd:
  489. FOLD_OP(+);
  490. break;
  491. case SpvOpISub:
  492. FOLD_OP(-);
  493. break;
  494. default:
  495. assert(false && "Unexpected operation");
  496. break;
  497. }
  498. #undef FOLD_OP
  499. const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
  500. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  501. }
  502. // Performs |input1| |opcode| |input2| and returns the merged constant result
  503. // id. Returns 0 if the result is not a valid value. The input types must be
  504. // Integers, Floats or Vectors of such.
  505. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
  506. const analysis::Constant* input1,
  507. const analysis::Constant* input2) {
  508. assert(input1 && input2);
  509. const analysis::Type* type = input1->type();
  510. std::vector<uint32_t> words;
  511. if (const analysis::Vector* vector_type = type->AsVector()) {
  512. const analysis::Type* ele_type = vector_type->element_type();
  513. for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
  514. uint32_t id = 0;
  515. const analysis::Constant* input1_comp = nullptr;
  516. if (const analysis::VectorConstant* input1_vector =
  517. input1->AsVectorConstant()) {
  518. input1_comp = input1_vector->GetComponents()[i];
  519. } else {
  520. assert(input1->AsNullConstant());
  521. input1_comp = const_mgr->GetConstant(ele_type, {});
  522. }
  523. const analysis::Constant* input2_comp = nullptr;
  524. if (const analysis::VectorConstant* input2_vector =
  525. input2->AsVectorConstant()) {
  526. input2_comp = input2_vector->GetComponents()[i];
  527. } else {
  528. assert(input2->AsNullConstant());
  529. input2_comp = const_mgr->GetConstant(ele_type, {});
  530. }
  531. if (ele_type->AsFloat()) {
  532. id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
  533. input2_comp);
  534. } else {
  535. assert(ele_type->AsInteger());
  536. id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
  537. input2_comp);
  538. }
  539. if (id == 0) return 0;
  540. words.push_back(id);
  541. }
  542. const analysis::Constant* merged_const =
  543. const_mgr->GetConstant(type, words);
  544. return const_mgr->GetDefiningInstruction(merged_const)->result_id();
  545. } else if (type->AsFloat()) {
  546. return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
  547. } else {
  548. assert(type->AsInteger());
  549. return PerformIntegerOperation(const_mgr, opcode, input1, input2);
  550. }
  551. }
  552. // Merges consecutive multiplies where each contains one constant operand.
  553. // Cases:
  554. // 2 * (x * 2) = x * 4
  555. // 2 * (2 * x) = x * 4
  556. // (x * 2) * 2 = x * 4
  557. // (2 * x) * 2 = x * 4
  558. FoldingRule MergeMulMulArithmetic() {
  559. return [](IRContext* context, Instruction* inst,
  560. const std::vector<const analysis::Constant*>& constants) {
  561. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  562. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  563. const analysis::Type* type =
  564. context->get_type_mgr()->GetType(inst->type_id());
  565. if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
  566. return false;
  567. uint32_t width = ElementWidth(type);
  568. if (width != 32 && width != 64) return false;
  569. // Determine the constant input and the variable input in |inst|.
  570. const analysis::Constant* const_input1 = ConstInput(constants);
  571. if (!const_input1) return false;
  572. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  573. if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
  574. return false;
  575. if (other_inst->opcode() == inst->opcode()) {
  576. std::vector<const analysis::Constant*> other_constants =
  577. const_mgr->GetOperandConstants(other_inst);
  578. const analysis::Constant* const_input2 = ConstInput(other_constants);
  579. if (!const_input2) return false;
  580. bool other_first_is_variable = other_constants[0] == nullptr;
  581. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  582. const_input1, const_input2);
  583. if (merged_id == 0) return false;
  584. uint32_t non_const_id = other_first_is_variable
  585. ? other_inst->GetSingleWordInOperand(0u)
  586. : other_inst->GetSingleWordInOperand(1u);
  587. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  588. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  589. return true;
  590. }
  591. return false;
  592. };
  593. }
  594. // Merges divides into subsequent multiplies if each instruction contains one
  595. // constant operand. Does not support integer operations.
  596. // Cases:
  597. // 2 * (x / 2) = x * 1
  598. // 2 * (2 / x) = 4 / x
  599. // (x / 2) * 2 = x * 1
  600. // (2 / x) * 2 = 4 / x
  601. // (y / x) * x = y
  602. // x * (y / x) = y
  603. FoldingRule MergeMulDivArithmetic() {
  604. return [](IRContext* context, Instruction* inst,
  605. const std::vector<const analysis::Constant*>& constants) {
  606. assert(inst->opcode() == SpvOpFMul);
  607. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  608. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  609. const analysis::Type* type =
  610. context->get_type_mgr()->GetType(inst->type_id());
  611. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  612. uint32_t width = ElementWidth(type);
  613. if (width != 32 && width != 64) return false;
  614. for (uint32_t i = 0; i < 2; i++) {
  615. uint32_t op_id = inst->GetSingleWordInOperand(i);
  616. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  617. if (op_inst->opcode() == SpvOpFDiv) {
  618. if (op_inst->GetSingleWordInOperand(1) ==
  619. inst->GetSingleWordInOperand(1 - i)) {
  620. inst->SetOpcode(SpvOpCopyObject);
  621. inst->SetInOperands(
  622. {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
  623. return true;
  624. }
  625. }
  626. }
  627. const analysis::Constant* const_input1 = ConstInput(constants);
  628. if (!const_input1) return false;
  629. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  630. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  631. if (other_inst->opcode() == SpvOpFDiv) {
  632. std::vector<const analysis::Constant*> other_constants =
  633. const_mgr->GetOperandConstants(other_inst);
  634. const analysis::Constant* const_input2 = ConstInput(other_constants);
  635. if (!const_input2 || HasZero(const_input2)) return false;
  636. bool other_first_is_variable = other_constants[0] == nullptr;
  637. // If the variable value is the second operand of the divide, multiply
  638. // the constants together. Otherwise divide the constants.
  639. uint32_t merged_id = PerformOperation(
  640. const_mgr,
  641. other_first_is_variable ? other_inst->opcode() : inst->opcode(),
  642. const_input1, const_input2);
  643. if (merged_id == 0) return false;
  644. uint32_t non_const_id = other_first_is_variable
  645. ? other_inst->GetSingleWordInOperand(0u)
  646. : other_inst->GetSingleWordInOperand(1u);
  647. // If the variable value is on the second operand of the div, then this
  648. // operation is a div. Otherwise it should be a multiply.
  649. inst->SetOpcode(other_first_is_variable ? inst->opcode()
  650. : other_inst->opcode());
  651. if (other_first_is_variable) {
  652. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
  653. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  654. } else {
  655. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
  656. {SPV_OPERAND_TYPE_ID, {non_const_id}}});
  657. }
  658. return true;
  659. }
  660. return false;
  661. };
  662. }
  663. // Merges multiply of constant and negation.
  664. // Cases:
  665. // (-x) * 2 = x * -2
  666. // 2 * (-x) = x * -2
  667. FoldingRule MergeMulNegateArithmetic() {
  668. return [](IRContext* context, Instruction* inst,
  669. const std::vector<const analysis::Constant*>& constants) {
  670. assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
  671. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  672. const analysis::Type* type =
  673. context->get_type_mgr()->GetType(inst->type_id());
  674. bool uses_float = HasFloatingPoint(type);
  675. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  676. uint32_t width = ElementWidth(type);
  677. if (width != 32 && width != 64) return false;
  678. const analysis::Constant* const_input1 = ConstInput(constants);
  679. if (!const_input1) return false;
  680. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  681. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  682. return false;
  683. if (other_inst->opcode() == SpvOpFNegate ||
  684. other_inst->opcode() == SpvOpSNegate) {
  685. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  686. inst->SetInOperands(
  687. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  688. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  689. return true;
  690. }
  691. return false;
  692. };
  693. }
  694. // Merges consecutive divides if each instruction contains one constant operand.
  695. // Does not support integer division.
  696. // Cases:
  697. // 2 / (x / 2) = 4 / x
  698. // 4 / (2 / x) = 2 * x
  699. // (4 / x) / 2 = 2 / x
  700. // (x / 2) / 2 = x / 4
  701. FoldingRule MergeDivDivArithmetic() {
  702. return [](IRContext* context, Instruction* inst,
  703. const std::vector<const analysis::Constant*>& constants) {
  704. assert(inst->opcode() == SpvOpFDiv);
  705. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  706. const analysis::Type* type =
  707. context->get_type_mgr()->GetType(inst->type_id());
  708. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  709. uint32_t width = ElementWidth(type);
  710. if (width != 32 && width != 64) return false;
  711. const analysis::Constant* const_input1 = ConstInput(constants);
  712. if (!const_input1 || HasZero(const_input1)) return false;
  713. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  714. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  715. bool first_is_variable = constants[0] == nullptr;
  716. if (other_inst->opcode() == inst->opcode()) {
  717. std::vector<const analysis::Constant*> other_constants =
  718. const_mgr->GetOperandConstants(other_inst);
  719. const analysis::Constant* const_input2 = ConstInput(other_constants);
  720. if (!const_input2 || HasZero(const_input2)) return false;
  721. bool other_first_is_variable = other_constants[0] == nullptr;
  722. SpvOp merge_op = inst->opcode();
  723. if (other_first_is_variable) {
  724. // Constants magnify.
  725. merge_op = SpvOpFMul;
  726. }
  727. // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
  728. // because it is commutative.
  729. if (first_is_variable) std::swap(const_input1, const_input2);
  730. uint32_t merged_id =
  731. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  732. if (merged_id == 0) return false;
  733. uint32_t non_const_id = other_first_is_variable
  734. ? other_inst->GetSingleWordInOperand(0u)
  735. : other_inst->GetSingleWordInOperand(1u);
  736. SpvOp op = inst->opcode();
  737. if (!first_is_variable && !other_first_is_variable) {
  738. // Effectively div of 1/x, so change to multiply.
  739. op = SpvOpFMul;
  740. }
  741. uint32_t op1 = merged_id;
  742. uint32_t op2 = non_const_id;
  743. if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
  744. inst->SetOpcode(op);
  745. inst->SetInOperands(
  746. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  747. return true;
  748. }
  749. return false;
  750. };
  751. }
  752. // Fold multiplies succeeded by divides where each instruction contains a
  753. // constant operand. Does not support integer divide.
  754. // Cases:
  755. // 4 / (x * 2) = 2 / x
  756. // 4 / (2 * x) = 2 / x
  757. // (x * 4) / 2 = x * 2
  758. // (4 * x) / 2 = x * 2
  759. // (x * y) / x = y
  760. // (y * x) / x = y
  761. FoldingRule MergeDivMulArithmetic() {
  762. return [](IRContext* context, Instruction* inst,
  763. const std::vector<const analysis::Constant*>& constants) {
  764. assert(inst->opcode() == SpvOpFDiv);
  765. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  766. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  767. const analysis::Type* type =
  768. context->get_type_mgr()->GetType(inst->type_id());
  769. if (!inst->IsFloatingPointFoldingAllowed()) return false;
  770. uint32_t width = ElementWidth(type);
  771. if (width != 32 && width != 64) return false;
  772. uint32_t op_id = inst->GetSingleWordInOperand(0);
  773. Instruction* op_inst = def_use_mgr->GetDef(op_id);
  774. if (op_inst->opcode() == SpvOpFMul) {
  775. for (uint32_t i = 0; i < 2; i++) {
  776. if (op_inst->GetSingleWordInOperand(i) ==
  777. inst->GetSingleWordInOperand(1)) {
  778. inst->SetOpcode(SpvOpCopyObject);
  779. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  780. {op_inst->GetSingleWordInOperand(1 - i)}}});
  781. return true;
  782. }
  783. }
  784. }
  785. const analysis::Constant* const_input1 = ConstInput(constants);
  786. if (!const_input1 || HasZero(const_input1)) return false;
  787. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  788. if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
  789. bool first_is_variable = constants[0] == nullptr;
  790. if (other_inst->opcode() == SpvOpFMul) {
  791. std::vector<const analysis::Constant*> other_constants =
  792. const_mgr->GetOperandConstants(other_inst);
  793. const analysis::Constant* const_input2 = ConstInput(other_constants);
  794. if (!const_input2) return false;
  795. bool other_first_is_variable = other_constants[0] == nullptr;
  796. // This is an x / (*) case. Swap the inputs.
  797. if (first_is_variable) std::swap(const_input1, const_input2);
  798. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  799. const_input1, const_input2);
  800. if (merged_id == 0) return false;
  801. uint32_t non_const_id = other_first_is_variable
  802. ? other_inst->GetSingleWordInOperand(0u)
  803. : other_inst->GetSingleWordInOperand(1u);
  804. uint32_t op1 = merged_id;
  805. uint32_t op2 = non_const_id;
  806. if (first_is_variable) std::swap(op1, op2);
  807. // Convert to multiply
  808. if (first_is_variable) inst->SetOpcode(other_inst->opcode());
  809. inst->SetInOperands(
  810. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  811. return true;
  812. }
  813. return false;
  814. };
  815. }
  816. // Fold divides of a constant and a negation.
  817. // Cases:
  818. // (-x) / 2 = x / -2
  819. // 2 / (-x) = 2 / -x
  820. FoldingRule MergeDivNegateArithmetic() {
  821. return [](IRContext* context, Instruction* inst,
  822. const std::vector<const analysis::Constant*>& constants) {
  823. assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
  824. inst->opcode() == SpvOpUDiv);
  825. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  826. const analysis::Type* type =
  827. context->get_type_mgr()->GetType(inst->type_id());
  828. bool uses_float = HasFloatingPoint(type);
  829. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  830. uint32_t width = ElementWidth(type);
  831. if (width != 32 && width != 64) return false;
  832. const analysis::Constant* const_input1 = ConstInput(constants);
  833. if (!const_input1) return false;
  834. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  835. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  836. return false;
  837. bool first_is_variable = constants[0] == nullptr;
  838. if (other_inst->opcode() == SpvOpFNegate ||
  839. other_inst->opcode() == SpvOpSNegate) {
  840. uint32_t neg_id = NegateConstant(const_mgr, const_input1);
  841. if (first_is_variable) {
  842. inst->SetInOperands(
  843. {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
  844. {SPV_OPERAND_TYPE_ID, {neg_id}}});
  845. } else {
  846. inst->SetInOperands(
  847. {{SPV_OPERAND_TYPE_ID, {neg_id}},
  848. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  849. }
  850. return true;
  851. }
  852. return false;
  853. };
  854. }
  855. // Folds addition of a constant and a negation.
  856. // Cases:
  857. // (-x) + 2 = 2 - x
  858. // 2 + (-x) = 2 - x
  859. FoldingRule MergeAddNegateArithmetic() {
  860. return [](IRContext* context, Instruction* inst,
  861. const std::vector<const analysis::Constant*>& constants) {
  862. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  863. const analysis::Type* type =
  864. context->get_type_mgr()->GetType(inst->type_id());
  865. bool uses_float = HasFloatingPoint(type);
  866. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  867. const analysis::Constant* const_input1 = ConstInput(constants);
  868. if (!const_input1) return false;
  869. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  870. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  871. return false;
  872. if (other_inst->opcode() == SpvOpSNegate ||
  873. other_inst->opcode() == SpvOpFNegate) {
  874. inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
  875. uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
  876. : inst->GetSingleWordInOperand(1u);
  877. inst->SetInOperands(
  878. {{SPV_OPERAND_TYPE_ID, {const_id}},
  879. {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
  880. return true;
  881. }
  882. return false;
  883. };
  884. }
  885. // Folds subtraction of a constant and a negation.
  886. // Cases:
  887. // (-x) - 2 = -2 - x
  888. // 2 - (-x) = x + 2
  889. FoldingRule MergeSubNegateArithmetic() {
  890. return [](IRContext* context, Instruction* inst,
  891. const std::vector<const analysis::Constant*>& constants) {
  892. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  893. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  894. const analysis::Type* type =
  895. context->get_type_mgr()->GetType(inst->type_id());
  896. bool uses_float = HasFloatingPoint(type);
  897. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  898. uint32_t width = ElementWidth(type);
  899. if (width != 32 && width != 64) return false;
  900. const analysis::Constant* const_input1 = ConstInput(constants);
  901. if (!const_input1) return false;
  902. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  903. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  904. return false;
  905. if (other_inst->opcode() == SpvOpSNegate ||
  906. other_inst->opcode() == SpvOpFNegate) {
  907. uint32_t op1 = 0;
  908. uint32_t op2 = 0;
  909. SpvOp opcode = inst->opcode();
  910. if (constants[0] != nullptr) {
  911. op1 = other_inst->GetSingleWordInOperand(0u);
  912. op2 = inst->GetSingleWordInOperand(0u);
  913. opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
  914. } else {
  915. op1 = NegateConstant(const_mgr, const_input1);
  916. op2 = other_inst->GetSingleWordInOperand(0u);
  917. }
  918. inst->SetOpcode(opcode);
  919. inst->SetInOperands(
  920. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  921. return true;
  922. }
  923. return false;
  924. };
  925. }
  926. // Folds addition of an addition where each operation has a constant operand.
  927. // Cases:
  928. // (x + 2) + 2 = x + 4
  929. // (2 + x) + 2 = x + 4
  930. // 2 + (x + 2) = x + 4
  931. // 2 + (2 + x) = x + 4
  932. FoldingRule MergeAddAddArithmetic() {
  933. return [](IRContext* context, Instruction* inst,
  934. const std::vector<const analysis::Constant*>& constants) {
  935. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  936. const analysis::Type* type =
  937. context->get_type_mgr()->GetType(inst->type_id());
  938. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  939. bool uses_float = HasFloatingPoint(type);
  940. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  941. uint32_t width = ElementWidth(type);
  942. if (width != 32 && width != 64) return false;
  943. const analysis::Constant* const_input1 = ConstInput(constants);
  944. if (!const_input1) return false;
  945. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  946. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  947. return false;
  948. if (other_inst->opcode() == SpvOpFAdd ||
  949. other_inst->opcode() == SpvOpIAdd) {
  950. std::vector<const analysis::Constant*> other_constants =
  951. const_mgr->GetOperandConstants(other_inst);
  952. const analysis::Constant* const_input2 = ConstInput(other_constants);
  953. if (!const_input2) return false;
  954. Instruction* non_const_input =
  955. NonConstInput(context, other_constants[0], other_inst);
  956. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  957. const_input1, const_input2);
  958. if (merged_id == 0) return false;
  959. inst->SetInOperands(
  960. {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
  961. {SPV_OPERAND_TYPE_ID, {merged_id}}});
  962. return true;
  963. }
  964. return false;
  965. };
  966. }
  967. // Folds addition of a subtraction where each operation has a constant operand.
  968. // Cases:
  969. // (x - 2) + 2 = x + 0
  970. // (2 - x) + 2 = 4 - x
  971. // 2 + (x - 2) = x + 0
  972. // 2 + (2 - x) = 4 - x
  973. FoldingRule MergeAddSubArithmetic() {
  974. return [](IRContext* context, Instruction* inst,
  975. const std::vector<const analysis::Constant*>& constants) {
  976. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  977. const analysis::Type* type =
  978. context->get_type_mgr()->GetType(inst->type_id());
  979. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  980. bool uses_float = HasFloatingPoint(type);
  981. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  982. uint32_t width = ElementWidth(type);
  983. if (width != 32 && width != 64) return false;
  984. const analysis::Constant* const_input1 = ConstInput(constants);
  985. if (!const_input1) return false;
  986. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  987. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  988. return false;
  989. if (other_inst->opcode() == SpvOpFSub ||
  990. other_inst->opcode() == SpvOpISub) {
  991. std::vector<const analysis::Constant*> other_constants =
  992. const_mgr->GetOperandConstants(other_inst);
  993. const analysis::Constant* const_input2 = ConstInput(other_constants);
  994. if (!const_input2) return false;
  995. bool first_is_variable = other_constants[0] == nullptr;
  996. SpvOp op = inst->opcode();
  997. uint32_t op1 = 0;
  998. uint32_t op2 = 0;
  999. if (first_is_variable) {
  1000. // Subtract constants. Non-constant operand is first.
  1001. op1 = other_inst->GetSingleWordInOperand(0u);
  1002. op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
  1003. const_input2);
  1004. } else {
  1005. // Add constants. Constant operand is first. Change the opcode.
  1006. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
  1007. const_input2);
  1008. op2 = other_inst->GetSingleWordInOperand(1u);
  1009. op = other_inst->opcode();
  1010. }
  1011. if (op1 == 0 || op2 == 0) return false;
  1012. inst->SetOpcode(op);
  1013. inst->SetInOperands(
  1014. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1015. return true;
  1016. }
  1017. return false;
  1018. };
  1019. }
  1020. // Folds subtraction of an addition where each operand has a constant operand.
  1021. // Cases:
  1022. // (x + 2) - 2 = x + 0
  1023. // (2 + x) - 2 = x + 0
  1024. // 2 - (x + 2) = 0 - x
  1025. // 2 - (2 + x) = 0 - x
  1026. FoldingRule MergeSubAddArithmetic() {
  1027. return [](IRContext* context, Instruction* inst,
  1028. const std::vector<const analysis::Constant*>& constants) {
  1029. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1030. const analysis::Type* type =
  1031. context->get_type_mgr()->GetType(inst->type_id());
  1032. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1033. bool uses_float = HasFloatingPoint(type);
  1034. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1035. uint32_t width = ElementWidth(type);
  1036. if (width != 32 && width != 64) return false;
  1037. const analysis::Constant* const_input1 = ConstInput(constants);
  1038. if (!const_input1) return false;
  1039. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1040. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1041. return false;
  1042. if (other_inst->opcode() == SpvOpFAdd ||
  1043. other_inst->opcode() == SpvOpIAdd) {
  1044. std::vector<const analysis::Constant*> other_constants =
  1045. const_mgr->GetOperandConstants(other_inst);
  1046. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1047. if (!const_input2) return false;
  1048. Instruction* non_const_input =
  1049. NonConstInput(context, other_constants[0], other_inst);
  1050. // If the first operand of the sub is not a constant, swap the constants
  1051. // so the subtraction has the correct operands.
  1052. if (constants[0] == nullptr) std::swap(const_input1, const_input2);
  1053. // Subtract the constants.
  1054. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
  1055. const_input1, const_input2);
  1056. SpvOp op = inst->opcode();
  1057. uint32_t op1 = 0;
  1058. uint32_t op2 = 0;
  1059. if (constants[0] == nullptr) {
  1060. // Non-constant operand is first. Change the opcode.
  1061. op1 = non_const_input->result_id();
  1062. op2 = merged_id;
  1063. op = other_inst->opcode();
  1064. } else {
  1065. // Constant operand is first.
  1066. op1 = merged_id;
  1067. op2 = non_const_input->result_id();
  1068. }
  1069. if (op1 == 0 || op2 == 0) return false;
  1070. inst->SetOpcode(op);
  1071. inst->SetInOperands(
  1072. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1073. return true;
  1074. }
  1075. return false;
  1076. };
  1077. }
  1078. // Folds subtraction of a subtraction where each operand has a constant operand.
  1079. // Cases:
  1080. // (x - 2) - 2 = x - 4
  1081. // (2 - x) - 2 = 0 - x
  1082. // 2 - (x - 2) = 4 - x
  1083. // 2 - (2 - x) = x + 0
  1084. FoldingRule MergeSubSubArithmetic() {
  1085. return [](IRContext* context, Instruction* inst,
  1086. const std::vector<const analysis::Constant*>& constants) {
  1087. assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
  1088. const analysis::Type* type =
  1089. context->get_type_mgr()->GetType(inst->type_id());
  1090. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1091. bool uses_float = HasFloatingPoint(type);
  1092. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1093. uint32_t width = ElementWidth(type);
  1094. if (width != 32 && width != 64) return false;
  1095. const analysis::Constant* const_input1 = ConstInput(constants);
  1096. if (!const_input1) return false;
  1097. Instruction* other_inst = NonConstInput(context, constants[0], inst);
  1098. if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
  1099. return false;
  1100. if (other_inst->opcode() == SpvOpFSub ||
  1101. other_inst->opcode() == SpvOpISub) {
  1102. std::vector<const analysis::Constant*> other_constants =
  1103. const_mgr->GetOperandConstants(other_inst);
  1104. const analysis::Constant* const_input2 = ConstInput(other_constants);
  1105. if (!const_input2) return false;
  1106. Instruction* non_const_input =
  1107. NonConstInput(context, other_constants[0], other_inst);
  1108. // Merge the constants.
  1109. uint32_t merged_id = 0;
  1110. SpvOp merge_op = inst->opcode();
  1111. if (other_constants[0] == nullptr) {
  1112. merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1113. } else if (constants[0] == nullptr) {
  1114. std::swap(const_input1, const_input2);
  1115. }
  1116. merged_id =
  1117. PerformOperation(const_mgr, merge_op, const_input1, const_input2);
  1118. if (merged_id == 0) return false;
  1119. SpvOp op = inst->opcode();
  1120. if (constants[0] != nullptr && other_constants[0] != nullptr) {
  1121. // Change the operation.
  1122. op = uses_float ? SpvOpFAdd : SpvOpIAdd;
  1123. }
  1124. uint32_t op1 = 0;
  1125. uint32_t op2 = 0;
  1126. if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
  1127. op1 = merged_id;
  1128. op2 = non_const_input->result_id();
  1129. } else {
  1130. op1 = non_const_input->result_id();
  1131. op2 = merged_id;
  1132. }
  1133. inst->SetOpcode(op);
  1134. inst->SetInOperands(
  1135. {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
  1136. return true;
  1137. }
  1138. return false;
  1139. };
  1140. }
  1141. // Helper function for MergeGenericAddSubArithmetic. If |addend| and
  1142. // subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
  1143. bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
  1144. IRContext* context = inst->context();
  1145. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1146. Instruction* sub_inst = def_use_mgr->GetDef(sub);
  1147. if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
  1148. return false;
  1149. if (sub_inst->opcode() == SpvOpFSub &&
  1150. !sub_inst->IsFloatingPointFoldingAllowed())
  1151. return false;
  1152. if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
  1153. inst->SetOpcode(SpvOpCopyObject);
  1154. inst->SetInOperands(
  1155. {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
  1156. context->UpdateDefUse(inst);
  1157. return true;
  1158. }
  1159. // Folds addition of a subtraction where the subtrahend is equal to the
  1160. // other addend. Return a copy of the minuend. Accepts generic (const and
  1161. // non-const) operands.
  1162. // Cases:
  1163. // (a - b) + b = a
  1164. // b + (a - b) = a
  1165. FoldingRule MergeGenericAddSubArithmetic() {
  1166. return [](IRContext* context, Instruction* inst,
  1167. const std::vector<const analysis::Constant*>&) {
  1168. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1169. const analysis::Type* type =
  1170. context->get_type_mgr()->GetType(inst->type_id());
  1171. bool uses_float = HasFloatingPoint(type);
  1172. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1173. uint32_t width = ElementWidth(type);
  1174. if (width != 32 && width != 64) return false;
  1175. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1176. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1177. if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
  1178. return MergeGenericAddendSub(add_op1, add_op0, inst);
  1179. };
  1180. }
  1181. // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
  1182. // generate |factor0_0| * (|factor0_1| + |factor1_1|).
  1183. bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
  1184. uint32_t factor1_0, uint32_t factor1_1,
  1185. Instruction* inst) {
  1186. IRContext* context = inst->context();
  1187. if (factor0_0 != factor1_0) return false;
  1188. InstructionBuilder ir_builder(
  1189. context, inst,
  1190. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  1191. Instruction* new_add_inst = ir_builder.AddBinaryOp(
  1192. inst->type_id(), inst->opcode(), factor0_1, factor1_1);
  1193. inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
  1194. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
  1195. {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
  1196. context->UpdateDefUse(inst);
  1197. return true;
  1198. }
  1199. // Perform the following factoring identity, handling all operand order
  1200. // combinations: (a * b) + (a * c) = a * (b + c)
  1201. FoldingRule FactorAddMuls() {
  1202. return [](IRContext* context, Instruction* inst,
  1203. const std::vector<const analysis::Constant*>&) {
  1204. assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
  1205. const analysis::Type* type =
  1206. context->get_type_mgr()->GetType(inst->type_id());
  1207. bool uses_float = HasFloatingPoint(type);
  1208. if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
  1209. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1210. uint32_t add_op0 = inst->GetSingleWordInOperand(0);
  1211. Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
  1212. if (add_op0_inst->opcode() != SpvOpFMul &&
  1213. add_op0_inst->opcode() != SpvOpIMul)
  1214. return false;
  1215. uint32_t add_op1 = inst->GetSingleWordInOperand(1);
  1216. Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
  1217. if (add_op1_inst->opcode() != SpvOpFMul &&
  1218. add_op1_inst->opcode() != SpvOpIMul)
  1219. return false;
  1220. // Only perform this optimization if both of the muls only have one use.
  1221. // Otherwise this is a deoptimization in size and performance.
  1222. if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
  1223. if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
  1224. if (add_op0_inst->opcode() == SpvOpFMul &&
  1225. (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
  1226. !add_op1_inst->IsFloatingPointFoldingAllowed()))
  1227. return false;
  1228. for (int i = 0; i < 2; i++) {
  1229. for (int j = 0; j < 2; j++) {
  1230. // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
  1231. if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
  1232. add_op0_inst->GetSingleWordInOperand(1 - i),
  1233. add_op1_inst->GetSingleWordInOperand(j),
  1234. add_op1_inst->GetSingleWordInOperand(1 - j),
  1235. inst))
  1236. return true;
  1237. }
  1238. }
  1239. return false;
  1240. };
  1241. }
  1242. FoldingRule IntMultipleBy1() {
  1243. return [](IRContext*, Instruction* inst,
  1244. const std::vector<const analysis::Constant*>& constants) {
  1245. assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
  1246. for (uint32_t i = 0; i < 2; i++) {
  1247. if (constants[i] == nullptr) {
  1248. continue;
  1249. }
  1250. const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
  1251. if (int_constant) {
  1252. uint32_t width = ElementWidth(int_constant->type());
  1253. if (width != 32 && width != 64) return false;
  1254. bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
  1255. : int_constant->GetU64BitValue() == 1ull;
  1256. if (is_one) {
  1257. inst->SetOpcode(SpvOpCopyObject);
  1258. inst->SetInOperands(
  1259. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
  1260. return true;
  1261. }
  1262. }
  1263. }
  1264. return false;
  1265. };
  1266. }
  1267. FoldingRule CompositeConstructFeedingExtract() {
  1268. return [](IRContext* context, Instruction* inst,
  1269. const std::vector<const analysis::Constant*>&) {
  1270. // If the input to an OpCompositeExtract is an OpCompositeConstruct,
  1271. // then we can simply use the appropriate element in the construction.
  1272. assert(inst->opcode() == SpvOpCompositeExtract &&
  1273. "Wrong opcode. Should be OpCompositeExtract.");
  1274. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1275. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1276. // If there are no index operands, then this rule cannot do anything.
  1277. if (inst->NumInOperands() <= 1) {
  1278. return false;
  1279. }
  1280. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1281. Instruction* cinst = def_use_mgr->GetDef(cid);
  1282. if (cinst->opcode() != SpvOpCompositeConstruct) {
  1283. return false;
  1284. }
  1285. std::vector<Operand> operands;
  1286. analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
  1287. if (composite_type->AsVector() == nullptr) {
  1288. // Get the element being extracted from the OpCompositeConstruct
  1289. // Since it is not a vector, it is simple to extract the single element.
  1290. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1291. uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
  1292. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1293. // Add the remaining indices for extraction.
  1294. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  1295. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1296. {inst->GetSingleWordInOperand(i)}});
  1297. }
  1298. } else {
  1299. // With vectors we have to handle the case where it is concatenating
  1300. // vectors.
  1301. assert(inst->NumInOperands() == 2 &&
  1302. "Expecting a vector of scalar values.");
  1303. uint32_t element_index = inst->GetSingleWordInOperand(1);
  1304. for (uint32_t construct_index = 0;
  1305. construct_index < cinst->NumInOperands(); ++construct_index) {
  1306. uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
  1307. Instruction* element_def = def_use_mgr->GetDef(element_id);
  1308. analysis::Vector* element_type =
  1309. type_mgr->GetType(element_def->type_id())->AsVector();
  1310. if (element_type) {
  1311. uint32_t vector_size = element_type->element_count();
  1312. if (vector_size < element_index) {
  1313. // The element we want comes after this vector.
  1314. element_index -= vector_size;
  1315. } else {
  1316. // We want an element of this vector.
  1317. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1318. operands.push_back(
  1319. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
  1320. break;
  1321. }
  1322. } else {
  1323. if (element_index == 0) {
  1324. // This is a scalar, and we this is the element we are extracting.
  1325. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
  1326. break;
  1327. } else {
  1328. // Skip over this scalar value.
  1329. --element_index;
  1330. }
  1331. }
  1332. }
  1333. }
  1334. // If there were no extra indices, then we have the final object. No need
  1335. // to extract even more.
  1336. if (operands.size() == 1) {
  1337. inst->SetOpcode(SpvOpCopyObject);
  1338. }
  1339. inst->SetInOperands(std::move(operands));
  1340. return true;
  1341. };
  1342. }
  1343. // If the OpCompositeConstruct is simply putting back together elements that
  1344. // where extracted from the same source, we can simply reuse the source.
  1345. //
  1346. // This is a common code pattern because of the way that scalar replacement
  1347. // works.
  1348. bool CompositeExtractFeedingConstruct(
  1349. IRContext* context, Instruction* inst,
  1350. const std::vector<const analysis::Constant*>&) {
  1351. assert(inst->opcode() == SpvOpCompositeConstruct &&
  1352. "Wrong opcode. Should be OpCompositeConstruct.");
  1353. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1354. uint32_t original_id = 0;
  1355. if (inst->NumInOperands() == 0) {
  1356. // The struct being constructed has no members.
  1357. return false;
  1358. }
  1359. // Check each element to make sure they are:
  1360. // - extractions
  1361. // - extracting the same position they are inserting
  1362. // - all extract from the same id.
  1363. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
  1364. const uint32_t element_id = inst->GetSingleWordInOperand(i);
  1365. Instruction* element_inst = def_use_mgr->GetDef(element_id);
  1366. if (element_inst->opcode() != SpvOpCompositeExtract) {
  1367. return false;
  1368. }
  1369. if (element_inst->NumInOperands() != 2) {
  1370. return false;
  1371. }
  1372. if (element_inst->GetSingleWordInOperand(1) != i) {
  1373. return false;
  1374. }
  1375. if (i == 0) {
  1376. original_id =
  1377. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1378. } else if (original_id !=
  1379. element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
  1380. return false;
  1381. }
  1382. }
  1383. // The last check it to see that the object being extracted from is the
  1384. // correct type.
  1385. Instruction* original_inst = def_use_mgr->GetDef(original_id);
  1386. if (original_inst->type_id() != inst->type_id()) {
  1387. return false;
  1388. }
  1389. // Simplify by using the original object.
  1390. inst->SetOpcode(SpvOpCopyObject);
  1391. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
  1392. return true;
  1393. }
  1394. FoldingRule InsertFeedingExtract() {
  1395. return [](IRContext* context, Instruction* inst,
  1396. const std::vector<const analysis::Constant*>&) {
  1397. assert(inst->opcode() == SpvOpCompositeExtract &&
  1398. "Wrong opcode. Should be OpCompositeExtract.");
  1399. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1400. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1401. Instruction* cinst = def_use_mgr->GetDef(cid);
  1402. if (cinst->opcode() != SpvOpCompositeInsert) {
  1403. return false;
  1404. }
  1405. // Find the first position where the list of insert and extract indicies
  1406. // differ, if at all.
  1407. uint32_t i;
  1408. for (i = 1; i < inst->NumInOperands(); ++i) {
  1409. if (i + 1 >= cinst->NumInOperands()) {
  1410. break;
  1411. }
  1412. if (inst->GetSingleWordInOperand(i) !=
  1413. cinst->GetSingleWordInOperand(i + 1)) {
  1414. break;
  1415. }
  1416. }
  1417. // We are extracting the element that was inserted.
  1418. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
  1419. inst->SetOpcode(SpvOpCopyObject);
  1420. inst->SetInOperands(
  1421. {{SPV_OPERAND_TYPE_ID,
  1422. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
  1423. return true;
  1424. }
  1425. // Extracting the value that was inserted along with values for the base
  1426. // composite. Cannot do anything.
  1427. if (i == inst->NumInOperands()) {
  1428. return false;
  1429. }
  1430. // Extracting an element of the value that was inserted. Extract from
  1431. // that value directly.
  1432. if (i + 1 == cinst->NumInOperands()) {
  1433. std::vector<Operand> operands;
  1434. operands.push_back(
  1435. {SPV_OPERAND_TYPE_ID,
  1436. {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
  1437. for (; i < inst->NumInOperands(); ++i) {
  1438. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1439. {inst->GetSingleWordInOperand(i)}});
  1440. }
  1441. inst->SetInOperands(std::move(operands));
  1442. return true;
  1443. }
  1444. // Extracting a value that is disjoint from the element being inserted.
  1445. // Rewrite the extract to use the composite input to the insert.
  1446. std::vector<Operand> operands;
  1447. operands.push_back(
  1448. {SPV_OPERAND_TYPE_ID,
  1449. {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
  1450. for (i = 1; i < inst->NumInOperands(); ++i) {
  1451. operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
  1452. {inst->GetSingleWordInOperand(i)}});
  1453. }
  1454. inst->SetInOperands(std::move(operands));
  1455. return true;
  1456. };
  1457. }
  1458. // When a VectorShuffle is feeding an Extract, we can extract from one of the
  1459. // operands of the VectorShuffle. We just need to adjust the index in the
  1460. // extract instruction.
  1461. FoldingRule VectorShuffleFeedingExtract() {
  1462. return [](IRContext* context, Instruction* inst,
  1463. const std::vector<const analysis::Constant*>&) {
  1464. assert(inst->opcode() == SpvOpCompositeExtract &&
  1465. "Wrong opcode. Should be OpCompositeExtract.");
  1466. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1467. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1468. uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1469. Instruction* cinst = def_use_mgr->GetDef(cid);
  1470. if (cinst->opcode() != SpvOpVectorShuffle) {
  1471. return false;
  1472. }
  1473. // Find the size of the first vector operand of the VectorShuffle
  1474. Instruction* first_input =
  1475. def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
  1476. analysis::Type* first_input_type =
  1477. type_mgr->GetType(first_input->type_id());
  1478. assert(first_input_type->AsVector() &&
  1479. "Input to vector shuffle should be vectors.");
  1480. uint32_t first_input_size = first_input_type->AsVector()->element_count();
  1481. // Get index of the element the vector shuffle is placing in the position
  1482. // being extracted.
  1483. uint32_t new_index =
  1484. cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
  1485. // Extracting an undefined value so fold this extract into an undef.
  1486. const uint32_t undef_literal_value = 0xffffffff;
  1487. if (new_index == undef_literal_value) {
  1488. inst->SetOpcode(SpvOpUndef);
  1489. inst->SetInOperands({});
  1490. return true;
  1491. }
  1492. // Get the id of the of the vector the elemtent comes from, and update the
  1493. // index if needed.
  1494. uint32_t new_vector = 0;
  1495. if (new_index < first_input_size) {
  1496. new_vector = cinst->GetSingleWordInOperand(0);
  1497. } else {
  1498. new_vector = cinst->GetSingleWordInOperand(1);
  1499. new_index -= first_input_size;
  1500. }
  1501. // Update the extract instruction.
  1502. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1503. inst->SetInOperand(1, {new_index});
  1504. return true;
  1505. };
  1506. }
  1507. // When an FMix with is feeding an Extract that extracts an element whose
  1508. // corresponding |a| in the FMix is 0 or 1, we can extract from one of the
  1509. // operands of the FMix.
  1510. FoldingRule FMixFeedingExtract() {
  1511. return [](IRContext* context, Instruction* inst,
  1512. const std::vector<const analysis::Constant*>&) {
  1513. assert(inst->opcode() == SpvOpCompositeExtract &&
  1514. "Wrong opcode. Should be OpCompositeExtract.");
  1515. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1516. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1517. uint32_t composite_id =
  1518. inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
  1519. Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
  1520. if (composite_inst->opcode() != SpvOpExtInst) {
  1521. return false;
  1522. }
  1523. uint32_t inst_set_id =
  1524. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1525. if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
  1526. inst_set_id ||
  1527. composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
  1528. GLSLstd450FMix) {
  1529. return false;
  1530. }
  1531. // Get the |a| for the FMix instruction.
  1532. uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
  1533. std::unique_ptr<Instruction> a(inst->Clone(context));
  1534. a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
  1535. context->get_instruction_folder().FoldInstruction(a.get());
  1536. if (a->opcode() != SpvOpCopyObject) {
  1537. return false;
  1538. }
  1539. const analysis::Constant* a_const =
  1540. const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
  1541. if (!a_const) {
  1542. return false;
  1543. }
  1544. bool use_x = false;
  1545. assert(a_const->type()->AsFloat());
  1546. double element_value = a_const->GetValueAsDouble();
  1547. if (element_value == 0.0) {
  1548. use_x = true;
  1549. } else if (element_value == 1.0) {
  1550. use_x = false;
  1551. } else {
  1552. return false;
  1553. }
  1554. // Get the id of the of the vector the element comes from.
  1555. uint32_t new_vector = 0;
  1556. if (use_x) {
  1557. new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
  1558. } else {
  1559. new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
  1560. }
  1561. // Update the extract instruction.
  1562. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
  1563. return true;
  1564. };
  1565. }
  1566. FoldingRule RedundantPhi() {
  1567. // An OpPhi instruction where all values are the same or the result of the phi
  1568. // itself, can be replaced by the value itself.
  1569. return [](IRContext*, Instruction* inst,
  1570. const std::vector<const analysis::Constant*>&) {
  1571. assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
  1572. uint32_t incoming_value = 0;
  1573. for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
  1574. uint32_t op_id = inst->GetSingleWordInOperand(i);
  1575. if (op_id == inst->result_id()) {
  1576. continue;
  1577. }
  1578. if (incoming_value == 0) {
  1579. incoming_value = op_id;
  1580. } else if (op_id != incoming_value) {
  1581. // Found two possible value. Can't simplify.
  1582. return false;
  1583. }
  1584. }
  1585. if (incoming_value == 0) {
  1586. // Code looks invalid. Don't do anything.
  1587. return false;
  1588. }
  1589. // We have a single incoming value. Simplify using that value.
  1590. inst->SetOpcode(SpvOpCopyObject);
  1591. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
  1592. return true;
  1593. };
  1594. }
  1595. FoldingRule RedundantSelect() {
  1596. // An OpSelect instruction where both values are the same or the condition is
  1597. // constant can be replaced by one of the values
  1598. return [](IRContext*, Instruction* inst,
  1599. const std::vector<const analysis::Constant*>& constants) {
  1600. assert(inst->opcode() == SpvOpSelect &&
  1601. "Wrong opcode. Should be OpSelect.");
  1602. assert(inst->NumInOperands() == 3);
  1603. assert(constants.size() == 3);
  1604. uint32_t true_id = inst->GetSingleWordInOperand(1);
  1605. uint32_t false_id = inst->GetSingleWordInOperand(2);
  1606. if (true_id == false_id) {
  1607. // Both results are the same, condition doesn't matter
  1608. inst->SetOpcode(SpvOpCopyObject);
  1609. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1610. return true;
  1611. } else if (constants[0]) {
  1612. const analysis::Type* type = constants[0]->type();
  1613. if (type->AsBool()) {
  1614. // Scalar constant value, select the corresponding value.
  1615. inst->SetOpcode(SpvOpCopyObject);
  1616. if (constants[0]->AsNullConstant() ||
  1617. !constants[0]->AsBoolConstant()->value()) {
  1618. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1619. } else {
  1620. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
  1621. }
  1622. return true;
  1623. } else {
  1624. assert(type->AsVector());
  1625. if (constants[0]->AsNullConstant()) {
  1626. // All values come from false id.
  1627. inst->SetOpcode(SpvOpCopyObject);
  1628. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
  1629. return true;
  1630. } else {
  1631. // Convert to a vector shuffle.
  1632. std::vector<Operand> ops;
  1633. ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
  1634. ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
  1635. const analysis::VectorConstant* vector_const =
  1636. constants[0]->AsVectorConstant();
  1637. uint32_t size =
  1638. static_cast<uint32_t>(vector_const->GetComponents().size());
  1639. for (uint32_t i = 0; i != size; ++i) {
  1640. const analysis::Constant* component =
  1641. vector_const->GetComponents()[i];
  1642. if (component->AsNullConstant() ||
  1643. !component->AsBoolConstant()->value()) {
  1644. // Selecting from the false vector which is the second input
  1645. // vector to the shuffle. Offset the index by |size|.
  1646. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
  1647. } else {
  1648. // Selecting from true vector which is the first input vector to
  1649. // the shuffle.
  1650. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
  1651. }
  1652. }
  1653. inst->SetOpcode(SpvOpVectorShuffle);
  1654. inst->SetInOperands(std::move(ops));
  1655. return true;
  1656. }
  1657. }
  1658. }
  1659. return false;
  1660. };
  1661. }
  1662. enum class FloatConstantKind { Unknown, Zero, One };
  1663. FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
  1664. if (constant == nullptr) {
  1665. return FloatConstantKind::Unknown;
  1666. }
  1667. assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
  1668. if (constant->AsNullConstant()) {
  1669. return FloatConstantKind::Zero;
  1670. } else if (const analysis::VectorConstant* vc =
  1671. constant->AsVectorConstant()) {
  1672. const std::vector<const analysis::Constant*>& components =
  1673. vc->GetComponents();
  1674. assert(!components.empty());
  1675. FloatConstantKind kind = getFloatConstantKind(components[0]);
  1676. for (size_t i = 1; i < components.size(); ++i) {
  1677. if (getFloatConstantKind(components[i]) != kind) {
  1678. return FloatConstantKind::Unknown;
  1679. }
  1680. }
  1681. return kind;
  1682. } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
  1683. if (fc->IsZero()) return FloatConstantKind::Zero;
  1684. uint32_t width = fc->type()->AsFloat()->width();
  1685. if (width != 32 && width != 64) return FloatConstantKind::Unknown;
  1686. double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
  1687. if (value == 0.0) {
  1688. return FloatConstantKind::Zero;
  1689. } else if (value == 1.0) {
  1690. return FloatConstantKind::One;
  1691. } else {
  1692. return FloatConstantKind::Unknown;
  1693. }
  1694. } else {
  1695. return FloatConstantKind::Unknown;
  1696. }
  1697. }
  1698. FoldingRule RedundantFAdd() {
  1699. return [](IRContext*, Instruction* inst,
  1700. const std::vector<const analysis::Constant*>& constants) {
  1701. assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
  1702. assert(constants.size() == 2);
  1703. if (!inst->IsFloatingPointFoldingAllowed()) {
  1704. return false;
  1705. }
  1706. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1707. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1708. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1709. inst->SetOpcode(SpvOpCopyObject);
  1710. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1711. {inst->GetSingleWordInOperand(
  1712. kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
  1713. return true;
  1714. }
  1715. return false;
  1716. };
  1717. }
  1718. FoldingRule RedundantFSub() {
  1719. return [](IRContext*, Instruction* inst,
  1720. const std::vector<const analysis::Constant*>& constants) {
  1721. assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
  1722. assert(constants.size() == 2);
  1723. if (!inst->IsFloatingPointFoldingAllowed()) {
  1724. return false;
  1725. }
  1726. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1727. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1728. if (kind0 == FloatConstantKind::Zero) {
  1729. inst->SetOpcode(SpvOpFNegate);
  1730. inst->SetInOperands(
  1731. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
  1732. return true;
  1733. }
  1734. if (kind1 == FloatConstantKind::Zero) {
  1735. inst->SetOpcode(SpvOpCopyObject);
  1736. inst->SetInOperands(
  1737. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1738. return true;
  1739. }
  1740. return false;
  1741. };
  1742. }
  1743. FoldingRule RedundantFMul() {
  1744. return [](IRContext*, Instruction* inst,
  1745. const std::vector<const analysis::Constant*>& constants) {
  1746. assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
  1747. assert(constants.size() == 2);
  1748. if (!inst->IsFloatingPointFoldingAllowed()) {
  1749. return false;
  1750. }
  1751. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1752. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1753. if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
  1754. inst->SetOpcode(SpvOpCopyObject);
  1755. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1756. {inst->GetSingleWordInOperand(
  1757. kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
  1758. return true;
  1759. }
  1760. if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
  1761. inst->SetOpcode(SpvOpCopyObject);
  1762. inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
  1763. {inst->GetSingleWordInOperand(
  1764. kind0 == FloatConstantKind::One ? 1 : 0)}}});
  1765. return true;
  1766. }
  1767. return false;
  1768. };
  1769. }
  1770. FoldingRule RedundantFDiv() {
  1771. return [](IRContext*, Instruction* inst,
  1772. const std::vector<const analysis::Constant*>& constants) {
  1773. assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
  1774. assert(constants.size() == 2);
  1775. if (!inst->IsFloatingPointFoldingAllowed()) {
  1776. return false;
  1777. }
  1778. FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
  1779. FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
  1780. if (kind0 == FloatConstantKind::Zero) {
  1781. inst->SetOpcode(SpvOpCopyObject);
  1782. inst->SetInOperands(
  1783. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1784. return true;
  1785. }
  1786. if (kind1 == FloatConstantKind::One) {
  1787. inst->SetOpcode(SpvOpCopyObject);
  1788. inst->SetInOperands(
  1789. {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
  1790. return true;
  1791. }
  1792. return false;
  1793. };
  1794. }
  1795. FoldingRule RedundantFMix() {
  1796. return [](IRContext* context, Instruction* inst,
  1797. const std::vector<const analysis::Constant*>& constants) {
  1798. assert(inst->opcode() == SpvOpExtInst &&
  1799. "Wrong opcode. Should be OpExtInst.");
  1800. if (!inst->IsFloatingPointFoldingAllowed()) {
  1801. return false;
  1802. }
  1803. uint32_t instSetId =
  1804. context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  1805. if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
  1806. inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
  1807. GLSLstd450FMix) {
  1808. assert(constants.size() == 5);
  1809. FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
  1810. if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
  1811. inst->SetOpcode(SpvOpCopyObject);
  1812. inst->SetInOperands(
  1813. {{SPV_OPERAND_TYPE_ID,
  1814. {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
  1815. ? kFMixXIdInIdx
  1816. : kFMixYIdInIdx)}}});
  1817. return true;
  1818. }
  1819. }
  1820. return false;
  1821. };
  1822. }
  1823. // This rule handles addition of zero for integers.
  1824. FoldingRule RedundantIAdd() {
  1825. return [](IRContext* context, Instruction* inst,
  1826. const std::vector<const analysis::Constant*>& constants) {
  1827. assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
  1828. uint32_t operand = std::numeric_limits<uint32_t>::max();
  1829. const analysis::Type* operand_type = nullptr;
  1830. if (constants[0] && constants[0]->IsZero()) {
  1831. operand = inst->GetSingleWordInOperand(1);
  1832. operand_type = constants[0]->type();
  1833. } else if (constants[1] && constants[1]->IsZero()) {
  1834. operand = inst->GetSingleWordInOperand(0);
  1835. operand_type = constants[1]->type();
  1836. }
  1837. if (operand != std::numeric_limits<uint32_t>::max()) {
  1838. const analysis::Type* inst_type =
  1839. context->get_type_mgr()->GetType(inst->type_id());
  1840. if (inst_type->IsSame(operand_type)) {
  1841. inst->SetOpcode(SpvOpCopyObject);
  1842. } else {
  1843. inst->SetOpcode(SpvOpBitcast);
  1844. }
  1845. inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
  1846. return true;
  1847. }
  1848. return false;
  1849. };
  1850. }
  1851. // This rule look for a dot with a constant vector containing a single 1 and
  1852. // the rest 0s. This is the same as doing an extract.
  1853. FoldingRule DotProductDoingExtract() {
  1854. return [](IRContext* context, Instruction* inst,
  1855. const std::vector<const analysis::Constant*>& constants) {
  1856. assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
  1857. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  1858. if (!inst->IsFloatingPointFoldingAllowed()) {
  1859. return false;
  1860. }
  1861. for (int i = 0; i < 2; ++i) {
  1862. if (!constants[i]) {
  1863. continue;
  1864. }
  1865. const analysis::Vector* vector_type = constants[i]->type()->AsVector();
  1866. assert(vector_type && "Inputs to OpDot must be vectors.");
  1867. const analysis::Float* element_type =
  1868. vector_type->element_type()->AsFloat();
  1869. assert(element_type && "Inputs to OpDot must be vectors of floats.");
  1870. uint32_t element_width = element_type->width();
  1871. if (element_width != 32 && element_width != 64) {
  1872. return false;
  1873. }
  1874. std::vector<const analysis::Constant*> components;
  1875. components = constants[i]->GetVectorComponents(const_mgr);
  1876. const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
  1877. uint32_t component_with_one = kNotFound;
  1878. bool all_others_zero = true;
  1879. for (uint32_t j = 0; j < components.size(); ++j) {
  1880. const analysis::Constant* element = components[j];
  1881. double value =
  1882. (element_width == 32 ? element->GetFloat() : element->GetDouble());
  1883. if (value == 0.0) {
  1884. continue;
  1885. } else if (value == 1.0) {
  1886. if (component_with_one == kNotFound) {
  1887. component_with_one = j;
  1888. } else {
  1889. component_with_one = kNotFound;
  1890. break;
  1891. }
  1892. } else {
  1893. all_others_zero = false;
  1894. break;
  1895. }
  1896. }
  1897. if (!all_others_zero || component_with_one == kNotFound) {
  1898. continue;
  1899. }
  1900. std::vector<Operand> operands;
  1901. operands.push_back(
  1902. {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
  1903. operands.push_back(
  1904. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
  1905. inst->SetOpcode(SpvOpCompositeExtract);
  1906. inst->SetInOperands(std::move(operands));
  1907. return true;
  1908. }
  1909. return false;
  1910. };
  1911. }
  1912. // If we are storing an undef, then we can remove the store.
  1913. //
  1914. // TODO: We can do something similar for OpImageWrite, but checking for volatile
  1915. // is complicated. Waiting to see if it is needed.
  1916. FoldingRule StoringUndef() {
  1917. return [](IRContext* context, Instruction* inst,
  1918. const std::vector<const analysis::Constant*>&) {
  1919. assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
  1920. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1921. // If this is a volatile store, the store cannot be removed.
  1922. if (inst->NumInOperands() == 3) {
  1923. if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
  1924. return false;
  1925. }
  1926. }
  1927. uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
  1928. Instruction* object_inst = def_use_mgr->GetDef(object_id);
  1929. if (object_inst->opcode() == SpvOpUndef) {
  1930. inst->ToNop();
  1931. return true;
  1932. }
  1933. return false;
  1934. };
  1935. }
  1936. FoldingRule VectorShuffleFeedingShuffle() {
  1937. return [](IRContext* context, Instruction* inst,
  1938. const std::vector<const analysis::Constant*>&) {
  1939. assert(inst->opcode() == SpvOpVectorShuffle &&
  1940. "Wrong opcode. Should be OpVectorShuffle.");
  1941. analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
  1942. analysis::TypeManager* type_mgr = context->get_type_mgr();
  1943. Instruction* feeding_shuffle_inst =
  1944. def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  1945. analysis::Vector* op0_type =
  1946. type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
  1947. uint32_t op0_length = op0_type->element_count();
  1948. bool feeder_is_op0 = true;
  1949. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1950. feeding_shuffle_inst =
  1951. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  1952. feeder_is_op0 = false;
  1953. }
  1954. if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
  1955. return false;
  1956. }
  1957. Instruction* feeder2 =
  1958. def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
  1959. analysis::Vector* feeder_op0_type =
  1960. type_mgr->GetType(feeder2->type_id())->AsVector();
  1961. uint32_t feeder_op0_length = feeder_op0_type->element_count();
  1962. uint32_t new_feeder_id = 0;
  1963. std::vector<Operand> new_operands;
  1964. new_operands.resize(
  1965. 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
  1966. const uint32_t undef_literal = 0xffffffff;
  1967. for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
  1968. uint32_t component_index = inst->GetSingleWordInOperand(op);
  1969. // Do not interpret the undefined value literal as coming from operand 1.
  1970. if (component_index != undef_literal &&
  1971. feeder_is_op0 == (component_index < op0_length)) {
  1972. // This component comes from the feeding_shuffle_inst. Update
  1973. // |component_index| to be the index into the operand of the feeder.
  1974. // Adjust component_index to get the index into the operands of the
  1975. // feeding_shuffle_inst.
  1976. if (component_index >= op0_length) {
  1977. component_index -= op0_length;
  1978. }
  1979. component_index =
  1980. feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
  1981. // Check if we are using a component from the first or second operand of
  1982. // the feeding instruction.
  1983. if (component_index < feeder_op0_length) {
  1984. if (new_feeder_id == 0) {
  1985. // First time through, save the id of the operand the element comes
  1986. // from.
  1987. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
  1988. } else if (new_feeder_id !=
  1989. feeding_shuffle_inst->GetSingleWordInOperand(0)) {
  1990. // We need both elements of the feeding_shuffle_inst, so we cannot
  1991. // fold.
  1992. return false;
  1993. }
  1994. } else {
  1995. if (new_feeder_id == 0) {
  1996. // First time through, save the id of the operand the element comes
  1997. // from.
  1998. new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
  1999. } else if (new_feeder_id !=
  2000. feeding_shuffle_inst->GetSingleWordInOperand(1)) {
  2001. // We need both elements of the feeding_shuffle_inst, so we cannot
  2002. // fold.
  2003. return false;
  2004. }
  2005. component_index -= feeder_op0_length;
  2006. }
  2007. if (!feeder_is_op0) {
  2008. component_index += op0_length;
  2009. }
  2010. }
  2011. new_operands.push_back(
  2012. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
  2013. }
  2014. if (new_feeder_id == 0) {
  2015. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  2016. const analysis::Type* type =
  2017. type_mgr->GetType(feeding_shuffle_inst->type_id());
  2018. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  2019. new_feeder_id =
  2020. const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
  2021. }
  2022. if (feeder_is_op0) {
  2023. // If the size of the first vector operand changed then the indices
  2024. // referring to the second operand need to be adjusted.
  2025. Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
  2026. analysis::Type* new_feeder_type =
  2027. type_mgr->GetType(new_feeder_inst->type_id());
  2028. uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
  2029. int32_t adjustment = op0_length - new_op0_size;
  2030. if (adjustment != 0) {
  2031. for (uint32_t i = 2; i < new_operands.size(); i++) {
  2032. if (inst->GetSingleWordInOperand(i) >= op0_length) {
  2033. new_operands[i].words[0] -= adjustment;
  2034. }
  2035. }
  2036. }
  2037. new_operands[0].words[0] = new_feeder_id;
  2038. new_operands[1] = inst->GetInOperand(1);
  2039. } else {
  2040. new_operands[1].words[0] = new_feeder_id;
  2041. new_operands[0] = inst->GetInOperand(0);
  2042. }
  2043. inst->SetInOperands(std::move(new_operands));
  2044. return true;
  2045. };
  2046. }
  2047. // Removes duplicate ids from the interface list of an OpEntryPoint
  2048. // instruction.
  2049. FoldingRule RemoveRedundantOperands() {
  2050. return [](IRContext*, Instruction* inst,
  2051. const std::vector<const analysis::Constant*>&) {
  2052. assert(inst->opcode() == SpvOpEntryPoint &&
  2053. "Wrong opcode. Should be OpEntryPoint.");
  2054. bool has_redundant_operand = false;
  2055. std::unordered_set<uint32_t> seen_operands;
  2056. std::vector<Operand> new_operands;
  2057. new_operands.emplace_back(inst->GetOperand(0));
  2058. new_operands.emplace_back(inst->GetOperand(1));
  2059. new_operands.emplace_back(inst->GetOperand(2));
  2060. for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
  2061. if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
  2062. new_operands.emplace_back(inst->GetOperand(i));
  2063. } else {
  2064. has_redundant_operand = true;
  2065. }
  2066. }
  2067. if (!has_redundant_operand) {
  2068. return false;
  2069. }
  2070. inst->SetInOperands(std::move(new_operands));
  2071. return true;
  2072. };
  2073. }
  2074. // If an image instruction's operand is a constant, updates the image operand
  2075. // flag from Offset to ConstOffset.
  2076. FoldingRule UpdateImageOperands() {
  2077. return [](IRContext*, Instruction* inst,
  2078. const std::vector<const analysis::Constant*>& constants) {
  2079. const auto opcode = inst->opcode();
  2080. (void)opcode;
  2081. assert((opcode == SpvOpImageSampleImplicitLod ||
  2082. opcode == SpvOpImageSampleExplicitLod ||
  2083. opcode == SpvOpImageSampleDrefImplicitLod ||
  2084. opcode == SpvOpImageSampleDrefExplicitLod ||
  2085. opcode == SpvOpImageSampleProjImplicitLod ||
  2086. opcode == SpvOpImageSampleProjExplicitLod ||
  2087. opcode == SpvOpImageSampleProjDrefImplicitLod ||
  2088. opcode == SpvOpImageSampleProjDrefExplicitLod ||
  2089. opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
  2090. opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
  2091. opcode == SpvOpImageWrite ||
  2092. opcode == SpvOpImageSparseSampleImplicitLod ||
  2093. opcode == SpvOpImageSparseSampleExplicitLod ||
  2094. opcode == SpvOpImageSparseSampleDrefImplicitLod ||
  2095. opcode == SpvOpImageSparseSampleDrefExplicitLod ||
  2096. opcode == SpvOpImageSparseSampleProjImplicitLod ||
  2097. opcode == SpvOpImageSparseSampleProjExplicitLod ||
  2098. opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
  2099. opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
  2100. opcode == SpvOpImageSparseFetch ||
  2101. opcode == SpvOpImageSparseGather ||
  2102. opcode == SpvOpImageSparseDrefGather ||
  2103. opcode == SpvOpImageSparseRead) &&
  2104. "Wrong opcode. Should be an image instruction.");
  2105. int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
  2106. if (operand_index >= 0) {
  2107. auto image_operands = inst->GetSingleWordInOperand(operand_index);
  2108. if (image_operands & SpvImageOperandsOffsetMask) {
  2109. uint32_t offset_operand_index = operand_index + 1;
  2110. if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
  2111. if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
  2112. if (image_operands & SpvImageOperandsGradMask)
  2113. offset_operand_index += 2;
  2114. assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
  2115. "Offset and ConstOffset may not be used together");
  2116. if (offset_operand_index < inst->NumOperands()) {
  2117. if (constants[offset_operand_index]) {
  2118. image_operands = image_operands | SpvImageOperandsConstOffsetMask;
  2119. image_operands = image_operands & ~SpvImageOperandsOffsetMask;
  2120. inst->SetInOperand(operand_index, {image_operands});
  2121. return true;
  2122. }
  2123. }
  2124. }
  2125. }
  2126. return false;
  2127. };
  2128. }
  2129. } // namespace
  2130. void FoldingRules::AddFoldingRules() {
  2131. // Add all folding rules to the list for the opcodes to which they apply.
  2132. // Note that the order in which rules are added to the list matters. If a rule
  2133. // applies to the instruction, the rest of the rules will not be attempted.
  2134. // Take that into consideration.
  2135. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
  2136. rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
  2137. rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
  2138. rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
  2139. rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
  2140. rules_[SpvOpDot].push_back(DotProductDoingExtract());
  2141. rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
  2142. rules_[SpvOpFAdd].push_back(RedundantFAdd());
  2143. rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
  2144. rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
  2145. rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
  2146. rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
  2147. rules_[SpvOpFAdd].push_back(FactorAddMuls());
  2148. rules_[SpvOpFDiv].push_back(RedundantFDiv());
  2149. rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
  2150. rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
  2151. rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
  2152. rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
  2153. rules_[SpvOpFMul].push_back(RedundantFMul());
  2154. rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
  2155. rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
  2156. rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
  2157. rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
  2158. rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
  2159. rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
  2160. rules_[SpvOpFSub].push_back(RedundantFSub());
  2161. rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
  2162. rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
  2163. rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
  2164. rules_[SpvOpIAdd].push_back(RedundantIAdd());
  2165. rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
  2166. rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
  2167. rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
  2168. rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
  2169. rules_[SpvOpIAdd].push_back(FactorAddMuls());
  2170. rules_[SpvOpIMul].push_back(IntMultipleBy1());
  2171. rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
  2172. rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
  2173. rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
  2174. rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
  2175. rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
  2176. rules_[SpvOpPhi].push_back(RedundantPhi());
  2177. rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
  2178. rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
  2179. rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
  2180. rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
  2181. rules_[SpvOpSelect].push_back(RedundantSelect());
  2182. rules_[SpvOpStore].push_back(StoringUndef());
  2183. rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
  2184. rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
  2185. rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
  2186. rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
  2187. rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
  2188. rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
  2189. rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
  2190. rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
  2191. rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
  2192. rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
  2193. rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
  2194. rules_[SpvOpImageGather].push_back(UpdateImageOperands());
  2195. rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
  2196. rules_[SpvOpImageRead].push_back(UpdateImageOperands());
  2197. rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
  2198. rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
  2199. rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
  2200. rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
  2201. UpdateImageOperands());
  2202. rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
  2203. UpdateImageOperands());
  2204. rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
  2205. UpdateImageOperands());
  2206. rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
  2207. UpdateImageOperands());
  2208. rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
  2209. UpdateImageOperands());
  2210. rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
  2211. UpdateImageOperands());
  2212. rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
  2213. rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
  2214. rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
  2215. rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
  2216. FeatureManager* feature_manager = context_->get_feature_mgr();
  2217. // Add rules for GLSLstd450
  2218. uint32_t ext_inst_glslstd450_id =
  2219. feature_manager->GetExtInstImportId_GLSLstd450();
  2220. if (ext_inst_glslstd450_id != 0) {
  2221. ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
  2222. RedundantFMix());
  2223. }
  2224. }
  2225. } // namespace opt
  2226. } // namespace spvtools