ScalarReplAggregatesHLSL.cpp 231 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234
  1. //===- ScalarReplAggregatesHLSL.cpp - Scalar Replacement of Aggregates ----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // Based on ScalarReplAggregates.cpp. The difference is HLSL version will keep
  12. // array so it can break up all structure.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/ADT/SetVector.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/ADT/DenseSet.h"
  18. #include "llvm/ADT/Statistic.h"
  19. #include "llvm/Analysis/AssumptionCache.h"
  20. #include "llvm/Analysis/Loads.h"
  21. #include "llvm/Analysis/ValueTracking.h"
  22. #include "llvm/Analysis/PostDominators.h"
  23. #include "llvm/IR/CallSite.h"
  24. #include "llvm/IR/Constants.h"
  25. #include "llvm/IR/DIBuilder.h"
  26. #include "llvm/IR/DataLayout.h"
  27. #include "llvm/IR/DebugInfo.h"
  28. #include "llvm/IR/DerivedTypes.h"
  29. #include "llvm/IR/Dominators.h"
  30. #include "llvm/IR/Function.h"
  31. #include "llvm/IR/GetElementPtrTypeIterator.h"
  32. #include "llvm/IR/GlobalVariable.h"
  33. #include "llvm/IR/IRBuilder.h"
  34. #include "llvm/IR/Instructions.h"
  35. #include "llvm/IR/IntrinsicInst.h"
  36. #include "llvm/IR/LLVMContext.h"
  37. #include "llvm/IR/InstIterator.h"
  38. #include "llvm/IR/Module.h"
  39. #include "llvm/IR/Operator.h"
  40. #include "llvm/Pass.h"
  41. #include "llvm/Support/Debug.h"
  42. #include "llvm/Support/ErrorHandling.h"
  43. #include "llvm/Support/MathExtras.h"
  44. #include "llvm/Support/raw_ostream.h"
  45. #include "llvm/Transforms/Scalar.h"
  46. #include "llvm/Transforms/Utils/Local.h"
  47. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  48. #include "llvm/Transforms/Utils/SSAUpdater.h"
  49. #include "llvm/Transforms/Utils/Local.h"
  50. #include "dxc/HLSL/HLOperations.h"
  51. #include "dxc/DXIL/DxilConstants.h"
  52. #include "dxc/HLSL/HLModule.h"
  53. #include "dxc/DXIL/DxilUtil.h"
  54. #include "dxc/DXIL/DxilModule.h"
  55. #include "dxc/HlslIntrinsicOp.h"
  56. #include "dxc/DXIL/DxilTypeSystem.h"
  57. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  58. #include "dxc/HLSL/HLMatrixType.h"
  59. #include "dxc/DXIL/DxilOperations.h"
  60. #include "dxc/HLSL/HLLowerUDT.h"
  61. #include "dxc/HLSL/HLUtil.h"
  62. #include <deque>
  63. #include <unordered_map>
  64. #include <unordered_set>
  65. #include <queue>
  66. using namespace llvm;
  67. using namespace hlsl;
  68. #define DEBUG_TYPE "scalarreplhlsl"
  69. STATISTIC(NumReplaced, "Number of allocas broken up");
  70. namespace {
  71. class SROA_Helper {
  72. public:
  73. // Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  74. // Then do SROA on V.
  75. static bool DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
  76. Type *&BrokenUpTy, uint64_t &NumInstances,
  77. IRBuilder<> &Builder, bool bFlatVector,
  78. bool hasPrecise, DxilTypeSystem &typeSys,
  79. const DataLayout &DL,
  80. SmallVector<Value *, 32> &DeadInsts,
  81. DominatorTree *DT);
  82. static bool DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
  83. IRBuilder<> &Builder, bool bFlatVector,
  84. bool hasPrecise, DxilTypeSystem &typeSys,
  85. const DataLayout &DL,
  86. SmallVector<Value *, 32> &DeadInsts,
  87. DominatorTree *DT);
  88. static unsigned GetEltAlign(unsigned ValueAlign, const DataLayout &DL,
  89. Type *EltTy, unsigned Offset);
  90. // Lower memcpy related to V.
  91. static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
  92. DxilTypeSystem &typeSys, const DataLayout &DL,
  93. DominatorTree *DT, bool bAllowReplace);
  94. static void MarkEmptyStructUsers(Value *V,
  95. SmallVector<Value *, 32> &DeadInsts);
  96. static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
  97. private:
  98. SROA_Helper(Value *V, ArrayRef<Value *> Elts,
  99. SmallVector<Value *, 32> &DeadInsts, DxilTypeSystem &ts,
  100. const DataLayout &dl, DominatorTree *dt)
  101. : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts), typeSys(ts), DL(dl), DT(dt) {}
  102. void RewriteForScalarRepl(Value *V, IRBuilder<> &Builder);
  103. private:
  104. // Must be a pointer type val.
  105. Value * OldVal;
  106. // Flattened elements for OldVal.
  107. ArrayRef<Value*> NewElts;
  108. SmallVector<Value *, 32> &DeadInsts;
  109. DxilTypeSystem &typeSys;
  110. const DataLayout &DL;
  111. DominatorTree *DT;
  112. void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
  113. void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
  114. void RewriteForAddrSpaceCast(Value *user, IRBuilder<> &Builder);
  115. void RewriteForLoad(LoadInst *loadInst);
  116. void RewriteForStore(StoreInst *storeInst);
  117. void RewriteMemIntrin(MemIntrinsic *MI, Value *OldV);
  118. void RewriteCall(CallInst *CI);
  119. void RewriteBitCast(BitCastInst *BCI);
  120. void RewriteCallArg(CallInst *CI, unsigned ArgIdx, bool bIn, bool bOut);
  121. };
  122. }
  123. static unsigned getNestedLevelInStruct(const Type *ty) {
  124. unsigned lvl = 0;
  125. while (ty->isStructTy()) {
  126. if (ty->getStructNumElements() != 1)
  127. break;
  128. ty = ty->getStructElementType(0);
  129. lvl++;
  130. }
  131. return lvl;
  132. }
  133. // After SROA'ing a given value into a series of elements,
  134. // creates the debug info for the storage of the individual elements.
  135. static void addDebugInfoForElements(Value *ParentVal,
  136. Type *BrokenUpTy, uint64_t NumInstances,
  137. ArrayRef<Value*> Elems, const DataLayout &DatLayout,
  138. DIBuilder *DbgBuilder) {
  139. // Extract the data we need from the parent value,
  140. // depending on whether it is an alloca, argument or global variable.
  141. Type *ParentTy;
  142. unsigned ParentBitPieceOffset = 0;
  143. std::vector<DxilDIArrayDim> DIArrayDims;
  144. DIVariable *ParentDbgVariable;
  145. DIExpression *ParentDbgExpr;
  146. DILocation *ParentDbgLocation;
  147. Instruction *DbgDeclareInsertPt = nullptr;
  148. if (isa<GlobalVariable>(ParentVal)) {
  149. llvm_unreachable("Not implemented: sroa debug info propagation for global vars.");
  150. }
  151. else {
  152. if (AllocaInst *ParentAlloca = dyn_cast<AllocaInst>(ParentVal))
  153. ParentTy = ParentAlloca->getAllocatedType();
  154. else
  155. ParentTy = cast<Argument>(ParentVal)->getType();
  156. DbgDeclareInst *ParentDbgDeclare = llvm::FindAllocaDbgDeclare(ParentVal);
  157. if (ParentDbgDeclare == nullptr) return;
  158. // Get the bit piece offset
  159. if ((ParentDbgExpr = ParentDbgDeclare->getExpression())) {
  160. if (ParentDbgExpr->isBitPiece()) {
  161. ParentBitPieceOffset = ParentDbgExpr->getBitPieceOffset();
  162. }
  163. }
  164. ParentDbgVariable = ParentDbgDeclare->getVariable();
  165. ParentDbgLocation = ParentDbgDeclare->getDebugLoc();
  166. DbgDeclareInsertPt = ParentDbgDeclare;
  167. // Read the extra layout metadata, if any
  168. unsigned ParentBitPieceOffsetFromMD = 0;
  169. if (DxilMDHelper::GetVariableDebugLayout(ParentDbgDeclare, ParentBitPieceOffsetFromMD, DIArrayDims)) {
  170. // The offset is redundant for local variables and only necessary for global variables.
  171. DXASSERT(ParentBitPieceOffsetFromMD == ParentBitPieceOffset,
  172. "Bit piece offset mismatch between llvm.dbg.declare and DXIL metadata.");
  173. }
  174. }
  175. // If the type that was broken up is nested in arrays,
  176. // then each element will also be an array,
  177. // but the continuity between successive elements of the original aggregate
  178. // will have been broken, such that we must store the stride to rebuild it.
  179. // For example: [2 x {i32, float}] => [2 x i32], [2 x float], each with stride 64 bits
  180. if (NumInstances > 1 && Elems.size() > 1) {
  181. // Existing dimensions already account for part of the stride
  182. uint64_t NewDimNumElements = NumInstances;
  183. for (const DxilDIArrayDim& ArrayDim : DIArrayDims) {
  184. DXASSERT(NewDimNumElements % ArrayDim.NumElements == 0,
  185. "Debug array stride is inconsistent with the number of elements.");
  186. NewDimNumElements /= ArrayDim.NumElements;
  187. }
  188. // Add a stride dimension
  189. DxilDIArrayDim NewDIArrayDim = {};
  190. NewDIArrayDim.StrideInBits = (unsigned)DatLayout.getTypeAllocSizeInBits(BrokenUpTy);
  191. NewDIArrayDim.NumElements = (unsigned)NewDimNumElements;
  192. DIArrayDims.emplace_back(NewDIArrayDim);
  193. }
  194. else {
  195. DIArrayDims.clear();
  196. }
  197. // Create the debug info for each element
  198. for (unsigned ElemIdx = 0; ElemIdx < Elems.size(); ++ElemIdx) {
  199. // Figure out the offset of the element in the broken up type
  200. unsigned ElemBitPieceOffset = ParentBitPieceOffset;
  201. if (StructType *ParentStructTy = dyn_cast<StructType>(BrokenUpTy)) {
  202. DXASSERT_NOMSG(Elems.size() == ParentStructTy->getNumElements());
  203. ElemBitPieceOffset += (unsigned)DatLayout.getStructLayout(ParentStructTy)->getElementOffsetInBits(ElemIdx);
  204. }
  205. else if (VectorType *ParentVecTy = dyn_cast<VectorType>(BrokenUpTy)) {
  206. DXASSERT_NOMSG(Elems.size() == ParentVecTy->getNumElements());
  207. ElemBitPieceOffset += (unsigned)DatLayout.getTypeStoreSizeInBits(ParentVecTy->getElementType()) * ElemIdx;
  208. }
  209. else if (ArrayType *ParentArrayTy = dyn_cast<ArrayType>(BrokenUpTy)) {
  210. DXASSERT_NOMSG(Elems.size() == ParentArrayTy->getNumElements());
  211. ElemBitPieceOffset += (unsigned)DatLayout.getTypeStoreSizeInBits(ParentArrayTy->getElementType()) * ElemIdx;
  212. }
  213. // The bit_piece can only represent the leading contiguous bytes.
  214. // If strides are involved, we'll need additional metadata.
  215. Type *ElemTy = Elems[ElemIdx]->getType()->getPointerElementType();
  216. unsigned ElemBitPieceSize = (unsigned)DatLayout.getTypeStoreSizeInBits(ElemTy);
  217. for (const DxilDIArrayDim& ArrayDim : DIArrayDims)
  218. ElemBitPieceSize /= ArrayDim.NumElements;
  219. if (AllocaInst *ElemAlloca = dyn_cast<AllocaInst>(Elems[ElemIdx])) {
  220. // Local variables get an @llvm.dbg.declare plus optional metadata for layout stride information.
  221. DIExpression *ElemDbgExpr = nullptr;
  222. if (ElemBitPieceOffset == 0 && DatLayout.getTypeAllocSizeInBits(ParentTy) == ElemBitPieceSize) {
  223. ElemDbgExpr = DbgBuilder->createExpression();
  224. }
  225. else {
  226. ElemDbgExpr = DbgBuilder->createBitPieceExpression(ElemBitPieceOffset, ElemBitPieceSize);
  227. }
  228. DXASSERT_NOMSG(DbgBuilder != nullptr);
  229. DbgDeclareInst *EltDDI = cast<DbgDeclareInst>(DbgBuilder->insertDeclare(
  230. ElemAlloca, cast<DILocalVariable>(ParentDbgVariable), ElemDbgExpr, ParentDbgLocation, DbgDeclareInsertPt));
  231. if (!DIArrayDims.empty()) DxilMDHelper::SetVariableDebugLayout(EltDDI, ElemBitPieceOffset, DIArrayDims);
  232. }
  233. else {
  234. llvm_unreachable("Non-AllocaInst SROA'd elements.");
  235. }
  236. }
  237. }
  238. /// Returns first GEP index that indexes a struct member, or 0 otherwise.
  239. /// Ignores initial ptr index.
  240. static unsigned FindFirstStructMemberIdxInGEP(GEPOperator *GEP) {
  241. StructType *ST = dyn_cast<StructType>(
  242. GEP->getPointerOperandType()->getPointerElementType());
  243. int index = 1;
  244. for (auto it = gep_type_begin(GEP), E = gep_type_end(GEP); it != E;
  245. ++it, ++index) {
  246. if (ST) {
  247. DXASSERT(!HLMatrixType::isa(ST) && !dxilutil::IsHLSLObjectType(ST),
  248. "otherwise, indexing into hlsl object");
  249. return index;
  250. }
  251. ST = dyn_cast<StructType>(it->getPointerElementType());
  252. }
  253. return 0;
  254. }
  255. /// Return true when ptr should not be SROA'd or copied, but used directly
  256. /// by a function in its lowered form. Also collect uses for translation.
  257. /// What is meant by directly here:
  258. /// Possibly accessed through GEP array index or address space cast, but
  259. /// not under another struct member (always allow SROA of outer struct).
  260. typedef SmallMapVector<CallInst*, unsigned, 4> FunctionUseMap;
  261. static unsigned IsPtrUsedByLoweredFn(
  262. Value *V, FunctionUseMap &CollectedUses) {
  263. bool bFound = false;
  264. for (Use &U : V->uses()) {
  265. User *user = U.getUser();
  266. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  267. unsigned foundIdx = (unsigned)-1;
  268. Function *F = CI->getCalledFunction();
  269. Type *Ty = V->getType();
  270. if (F->isDeclaration() && !F->isIntrinsic() &&
  271. Ty->isPointerTy()) {
  272. HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
  273. if (group == HLOpcodeGroup::HLIntrinsic) {
  274. unsigned opIdx = U.getOperandNo();
  275. switch ((IntrinsicOp)hlsl::GetHLOpcode(CI)) {
  276. // TODO: Lower these as well, along with function parameter types
  277. //case IntrinsicOp::IOP_TraceRay:
  278. // if (opIdx != HLOperandIndex::kTraceRayPayLoadOpIdx)
  279. // continue;
  280. // break;
  281. //case IntrinsicOp::IOP_ReportHit:
  282. // if (opIdx != HLOperandIndex::kReportIntersectionAttributeOpIdx)
  283. // continue;
  284. // break;
  285. //case IntrinsicOp::IOP_CallShader:
  286. // if (opIdx != HLOperandIndex::kCallShaderPayloadOpIdx)
  287. // continue;
  288. // break;
  289. case IntrinsicOp::IOP_DispatchMesh:
  290. if (opIdx != HLOperandIndex::kDispatchMeshOpPayload)
  291. continue;
  292. break;
  293. default:
  294. continue;
  295. }
  296. foundIdx = opIdx;
  297. // TODO: Lower these as well, along with function parameter types
  298. //} else if (group == HLOpcodeGroup::NotHL) {
  299. // foundIdx = U.getOperandNo();
  300. }
  301. }
  302. if (foundIdx != (unsigned)-1) {
  303. bFound = true;
  304. auto insRes = CollectedUses.insert(std::make_pair(CI, foundIdx));
  305. DXASSERT_LOCALVAR(insRes, insRes.second,
  306. "otherwise, multiple uses in single call");
  307. }
  308. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  309. // Not what we are looking for if GEP result is not [array of] struct.
  310. // If use is under struct member, we can still SROA the outer struct.
  311. if (!dxilutil::StripArrayTypes(GEP->getType()->getPointerElementType())
  312. ->isStructTy() ||
  313. FindFirstStructMemberIdxInGEP(cast<GEPOperator>(GEP)))
  314. continue;
  315. if (IsPtrUsedByLoweredFn(user, CollectedUses))
  316. bFound = true;
  317. } else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
  318. if (IsPtrUsedByLoweredFn(user, CollectedUses))
  319. bFound = true;
  320. } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
  321. unsigned opcode = CE->getOpcode();
  322. if (opcode == Instruction::AddrSpaceCast || opcode == Instruction::GetElementPtr)
  323. if (IsPtrUsedByLoweredFn(user, CollectedUses))
  324. bFound = true;
  325. }
  326. }
  327. return bFound;
  328. }
  329. /// Rewrite call to natively use an argument with addrspace cast/bitcast
  330. static CallInst *RewriteIntrinsicCallForCastedArg(CallInst *CI, unsigned argIdx) {
  331. Function *F = CI->getCalledFunction();
  332. HLOpcodeGroup group = GetHLOpcodeGroupByName(F);
  333. DXASSERT_NOMSG(group == HLOpcodeGroup::HLIntrinsic);
  334. unsigned opcode = GetHLOpcode(CI);
  335. SmallVector<Type *, 8> newArgTypes(CI->getFunctionType()->param_begin(),
  336. CI->getFunctionType()->param_end());
  337. SmallVector<Value *, 8> newArgs(CI->arg_operands());
  338. Value *newArg = CI->getOperand(argIdx)->stripPointerCasts();
  339. newArgTypes[argIdx] = newArg->getType();
  340. newArgs[argIdx] = newArg;
  341. FunctionType *newFuncTy = FunctionType::get(CI->getType(), newArgTypes, false);
  342. Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode,
  343. F->getAttributes().getFnAttributes());
  344. IRBuilder<> Builder(CI);
  345. return Builder.CreateCall(newF, newArgs);
  346. }
  347. /// Translate pointer for cases where intrinsics use UDT pointers directly
  348. /// Return existing or new ptr if needs preserving,
  349. /// otherwise nullptr to proceed with existing checks and SROA.
  350. static Value *TranslatePtrIfUsedByLoweredFn(
  351. Value *Ptr, DxilTypeSystem &TypeSys) {
  352. if (!Ptr->getType()->isPointerTy())
  353. return nullptr;
  354. Type *Ty = Ptr->getType()->getPointerElementType();
  355. SmallVector<unsigned, 4> outerToInnerLengths;
  356. Ty = dxilutil::StripArrayTypes(Ty, &outerToInnerLengths);
  357. if (!Ty->isStructTy())
  358. return nullptr;
  359. if (HLMatrixType::isa(Ty) || dxilutil::IsHLSLObjectType(Ty))
  360. return nullptr;
  361. unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
  362. FunctionUseMap FunctionUses;
  363. if (!IsPtrUsedByLoweredFn(Ptr, FunctionUses))
  364. return nullptr;
  365. // Translate vectors to arrays in type, but don't SROA
  366. Type *NewTy = GetLoweredUDT(cast<StructType>(Ty), &TypeSys);
  367. // No work to do here, but prevent SROA.
  368. if (Ty == NewTy && AddrSpace != DXIL::kTGSMAddrSpace)
  369. return Ptr;
  370. // If type changed, replace value, otherwise casting may still
  371. // require a rewrite of the calls.
  372. Value *NewPtr = Ptr;
  373. if (Ty != NewTy) {
  374. NewTy = dxilutil::WrapInArrayTypes(NewTy, outerToInnerLengths);
  375. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
  376. Module &M = *GV->getParent();
  377. // Rewrite init expression for arrays instead of vectors
  378. Constant *Init = GV->hasInitializer() ?
  379. GV->getInitializer() : UndefValue::get(Ptr->getType());
  380. Constant *NewInit = TranslateInitForLoweredUDT(
  381. Init, NewTy, &TypeSys);
  382. // Replace with new GV, and rewrite vector load/store users
  383. GlobalVariable *NewGV = new GlobalVariable(
  384. M, NewTy, GV->isConstant(), GV->getLinkage(),
  385. NewInit, GV->getName(), /*InsertBefore*/ GV,
  386. GV->getThreadLocalMode(), AddrSpace);
  387. NewPtr = NewGV;
  388. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
  389. IRBuilder<> Builder(AI);
  390. AllocaInst * NewAI = Builder.CreateAlloca(NewTy, nullptr, AI->getName());
  391. NewPtr = NewAI;
  392. } else {
  393. DXASSERT(false, "Ptr must be global or alloca");
  394. }
  395. // This will rewrite vector load/store users
  396. // and insert bitcasts for CallInst users
  397. ReplaceUsesForLoweredUDT(Ptr, NewPtr);
  398. }
  399. // Rewrite the HLIntrinsic calls
  400. for (auto it : FunctionUses) {
  401. CallInst *CI = it.first;
  402. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  403. if (group == HLOpcodeGroup::NotHL)
  404. continue;
  405. CallInst *newCI = RewriteIntrinsicCallForCastedArg(CI, it.second);
  406. CI->replaceAllUsesWith(newCI);
  407. CI->eraseFromParent();
  408. }
  409. return NewPtr;
  410. }
  411. /// isHomogeneousAggregate - Check if type T is a struct or array containing
  412. /// elements of the same type (which is always true for arrays). If so,
  413. /// return true with NumElts and EltTy set to the number of elements and the
  414. /// element type, respectively.
  415. static bool isHomogeneousAggregate(Type *T, unsigned &NumElts, Type *&EltTy) {
  416. if (ArrayType *AT = dyn_cast<ArrayType>(T)) {
  417. NumElts = AT->getNumElements();
  418. EltTy = (NumElts == 0 ? nullptr : AT->getElementType());
  419. return true;
  420. }
  421. if (StructType *ST = dyn_cast<StructType>(T)) {
  422. NumElts = ST->getNumContainedTypes();
  423. EltTy = (NumElts == 0 ? nullptr : ST->getContainedType(0));
  424. for (unsigned n = 1; n < NumElts; ++n) {
  425. if (ST->getContainedType(n) != EltTy)
  426. return false;
  427. }
  428. return true;
  429. }
  430. return false;
  431. }
  432. /// isCompatibleAggregate - Check if T1 and T2 are either the same type or are
  433. /// "homogeneous" aggregates with the same element type and number of elements.
  434. static bool isCompatibleAggregate(Type *T1, Type *T2) {
  435. if (T1 == T2)
  436. return true;
  437. unsigned NumElts1, NumElts2;
  438. Type *EltTy1, *EltTy2;
  439. if (isHomogeneousAggregate(T1, NumElts1, EltTy1) &&
  440. isHomogeneousAggregate(T2, NumElts2, EltTy2) && NumElts1 == NumElts2 &&
  441. EltTy1 == EltTy2)
  442. return true;
  443. return false;
  444. }
  445. /// LoadVectorArray - Load vector array like [2 x <4 x float>] from
  446. /// arrays like 4 [2 x float] or struct array like
  447. /// [2 x { <4 x float>, < 4 x uint> }]
  448. /// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
  449. static Value *LoadVectorOrStructArray(ArrayType *AT, ArrayRef<Value *> NewElts,
  450. SmallVector<Value *, 8> &idxList,
  451. IRBuilder<> &Builder) {
  452. Type *EltTy = AT->getElementType();
  453. Value *retVal = llvm::UndefValue::get(AT);
  454. Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
  455. uint32_t arraySize = AT->getNumElements();
  456. for (uint32_t i = 0; i < arraySize; i++) {
  457. Constant *idx = ConstantInt::get(i32Ty, i);
  458. idxList.emplace_back(idx);
  459. if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
  460. Value *EltVal = LoadVectorOrStructArray(EltAT, NewElts, idxList, Builder);
  461. retVal = Builder.CreateInsertValue(retVal, EltVal, i);
  462. } else {
  463. assert((EltTy->isVectorTy() ||
  464. EltTy->isStructTy()) && "must be a vector or struct type");
  465. bool isVectorTy = EltTy->isVectorTy();
  466. Value *retVec = llvm::UndefValue::get(EltTy);
  467. if (isVectorTy) {
  468. for (uint32_t c = 0; c < EltTy->getVectorNumElements(); c++) {
  469. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  470. Value *elt = Builder.CreateLoad(GEP);
  471. retVec = Builder.CreateInsertElement(retVec, elt, c);
  472. }
  473. } else {
  474. for (uint32_t c = 0; c < EltTy->getStructNumElements(); c++) {
  475. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  476. Value *elt = Builder.CreateLoad(GEP);
  477. retVec = Builder.CreateInsertValue(retVec, elt, c);
  478. }
  479. }
  480. retVal = Builder.CreateInsertValue(retVal, retVec, i);
  481. }
  482. idxList.pop_back();
  483. }
  484. return retVal;
  485. }
  486. /// LoadVectorArray - Store vector array like [2 x <4 x float>] to
  487. /// arrays like 4 [2 x float] or struct array like
  488. /// [2 x { <4 x float>, < 4 x uint> }]
  489. /// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
  490. static void StoreVectorOrStructArray(ArrayType *AT, Value *val,
  491. ArrayRef<Value *> NewElts,
  492. SmallVector<Value *, 8> &idxList,
  493. IRBuilder<> &Builder) {
  494. Type *EltTy = AT->getElementType();
  495. Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
  496. uint32_t arraySize = AT->getNumElements();
  497. for (uint32_t i = 0; i < arraySize; i++) {
  498. Value *elt = Builder.CreateExtractValue(val, i);
  499. Constant *idx = ConstantInt::get(i32Ty, i);
  500. idxList.emplace_back(idx);
  501. if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
  502. StoreVectorOrStructArray(EltAT, elt, NewElts, idxList, Builder);
  503. } else {
  504. assert((EltTy->isVectorTy() ||
  505. EltTy->isStructTy()) && "must be a vector or struct type");
  506. bool isVectorTy = EltTy->isVectorTy();
  507. if (isVectorTy) {
  508. for (uint32_t c = 0; c < EltTy->getVectorNumElements(); c++) {
  509. Value *component = Builder.CreateExtractElement(elt, c);
  510. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  511. Builder.CreateStore(component, GEP);
  512. }
  513. } else {
  514. for (uint32_t c = 0; c < EltTy->getStructNumElements(); c++) {
  515. Value *field = Builder.CreateExtractValue(elt, c);
  516. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  517. Builder.CreateStore(field, GEP);
  518. }
  519. }
  520. }
  521. idxList.pop_back();
  522. }
  523. }
  524. namespace {
  525. // Simple struct to split memcpy into ld/st
  526. struct MemcpySplitter {
  527. llvm::LLVMContext &m_context;
  528. DxilTypeSystem &m_typeSys;
  529. public:
  530. MemcpySplitter(llvm::LLVMContext &context, DxilTypeSystem &typeSys)
  531. : m_context(context), m_typeSys(typeSys) {}
  532. void Split(llvm::Function &F);
  533. static void PatchMemCpyWithZeroIdxGEP(Module &M);
  534. static void PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI, const DataLayout &DL);
  535. static void SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
  536. DxilFieldAnnotation *fieldAnnotation,
  537. DxilTypeSystem &typeSys,
  538. const bool bEltMemCpy = true);
  539. };
  540. // Copy data from srcPtr to destPtr.
  541. void SimplePtrCopy(Value *DestPtr, Value *SrcPtr,
  542. llvm::SmallVector<llvm::Value *, 16> &idxList,
  543. IRBuilder<> &Builder) {
  544. if (idxList.size() > 1) {
  545. DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
  546. SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
  547. }
  548. llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
  549. Builder.CreateStore(ld, DestPtr);
  550. }
  551. // Copy srcVal to destPtr.
  552. void SimpleValCopy(Value *DestPtr, Value *SrcVal,
  553. llvm::SmallVector<llvm::Value *, 16> &idxList,
  554. IRBuilder<> &Builder) {
  555. Value *DestGEP = Builder.CreateInBoundsGEP(DestPtr, idxList);
  556. Value *Val = SrcVal;
  557. // Skip beginning pointer type.
  558. for (unsigned i = 1; i < idxList.size(); i++) {
  559. ConstantInt *idx = cast<ConstantInt>(idxList[i]);
  560. Type *Ty = Val->getType();
  561. if (Ty->isAggregateType()) {
  562. Val = Builder.CreateExtractValue(Val, idx->getLimitedValue());
  563. }
  564. }
  565. Builder.CreateStore(Val, DestGEP);
  566. }
  567. void SimpleCopy(Value *Dest, Value *Src,
  568. llvm::SmallVector<llvm::Value *, 16> &idxList,
  569. IRBuilder<> &Builder) {
  570. if (Src->getType()->isPointerTy())
  571. SimplePtrCopy(Dest, Src, idxList, Builder);
  572. else
  573. SimpleValCopy(Dest, Src, idxList, Builder);
  574. }
  575. Value *CreateMergedGEP(Value *Ptr, SmallVector<Value *, 16> &idxList,
  576. IRBuilder<> &Builder) {
  577. if (GEPOperator *GEPPtr = dyn_cast<GEPOperator>(Ptr)) {
  578. SmallVector<Value *, 2> IdxList(GEPPtr->idx_begin(), GEPPtr->idx_end());
  579. // skip idxLIst.begin() because it is included in GEPPtr idx.
  580. IdxList.append(idxList.begin() + 1, idxList.end());
  581. return Builder.CreateInBoundsGEP(GEPPtr->getPointerOperand(), IdxList);
  582. } else {
  583. return Builder.CreateInBoundsGEP(Ptr, idxList);
  584. }
  585. }
  586. void EltMemCpy(Type *Ty, Value *Dest, Value *Src,
  587. SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
  588. const DataLayout &DL) {
  589. Value *DestGEP = CreateMergedGEP(Dest, idxList, Builder);
  590. Value *SrcGEP = CreateMergedGEP(Src, idxList, Builder);
  591. unsigned size = DL.getTypeAllocSize(Ty);
  592. Builder.CreateMemCpy(DestGEP, SrcGEP, size, /* Align */ 1);
  593. }
  594. bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
  595. if (!Ty->isAggregateType())
  596. return false;
  597. if (HLMatrixType::isa(Ty))
  598. return false;
  599. if (dxilutil::IsHLSLObjectType(Ty))
  600. return false;
  601. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  602. DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
  603. DXASSERT(STA, "require annotation here");
  604. if (STA->IsEmptyStruct())
  605. return false;
  606. // Skip 1 element struct which the element is basic type.
  607. // Because create memcpy will create gep on the struct, memcpy the basic
  608. // type only.
  609. if (ST->getNumElements() == 1)
  610. return IsMemCpyTy(ST->getElementType(0), typeSys);
  611. }
  612. return true;
  613. }
  614. // Split copy into ld/st.
  615. void SplitCpy(Type *Ty, Value *Dest, Value *Src,
  616. SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
  617. const DataLayout &DL, DxilTypeSystem &typeSys,
  618. const DxilFieldAnnotation *fieldAnnotation,
  619. const bool bEltMemCpy = true) {
  620. if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
  621. Constant *idx = Constant::getIntegerValue(
  622. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  623. idxList.emplace_back(idx);
  624. SplitCpy(PT->getElementType(), Dest, Src, idxList, Builder, DL, typeSys,
  625. fieldAnnotation, bEltMemCpy);
  626. idxList.pop_back();
  627. } else if (HLMatrixType::isa(Ty)) {
  628. // If no fieldAnnotation, use row major as default.
  629. // Only load then store immediately should be fine.
  630. bool bRowMajor = true;
  631. if (fieldAnnotation) {
  632. DXASSERT(fieldAnnotation->HasMatrixAnnotation(),
  633. "must has matrix annotation");
  634. bRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
  635. MatrixOrientation::RowMajor;
  636. }
  637. Module *M = Builder.GetInsertPoint()->getModule();
  638. Value *DestMatPtr;
  639. Value *SrcMatPtr;
  640. if (idxList.size() == 1 &&
  641. idxList[0] == ConstantInt::get(IntegerType::get(Ty->getContext(), 32),
  642. APInt(32, 0))) {
  643. // Avoid creating GEP(0)
  644. DestMatPtr = Dest;
  645. SrcMatPtr = Src;
  646. } else {
  647. DestMatPtr = Builder.CreateInBoundsGEP(Dest, idxList);
  648. SrcMatPtr = Builder.CreateInBoundsGEP(Src, idxList);
  649. }
  650. HLMatLoadStoreOpcode loadOp = bRowMajor ? HLMatLoadStoreOpcode::RowMatLoad
  651. : HLMatLoadStoreOpcode::ColMatLoad;
  652. HLMatLoadStoreOpcode storeOp = bRowMajor
  653. ? HLMatLoadStoreOpcode::RowMatStore
  654. : HLMatLoadStoreOpcode::ColMatStore;
  655. Value *Load = HLModule::EmitHLOperationCall(
  656. Builder, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(loadOp),
  657. Ty, {SrcMatPtr}, *M);
  658. HLModule::EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  659. static_cast<unsigned>(storeOp), Ty,
  660. {DestMatPtr, Load}, *M);
  661. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  662. if (dxilutil::IsHLSLObjectType(ST)) {
  663. // Avoid split HLSL object.
  664. SimpleCopy(Dest, Src, idxList, Builder);
  665. return;
  666. }
  667. // Built-in structs have no type annotation
  668. DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
  669. if (STA && STA->IsEmptyStruct())
  670. return;
  671. for (uint32_t i = 0; i < ST->getNumElements(); i++) {
  672. llvm::Type *ET = ST->getElementType(i);
  673. Constant *idx = llvm::Constant::getIntegerValue(
  674. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  675. idxList.emplace_back(idx);
  676. if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
  677. EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
  678. } else {
  679. DxilFieldAnnotation *EltAnnotation =
  680. STA ? &STA->GetFieldAnnotation(i) : nullptr;
  681. SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, EltAnnotation,
  682. bEltMemCpy);
  683. }
  684. idxList.pop_back();
  685. }
  686. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  687. Type *ET = AT->getElementType();
  688. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  689. Constant *idx = Constant::getIntegerValue(
  690. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  691. idxList.emplace_back(idx);
  692. if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
  693. EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
  694. } else {
  695. SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, fieldAnnotation,
  696. bEltMemCpy);
  697. }
  698. idxList.pop_back();
  699. }
  700. } else {
  701. SimpleCopy(Dest, Src, idxList, Builder);
  702. }
  703. }
  704. // Given a pointer to a value, produces a list of pointers to
  705. // all scalar elements of that value and their field annotations, at any nesting
  706. // level.
  707. void SplitPtr(
  708. Value *Ptr, // The root value pointer
  709. SmallVectorImpl<Value *> &IdxList, // GEP indices stack during recursion
  710. Type *Ty, // Type at the current GEP indirection level
  711. const DxilFieldAnnotation
  712. &Annotation, // Annotation at the current GEP indirection level
  713. SmallVectorImpl<Value *>
  714. &EltPtrList, // Accumulates pointers to each element found
  715. SmallVectorImpl<const DxilFieldAnnotation *>
  716. &EltAnnotationList, // Accumulates field annotations for each element
  717. // found
  718. DxilTypeSystem &TypeSys, IRBuilder<> &Builder) {
  719. if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
  720. Constant *idx = Constant::getIntegerValue(
  721. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  722. IdxList.emplace_back(idx);
  723. SplitPtr(Ptr, IdxList, PT->getElementType(), Annotation, EltPtrList,
  724. EltAnnotationList, TypeSys, Builder);
  725. IdxList.pop_back();
  726. return;
  727. }
  728. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  729. if (!HLMatrixType::isa(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
  730. const DxilStructAnnotation *SA = TypeSys.GetStructAnnotation(ST);
  731. for (uint32_t i = 0; i < ST->getNumElements(); i++) {
  732. llvm::Type *EltTy = ST->getElementType(i);
  733. Constant *idx = llvm::Constant::getIntegerValue(
  734. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  735. IdxList.emplace_back(idx);
  736. SplitPtr(Ptr, IdxList, EltTy, SA->GetFieldAnnotation(i), EltPtrList,
  737. EltAnnotationList, TypeSys, Builder);
  738. IdxList.pop_back();
  739. }
  740. return;
  741. }
  742. }
  743. if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  744. if (AT->getArrayNumElements() == 0) {
  745. // Skip cases like [0 x %struct], nothing to copy
  746. return;
  747. }
  748. Type *ElTy = AT->getElementType();
  749. SmallVector<ArrayType *, 4> nestArrayTys;
  750. nestArrayTys.emplace_back(AT);
  751. // support multi level of array
  752. while (ElTy->isArrayTy()) {
  753. ArrayType *ElAT = cast<ArrayType>(ElTy);
  754. nestArrayTys.emplace_back(ElAT);
  755. ElTy = ElAT->getElementType();
  756. }
  757. if (ElTy->isStructTy() && !HLMatrixType::isa(ElTy)) {
  758. DXASSERT(0, "Not support array of struct when split pointers.");
  759. return;
  760. }
  761. }
  762. // Return a pointer to the current element and its annotation
  763. Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
  764. EltPtrList.emplace_back(GEP);
  765. EltAnnotationList.emplace_back(&Annotation);
  766. }
  767. // Support case when bitcast (gep ptr, 0,0) is transformed into bitcast ptr.
  768. unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL,
  769. unsigned size, unsigned level) {
  770. unsigned ptrSize = DL.getTypeAllocSize(Ty);
  771. // Size match, return current level.
  772. if (ptrSize == size) {
  773. // Do not go deeper for matrix or object.
  774. if (HLMatrixType::isa(Ty) || dxilutil::IsHLSLObjectType(Ty))
  775. return level;
  776. // For struct, go deeper if size not change.
  777. // This will leave memcpy to deeper level when flatten.
  778. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  779. if (ST->getNumElements() == 1) {
  780. return MatchSizeByCheckElementType(ST->getElementType(0), DL, size,
  781. level + 1);
  782. }
  783. }
  784. // Don't do this for array.
  785. // Array will be flattened as struct of array.
  786. return level;
  787. }
  788. // Add ZeroIdx cannot make ptrSize bigger.
  789. if (ptrSize < size)
  790. return 0;
  791. // ptrSize > size.
  792. // Try to use element type to make size match.
  793. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  794. return MatchSizeByCheckElementType(ST->getElementType(0), DL, size,
  795. level + 1);
  796. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  797. return MatchSizeByCheckElementType(AT->getElementType(), DL, size,
  798. level + 1);
  799. } else {
  800. return 0;
  801. }
  802. }
  803. void PatchZeroIdxGEP(Value *Ptr, Value *RawPtr, MemCpyInst *MI, unsigned level,
  804. IRBuilder<> &Builder) {
  805. Value *zeroIdx = Builder.getInt32(0);
  806. Value *GEP = nullptr;
  807. if (GEPOperator *GEPPtr = dyn_cast<GEPOperator>(Ptr)) {
  808. SmallVector<Value *, 2> IdxList(GEPPtr->idx_begin(), GEPPtr->idx_end());
  809. // level not + 1 because it is included in GEPPtr idx.
  810. IdxList.append(level, zeroIdx);
  811. GEP = Builder.CreateInBoundsGEP(GEPPtr->getPointerOperand(), IdxList);
  812. } else {
  813. SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
  814. GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
  815. }
  816. // Use BitCastInst::Create to prevent idxList from being optimized.
  817. CastInst *Cast =
  818. BitCastInst::Create(Instruction::BitCast, GEP, RawPtr->getType());
  819. Builder.Insert(Cast);
  820. MI->replaceUsesOfWith(RawPtr, Cast);
  821. // Remove RawPtr if possible.
  822. if (RawPtr->user_empty()) {
  823. if (Instruction *I = dyn_cast<Instruction>(RawPtr)) {
  824. I->eraseFromParent();
  825. }
  826. }
  827. }
  828. void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI,
  829. const DataLayout &DL) {
  830. Value *Dest = MI->getRawDest();
  831. Value *Src = MI->getRawSource();
  832. // Only remove one level bitcast generated from inline.
  833. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  834. Dest = BC->getOperand(0);
  835. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  836. Src = BC->getOperand(0);
  837. IRBuilder<> Builder(MI);
  838. ConstantInt *zero = Builder.getInt32(0);
  839. Type *DestTy = Dest->getType()->getPointerElementType();
  840. Type *SrcTy = Src->getType()->getPointerElementType();
  841. // Support case when bitcast (gep ptr, 0,0) is transformed into
  842. // bitcast ptr.
  843. // Also replace (gep ptr, 0) with ptr.
  844. ConstantInt *Length = cast<ConstantInt>(MI->getLength());
  845. unsigned size = Length->getLimitedValue();
  846. if (unsigned level = MatchSizeByCheckElementType(DestTy, DL, size, 0)) {
  847. PatchZeroIdxGEP(Dest, MI->getRawDest(), MI, level, Builder);
  848. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Dest)) {
  849. if (GEP->getNumIndices() == 1) {
  850. Value *idx = *GEP->idx_begin();
  851. if (idx == zero) {
  852. GEP->replaceAllUsesWith(GEP->getPointerOperand());
  853. }
  854. }
  855. }
  856. if (unsigned level = MatchSizeByCheckElementType(SrcTy, DL, size, 0)) {
  857. PatchZeroIdxGEP(Src, MI->getRawSource(), MI, level, Builder);
  858. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
  859. if (GEP->getNumIndices() == 1) {
  860. Value *idx = *GEP->idx_begin();
  861. if (idx == zero) {
  862. GEP->replaceAllUsesWith(GEP->getPointerOperand());
  863. }
  864. }
  865. }
  866. }
  867. void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(Module &M) {
  868. const DataLayout &DL = M.getDataLayout();
  869. for (Function &F : M.functions()) {
  870. for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
  871. for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
  872. // Avoid invalidating the iterator.
  873. Instruction *I = BI++;
  874. if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
  875. PatchMemCpyWithZeroIdxGEP(MI, DL);
  876. }
  877. }
  878. }
  879. }
  880. }
  881. void DeleteMemcpy(MemCpyInst *MI) {
  882. Value *Op0 = MI->getOperand(0);
  883. Value *Op1 = MI->getOperand(1);
  884. // delete memcpy
  885. MI->eraseFromParent();
  886. if (Instruction *op0 = dyn_cast<Instruction>(Op0)) {
  887. if (op0->user_empty())
  888. op0->eraseFromParent();
  889. }
  890. if (Instruction *op1 = dyn_cast<Instruction>(Op1)) {
  891. if (op1->user_empty())
  892. op1->eraseFromParent();
  893. }
  894. }
  895. // If user is function call, return param annotation to get matrix major.
  896. DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
  897. DxilTypeSystem &typeSys) {
  898. for (User *U : Mat->users()) {
  899. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  900. Function *F = CI->getCalledFunction();
  901. if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
  902. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  903. if (CI->getArgOperand(i) == Mat) {
  904. return &Anno->GetParameterAnnotation(i);
  905. }
  906. }
  907. }
  908. }
  909. }
  910. return nullptr;
  911. }
  912. namespace {
  913. bool isCBVec4ArrayToScalarArray(Type *TyV, Value *Src, Type *TySrc, const DataLayout &DL) {
  914. Value *SrcPtr = Src;
  915. while (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(SrcPtr)) {
  916. SrcPtr = GEP->getPointerOperand();
  917. }
  918. CallInst *CI = dyn_cast<CallInst>(SrcPtr);
  919. if (!CI)
  920. return false;
  921. Function *F = CI->getCalledFunction();
  922. if (hlsl::GetHLOpcodeGroupByName(F) != HLOpcodeGroup::HLSubscript)
  923. return false;
  924. if (hlsl::GetHLOpcode(CI) != (unsigned)HLSubscriptOpcode::CBufferSubscript)
  925. return false;
  926. ArrayType *AT = dyn_cast<ArrayType>(TySrc);
  927. if (!AT)
  928. return false;
  929. VectorType *VT = dyn_cast<VectorType>(AT->getElementType());
  930. if (!VT)
  931. return false;
  932. if (DL.getTypeSizeInBits(VT) != 128)
  933. return false;
  934. ArrayType *DstAT = dyn_cast<ArrayType>(TyV);
  935. if (!DstAT)
  936. return false;
  937. if (VT->getElementType() != DstAT->getElementType())
  938. return false;
  939. unsigned sizeInBits = DL.getTypeSizeInBits(VT->getElementType());
  940. if (sizeInBits < 32)
  941. return false;
  942. return true;
  943. }
  944. bool trySplitCBVec4ArrayToScalarArray(Value *Dest, Type *TyV, Value *Src,
  945. Type *TySrc, const DataLayout &DL,
  946. IRBuilder<> &B) {
  947. if (!isCBVec4ArrayToScalarArray(TyV, Src, TySrc, DL))
  948. return false;
  949. ArrayType *AT = cast<ArrayType>(TyV);
  950. Type *EltTy = AT->getElementType();
  951. unsigned sizeInBits = DL.getTypeSizeInBits(EltTy);
  952. unsigned vecSize = 4;
  953. if (sizeInBits == 64)
  954. vecSize = 2;
  955. unsigned arraySize = AT->getNumElements();
  956. unsigned vecArraySize = arraySize / vecSize;
  957. Value *zeroIdx = B.getInt32(0);
  958. for (unsigned a = 0; a < vecArraySize; a++) {
  959. Value *SrcGEP = B.CreateGEP(Src, {zeroIdx, B.getInt32(a)});
  960. Value *Ld = B.CreateLoad(SrcGEP);
  961. for (unsigned v = 0; v < vecSize; v++) {
  962. Value *Elt = B.CreateExtractElement(Ld, v);
  963. Value *DestGEP =
  964. B.CreateGEP(Dest, {zeroIdx, B.getInt32(a * vecSize + v)});
  965. B.CreateStore(Elt, DestGEP);
  966. }
  967. }
  968. return true;
  969. }
  970. }
  971. void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
  972. DxilFieldAnnotation *fieldAnnotation,
  973. DxilTypeSystem &typeSys,
  974. const bool bEltMemCpy) {
  975. Value *Dest = MI->getRawDest();
  976. Value *Src = MI->getRawSource();
  977. // Only remove one level bitcast generated from inline.
  978. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  979. Dest = BC->getOperand(0);
  980. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  981. Src = BC->getOperand(0);
  982. if (Dest == Src) {
  983. // delete self copy.
  984. DeleteMemcpy(MI);
  985. return;
  986. }
  987. IRBuilder<> Builder(MI);
  988. Type *DestTy = Dest->getType()->getPointerElementType();
  989. Type *SrcTy = Src->getType()->getPointerElementType();
  990. // Allow copy between different address space.
  991. if (DestTy != SrcTy) {
  992. if (trySplitCBVec4ArrayToScalarArray(Dest, DestTy, Src, SrcTy, DL,
  993. Builder)) {
  994. // delete memcpy
  995. DeleteMemcpy(MI);
  996. }
  997. return;
  998. }
  999. // Try to find fieldAnnotation from user of Dest/Src.
  1000. if (!fieldAnnotation) {
  1001. Type *EltTy = dxilutil::GetArrayEltTy(DestTy);
  1002. if (HLMatrixType::isa(EltTy)) {
  1003. fieldAnnotation = FindAnnotationFromMatUser(Dest, typeSys);
  1004. }
  1005. }
  1006. llvm::SmallVector<llvm::Value *, 16> idxList;
  1007. // split
  1008. // Matrix is treated as scalar type, will not use memcpy.
  1009. // So use nullptr for fieldAnnotation should be safe here.
  1010. SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, DL, typeSys,
  1011. fieldAnnotation, bEltMemCpy);
  1012. // delete memcpy
  1013. DeleteMemcpy(MI);
  1014. }
  1015. void MemcpySplitter::Split(llvm::Function &F) {
  1016. const DataLayout &DL = F.getParent()->getDataLayout();
  1017. SmallVector<Function *, 2> memcpys;
  1018. for (Function &Fn : F.getParent()->functions()) {
  1019. if (Fn.getIntrinsicID() == Intrinsic::memcpy) {
  1020. memcpys.emplace_back(&Fn);
  1021. }
  1022. }
  1023. for (Function *memcpy : memcpys) {
  1024. for (auto U = memcpy->user_begin(); U != memcpy->user_end();) {
  1025. MemCpyInst *MI = cast<MemCpyInst>(*(U++));
  1026. if (MI->getParent()->getParent() != &F)
  1027. continue;
  1028. // Matrix is treated as scalar type, will not use memcpy.
  1029. // So use nullptr for fieldAnnotation should be safe here.
  1030. SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys,
  1031. /*bEltMemCpy*/ false);
  1032. }
  1033. }
  1034. }
  1035. } // namespace
  1036. namespace {
  1037. /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
  1038. /// recursively including all their operands that become trivially dead.
  1039. void DeleteDeadInstructions(SmallVector<Value *, 32> &DeadInsts) {
  1040. while (!DeadInsts.empty()) {
  1041. Instruction *I = cast<Instruction>(DeadInsts.pop_back_val());
  1042. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI)
  1043. if (Instruction *U = dyn_cast<Instruction>(*OI)) {
  1044. // Zero out the operand and see if it becomes trivially dead.
  1045. // (But, don't add allocas to the dead instruction list -- they are
  1046. // already on the worklist and will be deleted separately.)
  1047. *OI = nullptr;
  1048. if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U))
  1049. DeadInsts.push_back(U);
  1050. }
  1051. I->eraseFromParent();
  1052. }
  1053. }
  1054. // markPrecise - To save the precise attribute on alloca inst which might be
  1055. // removed by promote, mark precise attribute with function call on alloca inst
  1056. // stores.
  1057. bool markPrecise(Function &F) {
  1058. bool Changed = false;
  1059. BasicBlock &BB = F.getEntryBlock();
  1060. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
  1061. if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
  1062. // TODO: Only do this on basic types.
  1063. if (HLModule::HasPreciseAttributeWithMetadata(A)) {
  1064. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(A,
  1065. *(F.getParent()));
  1066. Changed = true;
  1067. }
  1068. }
  1069. return Changed;
  1070. }
  1071. bool Cleanup(Function &F, DxilTypeSystem &typeSys) {
  1072. // change rest memcpy into ld/st.
  1073. MemcpySplitter splitter(F.getContext(), typeSys);
  1074. splitter.Split(F);
  1075. return markPrecise(F);
  1076. }
  1077. } // namespace
  1078. namespace {
  1079. /// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for
  1080. /// SROA. It must be a struct or array type with a small number of elements.
  1081. bool ShouldAttemptScalarRepl(AllocaInst *AI) {
  1082. Type *T = AI->getAllocatedType();
  1083. // promote every struct.
  1084. if (dyn_cast<StructType>(T))
  1085. return true;
  1086. // promote every array.
  1087. if (dyn_cast<ArrayType>(T))
  1088. return true;
  1089. return false;
  1090. }
  1091. /// AllocaInfo - When analyzing uses of an alloca instruction, this captures
  1092. /// information about the uses. All these fields are initialized to false
  1093. /// and set to true when something is learned.
  1094. struct AllocaInfo {
  1095. /// The alloca to promote.
  1096. AllocaInst *AI;
  1097. /// CheckedPHIs - This is a set of verified PHI nodes, to prevent infinite
  1098. /// looping and avoid redundant work.
  1099. SmallPtrSet<PHINode *, 8> CheckedPHIs;
  1100. /// isUnsafe - This is set to true if the alloca cannot be SROA'd.
  1101. bool isUnsafe : 1;
  1102. /// isMemCpySrc - This is true if this aggregate is memcpy'd from.
  1103. bool isMemCpySrc : 1;
  1104. /// isMemCpyDst - This is true if this aggregate is memcpy'd into.
  1105. bool isMemCpyDst : 1;
  1106. /// hasSubelementAccess - This is true if a subelement of the alloca is
  1107. /// ever accessed, or false if the alloca is only accessed with mem
  1108. /// intrinsics or load/store that only access the entire alloca at once.
  1109. bool hasSubelementAccess : 1;
  1110. /// hasALoadOrStore - This is true if there are any loads or stores to it.
  1111. /// The alloca may just be accessed with memcpy, for example, which would
  1112. /// not set this.
  1113. bool hasALoadOrStore : 1;
  1114. /// hasArrayIndexing - This is true if there are any dynamic array
  1115. /// indexing to it.
  1116. bool hasArrayIndexing : 1;
  1117. /// hasVectorIndexing - This is true if there are any dynamic vector
  1118. /// indexing to it.
  1119. bool hasVectorIndexing : 1;
  1120. explicit AllocaInfo(AllocaInst *ai)
  1121. : AI(ai), isUnsafe(false), isMemCpySrc(false), isMemCpyDst(false),
  1122. hasSubelementAccess(false), hasALoadOrStore(false),
  1123. hasArrayIndexing(false), hasVectorIndexing(false) {}
  1124. };
  1125. /// TypeHasComponent - Return true if T has a component type with the
  1126. /// specified offset and size. If Size is zero, do not check the size.
  1127. bool TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
  1128. const DataLayout &DL) {
  1129. Type *EltTy;
  1130. uint64_t EltSize;
  1131. if (StructType *ST = dyn_cast<StructType>(T)) {
  1132. const StructLayout *Layout = DL.getStructLayout(ST);
  1133. unsigned EltIdx = Layout->getElementContainingOffset(Offset);
  1134. EltTy = ST->getContainedType(EltIdx);
  1135. EltSize = DL.getTypeAllocSize(EltTy);
  1136. Offset -= Layout->getElementOffset(EltIdx);
  1137. } else if (ArrayType *AT = dyn_cast<ArrayType>(T)) {
  1138. EltTy = AT->getElementType();
  1139. EltSize = DL.getTypeAllocSize(EltTy);
  1140. if (Offset >= AT->getNumElements() * EltSize)
  1141. return false;
  1142. Offset %= EltSize;
  1143. } else if (VectorType *VT = dyn_cast<VectorType>(T)) {
  1144. EltTy = VT->getElementType();
  1145. EltSize = DL.getTypeAllocSize(EltTy);
  1146. if (Offset >= VT->getNumElements() * EltSize)
  1147. return false;
  1148. Offset %= EltSize;
  1149. } else {
  1150. return false;
  1151. }
  1152. if (Offset == 0 && (Size == 0 || EltSize == Size))
  1153. return true;
  1154. // Check if the component spans multiple elements.
  1155. if (Offset + Size > EltSize)
  1156. return false;
  1157. return TypeHasComponent(EltTy, Offset, Size, DL);
  1158. }
  1159. void MarkUnsafe(AllocaInfo &I, Instruction *User) {
  1160. I.isUnsafe = true;
  1161. DEBUG(dbgs() << " Transformation preventing inst: " << *User << '\n');
  1162. }
  1163. /// isSafeGEP - Check if a GEP instruction can be handled for scalar
  1164. /// replacement. It is safe when all the indices are constant, in-bounds
  1165. /// references, and when the resulting offset corresponds to an element within
  1166. /// the alloca type. The results are flagged in the Info parameter. Upon
  1167. /// return, Offset is adjusted as specified by the GEP indices.
  1168. void isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset, AllocaInfo &Info) {
  1169. gep_type_iterator GEPIt = gep_type_begin(GEPI), E = gep_type_end(GEPI);
  1170. if (GEPIt == E)
  1171. return;
  1172. bool NonConstant = false;
  1173. unsigned NonConstantIdxSize = 0;
  1174. // Compute the offset due to this GEP and check if the alloca has a
  1175. // component element at that offset.
  1176. SmallVector<Value *, 8> Indices(GEPI->op_begin() + 1, GEPI->op_end());
  1177. auto indicesIt = Indices.begin();
  1178. // Walk through the GEP type indices, checking the types that this indexes
  1179. // into.
  1180. uint32_t arraySize = 0;
  1181. bool isArrayIndexing = false;
  1182. for (; GEPIt != E; ++GEPIt) {
  1183. Type *Ty = *GEPIt;
  1184. if (Ty->isStructTy() && !HLMatrixType::isa(Ty)) {
  1185. // Don't go inside struct when mark hasArrayIndexing and
  1186. // hasVectorIndexing. The following level won't affect scalar repl on the
  1187. // struct.
  1188. break;
  1189. }
  1190. if (GEPIt->isArrayTy()) {
  1191. arraySize = GEPIt->getArrayNumElements();
  1192. isArrayIndexing = true;
  1193. }
  1194. if (GEPIt->isVectorTy()) {
  1195. arraySize = GEPIt->getVectorNumElements();
  1196. isArrayIndexing = false;
  1197. }
  1198. // Allow dynamic indexing
  1199. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1200. if (!IdxVal) {
  1201. // for dynamic index, use array size - 1 to check the offset
  1202. *indicesIt = Constant::getIntegerValue(
  1203. Type::getInt32Ty(GEPI->getContext()), APInt(32, arraySize - 1));
  1204. if (isArrayIndexing)
  1205. Info.hasArrayIndexing = true;
  1206. else
  1207. Info.hasVectorIndexing = true;
  1208. NonConstant = true;
  1209. }
  1210. indicesIt++;
  1211. }
  1212. // Continue iterate only for the NonConstant.
  1213. for (; GEPIt != E; ++GEPIt) {
  1214. Type *Ty = *GEPIt;
  1215. if (Ty->isArrayTy()) {
  1216. arraySize = GEPIt->getArrayNumElements();
  1217. }
  1218. if (Ty->isVectorTy()) {
  1219. arraySize = GEPIt->getVectorNumElements();
  1220. }
  1221. // Allow dynamic indexing
  1222. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1223. if (!IdxVal) {
  1224. // for dynamic index, use array size - 1 to check the offset
  1225. *indicesIt = Constant::getIntegerValue(
  1226. Type::getInt32Ty(GEPI->getContext()), APInt(32, arraySize - 1));
  1227. NonConstant = true;
  1228. }
  1229. indicesIt++;
  1230. }
  1231. // If this GEP is non-constant then the last operand must have been a
  1232. // dynamic index into a vector. Pop this now as it has no impact on the
  1233. // constant part of the offset.
  1234. if (NonConstant)
  1235. Indices.pop_back();
  1236. const DataLayout &DL = GEPI->getModule()->getDataLayout();
  1237. Offset += DL.getIndexedOffset(GEPI->getPointerOperandType(), Indices);
  1238. if (!TypeHasComponent(Info.AI->getAllocatedType(), Offset, NonConstantIdxSize,
  1239. DL))
  1240. MarkUnsafe(Info, GEPI);
  1241. }
  1242. /// isSafeMemAccess - Check if a load/store/memcpy operates on the entire AI
  1243. /// alloca or has an offset and size that corresponds to a component element
  1244. /// within it. The offset checked here may have been formed from a GEP with a
  1245. /// pointer bitcasted to a different type.
  1246. ///
  1247. /// If AllowWholeAccess is true, then this allows uses of the entire alloca as a
  1248. /// unit. If false, it only allows accesses known to be in a single element.
  1249. void isSafeMemAccess(uint64_t Offset, uint64_t MemSize, Type *MemOpType,
  1250. bool isStore, AllocaInfo &Info, Instruction *TheAccess,
  1251. bool AllowWholeAccess) {
  1252. // What hlsl cares is Info.hasVectorIndexing.
  1253. // Do nothing here.
  1254. }
  1255. /// isSafePHIUseForScalarRepl - If we see a PHI node or select using a pointer
  1256. /// derived from the alloca, we can often still split the alloca into elements.
  1257. /// This is useful if we have a large alloca where one element is phi'd
  1258. /// together somewhere: we can SRoA and promote all the other elements even if
  1259. /// we end up not being able to promote this one.
  1260. ///
  1261. /// All we require is that the uses of the PHI do not index into other parts of
  1262. /// the alloca. The most important use case for this is single load and stores
  1263. /// that are PHI'd together, which can happen due to code sinking.
  1264. void isSafePHISelectUseForScalarRepl(Instruction *I, uint64_t Offset,
  1265. AllocaInfo &Info) {
  1266. // If we've already checked this PHI, don't do it again.
  1267. if (PHINode *PN = dyn_cast<PHINode>(I))
  1268. if (!Info.CheckedPHIs.insert(PN).second)
  1269. return;
  1270. const DataLayout &DL = I->getModule()->getDataLayout();
  1271. for (User *U : I->users()) {
  1272. Instruction *UI = cast<Instruction>(U);
  1273. if (BitCastInst *BC = dyn_cast<BitCastInst>(UI)) {
  1274. isSafePHISelectUseForScalarRepl(BC, Offset, Info);
  1275. } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(UI)) {
  1276. // Only allow "bitcast" GEPs for simplicity. We could generalize this,
  1277. // but would have to prove that we're staying inside of an element being
  1278. // promoted.
  1279. if (!GEPI->hasAllZeroIndices())
  1280. return MarkUnsafe(Info, UI);
  1281. isSafePHISelectUseForScalarRepl(GEPI, Offset, Info);
  1282. } else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  1283. if (!LI->isSimple())
  1284. return MarkUnsafe(Info, UI);
  1285. Type *LIType = LI->getType();
  1286. isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info,
  1287. LI, false /*AllowWholeAccess*/);
  1288. Info.hasALoadOrStore = true;
  1289. } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
  1290. // Store is ok if storing INTO the pointer, not storing the pointer
  1291. if (!SI->isSimple() || SI->getOperand(0) == I)
  1292. return MarkUnsafe(Info, UI);
  1293. Type *SIType = SI->getOperand(0)->getType();
  1294. isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info,
  1295. SI, false /*AllowWholeAccess*/);
  1296. Info.hasALoadOrStore = true;
  1297. } else if (isa<PHINode>(UI) || isa<SelectInst>(UI)) {
  1298. isSafePHISelectUseForScalarRepl(UI, Offset, Info);
  1299. } else {
  1300. return MarkUnsafe(Info, UI);
  1301. }
  1302. if (Info.isUnsafe)
  1303. return;
  1304. }
  1305. }
  1306. /// isSafeForScalarRepl - Check if instruction I is a safe use with regard to
  1307. /// performing scalar replacement of alloca AI. The results are flagged in
  1308. /// the Info parameter. Offset indicates the position within AI that is
  1309. /// referenced by this instruction.
  1310. void isSafeForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info) {
  1311. if (I->getType()->isPointerTy()) {
  1312. // Don't check object pointers.
  1313. if (dxilutil::IsHLSLObjectType(I->getType()->getPointerElementType()))
  1314. return;
  1315. }
  1316. const DataLayout &DL = I->getModule()->getDataLayout();
  1317. for (Use &U : I->uses()) {
  1318. Instruction *User = cast<Instruction>(U.getUser());
  1319. if (BitCastInst *BC = dyn_cast<BitCastInst>(User)) {
  1320. isSafeForScalarRepl(BC, Offset, Info);
  1321. } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User)) {
  1322. uint64_t GEPOffset = Offset;
  1323. isSafeGEP(GEPI, GEPOffset, Info);
  1324. if (!Info.isUnsafe)
  1325. isSafeForScalarRepl(GEPI, GEPOffset, Info);
  1326. } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
  1327. ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength());
  1328. if (!Length || Length->isNegative())
  1329. return MarkUnsafe(Info, User);
  1330. isSafeMemAccess(Offset, Length->getZExtValue(), nullptr,
  1331. U.getOperandNo() == 0, Info, MI,
  1332. true /*AllowWholeAccess*/);
  1333. } else if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
  1334. if (!LI->isSimple())
  1335. return MarkUnsafe(Info, User);
  1336. Type *LIType = LI->getType();
  1337. isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info,
  1338. LI, true /*AllowWholeAccess*/);
  1339. Info.hasALoadOrStore = true;
  1340. } else if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
  1341. // Store is ok if storing INTO the pointer, not storing the pointer
  1342. if (!SI->isSimple() || SI->getOperand(0) == I)
  1343. return MarkUnsafe(Info, User);
  1344. Type *SIType = SI->getOperand(0)->getType();
  1345. isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info,
  1346. SI, true /*AllowWholeAccess*/);
  1347. Info.hasALoadOrStore = true;
  1348. } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) {
  1349. if (II->getIntrinsicID() != Intrinsic::lifetime_start &&
  1350. II->getIntrinsicID() != Intrinsic::lifetime_end)
  1351. return MarkUnsafe(Info, User);
  1352. } else if (isa<PHINode>(User) || isa<SelectInst>(User)) {
  1353. isSafePHISelectUseForScalarRepl(User, Offset, Info);
  1354. } else if (CallInst *CI = dyn_cast<CallInst>(User)) {
  1355. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1356. // Most HL functions are safe for scalar repl.
  1357. if (HLOpcodeGroup::NotHL == group)
  1358. return MarkUnsafe(Info, User);
  1359. else if (HLOpcodeGroup::HLIntrinsic == group) {
  1360. // TODO: should we check HL parameter type for UDT overload instead of
  1361. // basing on IOP?
  1362. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(CI));
  1363. if (IntrinsicOp::IOP_TraceRay == opcode ||
  1364. IntrinsicOp::IOP_ReportHit == opcode ||
  1365. IntrinsicOp::IOP_CallShader == opcode) {
  1366. return MarkUnsafe(Info, User);
  1367. }
  1368. }
  1369. } else {
  1370. return MarkUnsafe(Info, User);
  1371. }
  1372. if (Info.isUnsafe)
  1373. return;
  1374. }
  1375. }
  1376. /// HasPadding - Return true if the specified type has any structure or
  1377. /// alignment padding in between the elements that would be split apart
  1378. /// by SROA; return false otherwise.
  1379. static bool HasPadding(Type *Ty, const DataLayout &DL) {
  1380. if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
  1381. Ty = ATy->getElementType();
  1382. return DL.getTypeSizeInBits(Ty) != DL.getTypeAllocSizeInBits(Ty);
  1383. }
  1384. // SROA currently handles only Arrays and Structs.
  1385. StructType *STy = cast<StructType>(Ty);
  1386. const StructLayout *SL = DL.getStructLayout(STy);
  1387. unsigned PrevFieldBitOffset = 0;
  1388. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
  1389. unsigned FieldBitOffset = SL->getElementOffsetInBits(i);
  1390. // Check to see if there is any padding between this element and the
  1391. // previous one.
  1392. if (i) {
  1393. unsigned PrevFieldEnd =
  1394. PrevFieldBitOffset + DL.getTypeSizeInBits(STy->getElementType(i - 1));
  1395. if (PrevFieldEnd < FieldBitOffset)
  1396. return true;
  1397. }
  1398. PrevFieldBitOffset = FieldBitOffset;
  1399. }
  1400. // Check for tail padding.
  1401. if (unsigned EltCount = STy->getNumElements()) {
  1402. unsigned PrevFieldEnd =
  1403. PrevFieldBitOffset +
  1404. DL.getTypeSizeInBits(STy->getElementType(EltCount - 1));
  1405. if (PrevFieldEnd < SL->getSizeInBits())
  1406. return true;
  1407. }
  1408. return false;
  1409. }
  1410. /// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of
  1411. /// an aggregate can be broken down into elements. Return 0 if not, 3 if safe,
  1412. /// or 1 if safe after canonicalization has been performed.
  1413. bool isSafeAllocaToScalarRepl(AllocaInst *AI) {
  1414. // Loop over the use list of the alloca. We can only transform it if all of
  1415. // the users are safe to transform.
  1416. AllocaInfo Info(AI);
  1417. isSafeForScalarRepl(AI, 0, Info);
  1418. if (Info.isUnsafe) {
  1419. DEBUG(dbgs() << "Cannot transform: " << *AI << '\n');
  1420. return false;
  1421. }
  1422. // vector indexing need translate vector into array
  1423. if (Info.hasVectorIndexing)
  1424. return false;
  1425. const DataLayout &DL = AI->getModule()->getDataLayout();
  1426. // Okay, we know all the users are promotable. If the aggregate is a memcpy
  1427. // source and destination, we have to be careful. In particular, the memcpy
  1428. // could be moving around elements that live in structure padding of the LLVM
  1429. // types, but may actually be used. In these cases, we refuse to promote the
  1430. // struct.
  1431. if (Info.isMemCpySrc && Info.isMemCpyDst &&
  1432. HasPadding(AI->getAllocatedType(), DL))
  1433. return false;
  1434. return true;
  1435. }
  1436. } // namespace
  1437. namespace {
  1438. struct GVDbgOffset {
  1439. GlobalVariable *base;
  1440. unsigned debugOffset;
  1441. };
  1442. bool hasDynamicVectorIndexing(Value *V) {
  1443. for (User *U : V->users()) {
  1444. if (!U->getType()->isPointerTy())
  1445. continue;
  1446. if (dyn_cast<GEPOperator>(U)) {
  1447. gep_type_iterator GEPIt = gep_type_begin(U), E = gep_type_end(U);
  1448. for (; GEPIt != E; ++GEPIt) {
  1449. if (isa<VectorType>(*GEPIt)) {
  1450. Value *VecIdx = GEPIt.getOperand();
  1451. if (!isa<ConstantInt>(VecIdx))
  1452. return true;
  1453. }
  1454. }
  1455. }
  1456. }
  1457. return false;
  1458. }
  1459. } // namespace
  1460. namespace {
  1461. void RemoveUnusedInternalGlobalVariable(Module &M) {
  1462. std::vector<GlobalVariable *> staticGVs;
  1463. for (GlobalVariable &GV : M.globals()) {
  1464. if (dxilutil::IsStaticGlobal(&GV) || dxilutil::IsSharedMemoryGlobal(&GV)) {
  1465. staticGVs.emplace_back(&GV);
  1466. }
  1467. }
  1468. for (GlobalVariable *GV : staticGVs) {
  1469. bool onlyStoreUse = true;
  1470. for (User *user : GV->users()) {
  1471. if (isa<StoreInst>(user))
  1472. continue;
  1473. if (isa<ConstantExpr>(user) && user->user_empty())
  1474. continue;
  1475. // Check matrix store.
  1476. if (HLMatrixType::isa(GV->getType()->getPointerElementType())) {
  1477. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  1478. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  1479. HLOpcodeGroup::HLMatLoadStore) {
  1480. HLMatLoadStoreOpcode opcode =
  1481. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  1482. if (opcode == HLMatLoadStoreOpcode::ColMatStore ||
  1483. opcode == HLMatLoadStoreOpcode::RowMatStore)
  1484. continue;
  1485. }
  1486. }
  1487. }
  1488. onlyStoreUse = false;
  1489. break;
  1490. }
  1491. if (onlyStoreUse) {
  1492. for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
  1493. Value *User = *(UserIt++);
  1494. if (Instruction *I = dyn_cast<Instruction>(User)) {
  1495. I->eraseFromParent();
  1496. } else {
  1497. ConstantExpr *CE = cast<ConstantExpr>(User);
  1498. CE->dropAllReferences();
  1499. }
  1500. }
  1501. GV->eraseFromParent();
  1502. }
  1503. }
  1504. }
  1505. bool isGroupShareOrConstStaticArray(GlobalVariable *GV) {
  1506. // Disable scalarization of groupshared/const_static vector arrays
  1507. if (!(GV->getType()->getAddressSpace() == DXIL::kTGSMAddrSpace ||
  1508. (GV->isConstant() && GV->hasInitializer() &&
  1509. GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage)))
  1510. return false;
  1511. Type *Ty = GV->getType()->getPointerElementType();
  1512. return Ty->isArrayTy();
  1513. }
  1514. bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
  1515. Module &M = *HLM.GetModule();
  1516. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  1517. const DataLayout &DL = M.getDataLayout();
  1518. // Make sure big alloca split first.
  1519. // This will simplify memcpy check between part of big alloca and small
  1520. // alloca. Big alloca will be split to smaller piece first, when process the
  1521. // alloca, it will be alloca flattened from big alloca instead of a GEP of
  1522. // big alloca.
  1523. auto size_cmp = [&DL](const Value *a0, const Value *a1) -> bool {
  1524. Type *a0ty = a0->getType()->getPointerElementType();
  1525. Type *a1ty = a1->getType()->getPointerElementType();
  1526. bool isUnitSzStruct0 =
  1527. a0ty->isStructTy() && a0ty->getStructNumElements() == 1;
  1528. bool isUnitSzStruct1 =
  1529. a1ty->isStructTy() && a1ty->getStructNumElements() == 1;
  1530. auto sz0 = DL.getTypeAllocSize(a0ty);
  1531. auto sz1 = DL.getTypeAllocSize(a1ty);
  1532. if (sz0 == sz1 && (isUnitSzStruct0 || isUnitSzStruct1))
  1533. return getNestedLevelInStruct(a0ty) < getNestedLevelInStruct(a1ty);
  1534. return sz0 < sz1;
  1535. };
  1536. std::priority_queue<Value *, std::vector<Value *>,
  1537. std::function<bool(Value *, Value *)>>
  1538. WorkList(size_cmp);
  1539. // Flatten internal global.
  1540. llvm::SetVector<GlobalVariable *> staticGVs;
  1541. DenseMap<GlobalVariable *, GVDbgOffset> GVDbgOffsetMap;
  1542. for (GlobalVariable &GV : M.globals()) {
  1543. if (dxilutil::IsStaticGlobal(&GV) || dxilutil::IsSharedMemoryGlobal(&GV)) {
  1544. staticGVs.insert(&GV);
  1545. GVDbgOffset &dbgOffset = GVDbgOffsetMap[&GV];
  1546. dbgOffset.base = &GV;
  1547. dbgOffset.debugOffset = 0;
  1548. } else {
  1549. // merge GEP use for global.
  1550. HLModule::MergeGepUse(&GV);
  1551. }
  1552. }
  1553. // Add static GVs to work list.
  1554. for (GlobalVariable *GV : staticGVs)
  1555. WorkList.push(GV);
  1556. DenseMap<Function *, DominatorTree> domTreeMap;
  1557. for (Function &F : M) {
  1558. if (F.isDeclaration())
  1559. continue;
  1560. // Collect domTree.
  1561. domTreeMap[&F].recalculate(F);
  1562. // Scan the entry basic block, adding allocas to the worklist.
  1563. BasicBlock &BB = F.getEntryBlock();
  1564. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
  1565. if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
  1566. if (!A->user_empty()) {
  1567. WorkList.push(A);
  1568. // merge GEP use for the allocs
  1569. HLModule::MergeGepUse(A);
  1570. }
  1571. }
  1572. }
  1573. // Establish debug metadata layout name in the context in advance so the name
  1574. // is serialized in both debug and non-debug compilations.
  1575. (void)M.getContext().getMDKindID(
  1576. DxilMDHelper::kDxilVariableDebugLayoutMDName);
  1577. DIBuilder DIB(M, /*AllowUnresolved*/ false);
  1578. /// DeadInsts - Keep track of instructions we have made dead, so that
  1579. /// we can remove them after we are done working.
  1580. SmallVector<Value *, 32> DeadInsts;
  1581. // Only used to create ConstantExpr.
  1582. IRBuilder<> Builder(M.getContext());
  1583. std::unordered_map<Value *, StringRef> EltNameMap;
  1584. bool Changed = false;
  1585. while (!WorkList.empty()) {
  1586. Value *V = WorkList.top();
  1587. WorkList.pop();
  1588. if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) {
  1589. // Handle dead allocas trivially. These can be formed by SROA'ing arrays
  1590. // with unused elements.
  1591. if (AI->use_empty()) {
  1592. AI->eraseFromParent();
  1593. Changed = true;
  1594. continue;
  1595. }
  1596. Function *F = AI->getParent()->getParent();
  1597. const bool bAllowReplace = true;
  1598. DominatorTree &DT = domTreeMap[F];
  1599. if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL, &DT,
  1600. bAllowReplace)) {
  1601. Changed = true;
  1602. continue;
  1603. }
  1604. // If this alloca is impossible for us to promote, reject it early.
  1605. if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
  1606. continue;
  1607. // Check to see if we can perform the core SROA transformation. We cannot
  1608. // transform the allocation instruction if it is an array allocation
  1609. // (allocations OF arrays are ok though), and an allocation of a scalar
  1610. // value cannot be decomposed at all.
  1611. uint64_t AllocaSize = DL.getTypeAllocSize(AI->getAllocatedType());
  1612. // Do not promote [0 x %struct].
  1613. if (AllocaSize == 0)
  1614. continue;
  1615. Type *Ty = AI->getAllocatedType();
  1616. // Skip empty struct type.
  1617. if (SROA_Helper::IsEmptyStructType(Ty, typeSys)) {
  1618. SROA_Helper::MarkEmptyStructUsers(AI, DeadInsts);
  1619. DeleteDeadInstructions(DeadInsts);
  1620. continue;
  1621. }
  1622. if (Value *NewV = TranslatePtrIfUsedByLoweredFn(AI, typeSys)) {
  1623. if (NewV != AI) {
  1624. DXASSERT(AI->getNumUses() == 0, "must have zero users.");
  1625. // Update debug declare.
  1626. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI)) {
  1627. DDI->setArgOperand(0, MetadataAsValue::get(NewV->getContext(), ValueAsMetadata::get(NewV)));
  1628. }
  1629. AI->eraseFromParent();
  1630. Changed = true;
  1631. }
  1632. continue;
  1633. }
  1634. // If the alloca looks like a good candidate for scalar replacement, and
  1635. // if
  1636. // all its users can be transformed, then split up the aggregate into its
  1637. // separate elements.
  1638. if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
  1639. std::vector<Value *> Elts;
  1640. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(AI));
  1641. bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
  1642. Type *BrokenUpTy = nullptr;
  1643. uint64_t NumInstances = 1;
  1644. bool SROAed = SROA_Helper::DoScalarReplacement(
  1645. AI, Elts, BrokenUpTy, NumInstances, Builder,
  1646. /*bFlatVector*/ true, hasPrecise, typeSys, DL, DeadInsts, &DT);
  1647. if (SROAed) {
  1648. Type *Ty = AI->getAllocatedType();
  1649. // Skip empty struct parameters.
  1650. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  1651. if (!HLMatrixType::isa(Ty)) {
  1652. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  1653. if (SA && SA->IsEmptyStruct()) {
  1654. for (User *U : AI->users()) {
  1655. if (StoreInst *SI = dyn_cast<StoreInst>(U))
  1656. DeadInsts.emplace_back(SI);
  1657. }
  1658. DeleteDeadInstructions(DeadInsts);
  1659. AI->replaceAllUsesWith(UndefValue::get(AI->getType()));
  1660. AI->eraseFromParent();
  1661. continue;
  1662. }
  1663. }
  1664. }
  1665. addDebugInfoForElements(AI, BrokenUpTy, NumInstances, Elts, DL, &DIB);
  1666. // Push Elts into workList.
  1667. for (unsigned EltIdx = 0; EltIdx < Elts.size(); ++EltIdx) {
  1668. AllocaInst *EltAlloca = cast<AllocaInst>(Elts[EltIdx]);
  1669. WorkList.push(EltAlloca);
  1670. }
  1671. // Now erase any instructions that were made dead while rewriting the
  1672. // alloca.
  1673. DeleteDeadInstructions(DeadInsts);
  1674. ++NumReplaced;
  1675. DXASSERT(AI->getNumUses() == 0, "must have zero users.");
  1676. AI->eraseFromParent();
  1677. Changed = true;
  1678. continue;
  1679. }
  1680. }
  1681. } else {
  1682. GlobalVariable *GV = cast<GlobalVariable>(V);
  1683. if (staticGVs.count(GV)) {
  1684. Type *Ty = GV->getType()->getPointerElementType();
  1685. // Skip basic types.
  1686. if (!Ty->isAggregateType() && !Ty->isVectorTy())
  1687. continue;
  1688. // merge GEP use for global.
  1689. HLModule::MergeGepUse(GV);
  1690. }
  1691. const bool bAllowReplace = true;
  1692. // SROA_Parameter_HLSL has no access to a domtree, if one is needed, it'll
  1693. // be generated
  1694. if (SROA_Helper::LowerMemcpy(GV, /*annotation*/ nullptr, typeSys, DL,
  1695. nullptr /*DT */, bAllowReplace)) {
  1696. continue;
  1697. }
  1698. // Flat Global vector if no dynamic vector indexing.
  1699. bool bFlatVector = !hasDynamicVectorIndexing(GV);
  1700. if (bFlatVector) {
  1701. GVDbgOffset &dbgOffset = GVDbgOffsetMap[GV];
  1702. GlobalVariable *baseGV = dbgOffset.base;
  1703. // Disable scalarization of groupshared/const_static vector arrays
  1704. if (isGroupShareOrConstStaticArray(baseGV))
  1705. bFlatVector = false;
  1706. }
  1707. std::vector<Value *> Elts;
  1708. bool SROAed = false;
  1709. if (GlobalVariable *NewEltGV = dyn_cast_or_null<GlobalVariable>(
  1710. TranslatePtrIfUsedByLoweredFn(GV, typeSys))) {
  1711. GVDbgOffset dbgOffset = GVDbgOffsetMap[GV];
  1712. // Don't need to update when skip SROA on base GV.
  1713. if (NewEltGV == dbgOffset.base)
  1714. continue;
  1715. if (GV != NewEltGV) {
  1716. GVDbgOffsetMap[NewEltGV] = dbgOffset;
  1717. // Remove GV from GVDbgOffsetMap.
  1718. GVDbgOffsetMap.erase(GV);
  1719. if (GV != dbgOffset.base) {
  1720. // Remove GV when it is replaced by NewEltGV and is not a base GV.
  1721. GV->removeDeadConstantUsers();
  1722. GV->eraseFromParent();
  1723. }
  1724. GV = NewEltGV;
  1725. }
  1726. } else {
  1727. // SROA_Parameter_HLSL has no access to a domtree, if one is needed,
  1728. // it'll be generated
  1729. SROAed = SROA_Helper::DoScalarReplacement(
  1730. GV, Elts, Builder, bFlatVector,
  1731. // TODO: set precise.
  1732. /*hasPrecise*/ false, typeSys, DL, DeadInsts, /*DT*/ nullptr);
  1733. }
  1734. if (SROAed) {
  1735. GVDbgOffset dbgOffset = GVDbgOffsetMap[GV];
  1736. unsigned offset = 0;
  1737. // Push Elts into workList.
  1738. for (auto iter = Elts.begin(); iter != Elts.end(); iter++) {
  1739. WorkList.push(*iter);
  1740. GlobalVariable *EltGV = cast<GlobalVariable>(*iter);
  1741. if (bHasDbgInfo) {
  1742. StringRef OriginEltName = EltGV->getName();
  1743. StringRef OriginName = dbgOffset.base->getName();
  1744. StringRef EltName = OriginEltName.substr(OriginName.size());
  1745. StringRef EltParentName = OriginEltName.substr(0, OriginName.size());
  1746. DXASSERT_LOCALVAR(EltParentName, EltParentName == OriginName, "parent name mismatch");
  1747. EltNameMap[EltGV] = EltName;
  1748. }
  1749. GVDbgOffset &EltDbgOffset = GVDbgOffsetMap[EltGV];
  1750. EltDbgOffset.base = dbgOffset.base;
  1751. EltDbgOffset.debugOffset = dbgOffset.debugOffset + offset;
  1752. unsigned size =
  1753. DL.getTypeAllocSizeInBits(EltGV->getType()->getElementType());
  1754. offset += size;
  1755. }
  1756. GV->removeDeadConstantUsers();
  1757. // Now erase any instructions that were made dead while rewriting the
  1758. // alloca.
  1759. DeleteDeadInstructions(DeadInsts);
  1760. ++NumReplaced;
  1761. } else {
  1762. // Add debug info for flattened globals.
  1763. if (bHasDbgInfo && staticGVs.count(GV) == 0) {
  1764. GVDbgOffset &dbgOffset = GVDbgOffsetMap[GV];
  1765. DebugInfoFinder &Finder = HLM.GetOrCreateDebugInfoFinder();
  1766. Type *Ty = GV->getType()->getElementType();
  1767. unsigned size = DL.getTypeAllocSizeInBits(Ty);
  1768. unsigned align = DL.getPrefTypeAlignment(Ty);
  1769. HLModule::CreateElementGlobalVariableDebugInfo(
  1770. dbgOffset.base, Finder, GV, size, align, dbgOffset.debugOffset,
  1771. EltNameMap[GV]);
  1772. }
  1773. }
  1774. // Remove GV from GVDbgOffsetMap.
  1775. GVDbgOffsetMap.erase(GV);
  1776. }
  1777. }
  1778. // Remove unused internal global.
  1779. RemoveUnusedInternalGlobalVariable(M);
  1780. // Cleanup memcpy for allocas and mark precise.
  1781. for (Function &F : M) {
  1782. if (F.isDeclaration())
  1783. continue;
  1784. Cleanup(F, typeSys);
  1785. }
  1786. return true;
  1787. }
  1788. } // namespace
  1789. //===----------------------------------------------------------------------===//
  1790. // SRoA Helper
  1791. //===----------------------------------------------------------------------===//
  1792. /// RewriteGEP - Rewrite the GEP to be relative to new element when can find a
  1793. /// new element which is struct field. If cannot find, create new element GEPs
  1794. /// and try to rewrite GEP with new GEPS.
  1795. void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
  1796. assert(OldVal == GEP->getPointerOperand() && "");
  1797. Value *NewPointer = nullptr;
  1798. SmallVector<Value *, 8> NewArgs;
  1799. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  1800. for (; GEPIt != E; ++GEPIt) {
  1801. if (GEPIt->isStructTy()) {
  1802. // must be const
  1803. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1804. assert(IdxVal->getLimitedValue() < NewElts.size() && "");
  1805. NewPointer = NewElts[IdxVal->getLimitedValue()];
  1806. // The idx is used for NewPointer, not part of newGEP idx,
  1807. GEPIt++;
  1808. break;
  1809. } else if (GEPIt->isArrayTy()) {
  1810. // Add array idx.
  1811. NewArgs.push_back(GEPIt.getOperand());
  1812. } else if (GEPIt->isPointerTy()) {
  1813. // Add pointer idx.
  1814. NewArgs.push_back(GEPIt.getOperand());
  1815. } else if (GEPIt->isVectorTy()) {
  1816. // Add vector idx.
  1817. NewArgs.push_back(GEPIt.getOperand());
  1818. } else {
  1819. llvm_unreachable("should break from structTy");
  1820. }
  1821. }
  1822. if (NewPointer) {
  1823. // Struct split.
  1824. // Add rest of idx.
  1825. for (; GEPIt != E; ++GEPIt) {
  1826. NewArgs.push_back(GEPIt.getOperand());
  1827. }
  1828. // If only 1 level struct, just use the new pointer.
  1829. Value *NewGEP = NewPointer;
  1830. if (NewArgs.size() > 1) {
  1831. NewGEP = Builder.CreateInBoundsGEP(NewPointer, NewArgs);
  1832. NewGEP->takeName(GEP);
  1833. }
  1834. assert(NewGEP->getType() == GEP->getType() && "type mismatch");
  1835. GEP->replaceAllUsesWith(NewGEP);
  1836. } else {
  1837. // End at array of basic type.
  1838. Type *Ty = GEP->getType()->getPointerElementType();
  1839. if (Ty->isVectorTy() ||
  1840. (Ty->isStructTy() && !dxilutil::IsHLSLObjectType(Ty)) ||
  1841. Ty->isArrayTy()) {
  1842. SmallVector<Value *, 8> NewArgs;
  1843. NewArgs.append(GEP->idx_begin(), GEP->idx_end());
  1844. SmallVector<Value *, 8> NewGEPs;
  1845. // create new geps
  1846. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  1847. Value *NewGEP = Builder.CreateGEP(nullptr, NewElts[i], NewArgs);
  1848. NewGEPs.emplace_back(NewGEP);
  1849. }
  1850. const bool bAllowReplace = isa<AllocaInst>(OldVal);
  1851. if (!SROA_Helper::LowerMemcpy(GEP, /*annotation*/ nullptr, typeSys, DL, DT, bAllowReplace)) {
  1852. SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL, DT);
  1853. helper.RewriteForScalarRepl(GEP, Builder);
  1854. for (Value *NewGEP : NewGEPs) {
  1855. if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
  1856. // Delete unused newGEP.
  1857. cast<Instruction>(NewGEP)->eraseFromParent();
  1858. }
  1859. }
  1860. }
  1861. } else {
  1862. Value *vecIdx = NewArgs.back();
  1863. if (ConstantInt *immVecIdx = dyn_cast<ConstantInt>(vecIdx)) {
  1864. // Replace vecArray[arrayIdx][immVecIdx]
  1865. // with scalarArray_immVecIdx[arrayIdx]
  1866. // Pop the vecIdx.
  1867. NewArgs.pop_back();
  1868. Value *NewGEP = NewElts[immVecIdx->getLimitedValue()];
  1869. if (NewArgs.size() > 1) {
  1870. NewGEP = Builder.CreateInBoundsGEP(NewGEP, NewArgs);
  1871. NewGEP->takeName(GEP);
  1872. }
  1873. assert(NewGEP->getType() == GEP->getType() && "type mismatch");
  1874. GEP->replaceAllUsesWith(NewGEP);
  1875. } else {
  1876. // dynamic vector indexing.
  1877. assert(0 && "should not reach here");
  1878. }
  1879. }
  1880. }
  1881. // Remove the use so that the caller can keep iterating over its other users
  1882. DXASSERT(GEP->user_empty(), "All uses of the GEP should have been eliminated");
  1883. if (isa<Instruction>(GEP)) {
  1884. GEP->setOperand(GEP->getPointerOperandIndex(), UndefValue::get(GEP->getPointerOperand()->getType()));
  1885. DeadInsts.push_back(GEP);
  1886. }
  1887. else {
  1888. cast<Constant>(GEP)->destroyConstant();
  1889. }
  1890. }
  1891. /// isVectorOrStructArray - Check if T is array of vector or struct.
  1892. static bool isVectorOrStructArray(Type *T) {
  1893. if (!T->isArrayTy())
  1894. return false;
  1895. T = dxilutil::GetArrayEltTy(T);
  1896. return T->isStructTy() || T->isVectorTy();
  1897. }
  1898. static void SimplifyStructValUsage(Value *StructVal, std::vector<Value *> Elts,
  1899. SmallVectorImpl<Value *> &DeadInsts) {
  1900. for (User *user : StructVal->users()) {
  1901. if (ExtractValueInst *Extract = dyn_cast<ExtractValueInst>(user)) {
  1902. DXASSERT(Extract->getNumIndices() == 1, "only support 1 index case");
  1903. unsigned index = Extract->getIndices()[0];
  1904. Value *Elt = Elts[index];
  1905. Extract->replaceAllUsesWith(Elt);
  1906. DeadInsts.emplace_back(Extract);
  1907. } else if (InsertValueInst *Insert = dyn_cast<InsertValueInst>(user)) {
  1908. DXASSERT(Insert->getNumIndices() == 1, "only support 1 index case");
  1909. unsigned index = Insert->getIndices()[0];
  1910. if (Insert->getAggregateOperand() == StructVal) {
  1911. // Update field.
  1912. std::vector<Value *> NewElts = Elts;
  1913. NewElts[index] = Insert->getInsertedValueOperand();
  1914. SimplifyStructValUsage(Insert, NewElts, DeadInsts);
  1915. } else {
  1916. // Insert to another bigger struct.
  1917. IRBuilder<> Builder(Insert);
  1918. Value *TmpStructVal = UndefValue::get(StructVal->getType());
  1919. for (unsigned i = 0; i < Elts.size(); i++) {
  1920. TmpStructVal =
  1921. Builder.CreateInsertValue(TmpStructVal, Elts[i], {i});
  1922. }
  1923. Insert->replaceUsesOfWith(StructVal, TmpStructVal);
  1924. }
  1925. }
  1926. }
  1927. }
  1928. /// RewriteForLoad - Replace OldVal with flattened NewElts in LoadInst.
  1929. void SROA_Helper::RewriteForLoad(LoadInst *LI) {
  1930. Type *LIType = LI->getType();
  1931. Type *ValTy = OldVal->getType()->getPointerElementType();
  1932. IRBuilder<> Builder(LI);
  1933. if (LIType->isVectorTy()) {
  1934. // Replace:
  1935. // %res = load { 2 x i32 }* %alloc
  1936. // with:
  1937. // %load.0 = load i32* %alloc.0
  1938. // %insert.0 insertvalue { 2 x i32 } zeroinitializer, i32 %load.0, 0
  1939. // %load.1 = load i32* %alloc.1
  1940. // %insert = insertvalue { 2 x i32 } %insert.0, i32 %load.1, 1
  1941. Value *Insert = UndefValue::get(LIType);
  1942. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  1943. Value *Load = Builder.CreateLoad(NewElts[i], "load");
  1944. Insert = Builder.CreateInsertElement(Insert, Load, i, "insert");
  1945. }
  1946. LI->replaceAllUsesWith(Insert);
  1947. } else if (isCompatibleAggregate(LIType, ValTy)) {
  1948. if (isVectorOrStructArray(LIType)) {
  1949. // Replace:
  1950. // %res = load [2 x <2 x float>] * %alloc
  1951. // with:
  1952. // %load.0 = load [4 x float]* %alloc.0
  1953. // %insert.0 insertvalue [4 x float] zeroinitializer,i32 %load.0,0
  1954. // %load.1 = load [4 x float]* %alloc.1
  1955. // %insert = insertvalue [4 x float] %insert.0, i32 %load.1, 1
  1956. // ...
  1957. Type *i32Ty = Type::getInt32Ty(LIType->getContext());
  1958. Value *zero = ConstantInt::get(i32Ty, 0);
  1959. SmallVector<Value *, 8> idxList;
  1960. idxList.emplace_back(zero);
  1961. Value *newLd =
  1962. LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
  1963. LI->replaceAllUsesWith(newLd);
  1964. } else {
  1965. // Replace:
  1966. // %res = load { i32, i32 }* %alloc
  1967. // with:
  1968. // %load.0 = load i32* %alloc.0
  1969. // %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0,
  1970. // 0
  1971. // %load.1 = load i32* %alloc.1
  1972. // %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1
  1973. // (Also works for arrays instead of structs)
  1974. Module *M = LI->getModule();
  1975. Value *Insert = UndefValue::get(LIType);
  1976. std::vector<Value *> LdElts(NewElts.size());
  1977. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  1978. Value *Ptr = NewElts[i];
  1979. Type *Ty = Ptr->getType()->getPointerElementType();
  1980. Value *Load = nullptr;
  1981. if (!HLMatrixType::isa(Ty))
  1982. Load = Builder.CreateLoad(Ptr, "load");
  1983. else {
  1984. // Generate Matrix Load.
  1985. Load = HLModule::EmitHLOperationCall(
  1986. Builder, HLOpcodeGroup::HLMatLoadStore,
  1987. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty,
  1988. {Ptr}, *M);
  1989. }
  1990. LdElts[i] = Load;
  1991. Insert = Builder.CreateInsertValue(Insert, Load, i, "insert");
  1992. }
  1993. LI->replaceAllUsesWith(Insert);
  1994. if (LIType->isStructTy()) {
  1995. SimplifyStructValUsage(Insert, LdElts, DeadInsts);
  1996. }
  1997. }
  1998. } else {
  1999. llvm_unreachable("other type don't need rewrite");
  2000. }
  2001. // Remove the use so that the caller can keep iterating over its other users
  2002. LI->setOperand(LI->getPointerOperandIndex(), UndefValue::get(LI->getPointerOperand()->getType()));
  2003. DeadInsts.push_back(LI);
  2004. }
  2005. /// RewriteForStore - Replace OldVal with flattened NewElts in StoreInst.
  2006. void SROA_Helper::RewriteForStore(StoreInst *SI) {
  2007. Value *Val = SI->getOperand(0);
  2008. Type *SIType = Val->getType();
  2009. IRBuilder<> Builder(SI);
  2010. Type *ValTy = OldVal->getType()->getPointerElementType();
  2011. if (SIType->isVectorTy()) {
  2012. // Replace:
  2013. // store <2 x float> %val, <2 x float>* %alloc
  2014. // with:
  2015. // %val.0 = extractelement { 2 x float } %val, 0
  2016. // store i32 %val.0, i32* %alloc.0
  2017. // %val.1 = extractelement { 2 x float } %val, 1
  2018. // store i32 %val.1, i32* %alloc.1
  2019. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2020. Value *Extract = Builder.CreateExtractElement(Val, i, Val->getName());
  2021. Builder.CreateStore(Extract, NewElts[i]);
  2022. }
  2023. } else if (isCompatibleAggregate(SIType, ValTy)) {
  2024. if (isVectorOrStructArray(SIType)) {
  2025. // Replace:
  2026. // store [2 x <2 x i32>] %val, [2 x <2 x i32>]* %alloc, align 16
  2027. // with:
  2028. // %val.0 = extractvalue [2 x <2 x i32>] %val, 0
  2029. // %all0c.0.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.0,
  2030. // i32 0, i32 0
  2031. // %val.0.0 = extractelement <2 x i32> %243, i64 0
  2032. // store i32 %val.0.0, i32* %all0c.0.0
  2033. // %alloc.1.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.1,
  2034. // i32 0, i32 0
  2035. // %val.0.1 = extractelement <2 x i32> %243, i64 1
  2036. // store i32 %val.0.1, i32* %alloc.1.0
  2037. // %val.1 = extractvalue [2 x <2 x i32>] %val, 1
  2038. // %alloc.0.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.0,
  2039. // i32 0, i32 1
  2040. // %val.1.0 = extractelement <2 x i32> %248, i64 0
  2041. // store i32 %val.1.0, i32* %alloc.0.0
  2042. // %all0c.1.1 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.1,
  2043. // i32 0, i32 1
  2044. // %val.1.1 = extractelement <2 x i32> %248, i64 1
  2045. // store i32 %val.1.1, i32* %all0c.1.1
  2046. ArrayType *AT = cast<ArrayType>(SIType);
  2047. Type *i32Ty = Type::getInt32Ty(SIType->getContext());
  2048. Value *zero = ConstantInt::get(i32Ty, 0);
  2049. SmallVector<Value *, 8> idxList;
  2050. idxList.emplace_back(zero);
  2051. StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
  2052. } else {
  2053. // Replace:
  2054. // store { i32, i32 } %val, { i32, i32 }* %alloc
  2055. // with:
  2056. // %val.0 = extractvalue { i32, i32 } %val, 0
  2057. // store i32 %val.0, i32* %alloc.0
  2058. // %val.1 = extractvalue { i32, i32 } %val, 1
  2059. // store i32 %val.1, i32* %alloc.1
  2060. // (Also works for arrays instead of structs)
  2061. Module *M = SI->getModule();
  2062. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2063. Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
  2064. if (!HLMatrixType::isa(Extract->getType())) {
  2065. Builder.CreateStore(Extract, NewElts[i]);
  2066. } else {
  2067. // Generate Matrix Store.
  2068. HLModule::EmitHLOperationCall(
  2069. Builder, HLOpcodeGroup::HLMatLoadStore,
  2070. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore),
  2071. Extract->getType(), {NewElts[i], Extract}, *M);
  2072. }
  2073. }
  2074. }
  2075. } else {
  2076. llvm_unreachable("other type don't need rewrite");
  2077. }
  2078. // Remove the use so that the caller can keep iterating over its other users
  2079. SI->setOperand(SI->getPointerOperandIndex(), UndefValue::get(SI->getPointerOperand()->getType()));
  2080. DeadInsts.push_back(SI);
  2081. }
  2082. /// RewriteMemIntrin - MI is a memcpy/memset/memmove from or to AI.
  2083. /// Rewrite it to copy or set the elements of the scalarized memory.
  2084. void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Value *OldV) {
  2085. // If this is a memcpy/memmove, construct the other pointer as the
  2086. // appropriate type. The "Other" pointer is the pointer that goes to memory
  2087. // that doesn't have anything to do with the alloca that we are promoting. For
  2088. // memset, this Value* stays null.
  2089. Value *OtherPtr = nullptr;
  2090. unsigned MemAlignment = MI->getAlignment();
  2091. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove/memcopy
  2092. if (OldV == MTI->getRawDest())
  2093. OtherPtr = MTI->getRawSource();
  2094. else {
  2095. assert(OldV == MTI->getRawSource());
  2096. OtherPtr = MTI->getRawDest();
  2097. }
  2098. }
  2099. // If there is an other pointer, we want to convert it to the same pointer
  2100. // type as AI has, so we can GEP through it safely.
  2101. if (OtherPtr) {
  2102. unsigned AddrSpace =
  2103. cast<PointerType>(OtherPtr->getType())->getAddressSpace();
  2104. // Remove bitcasts and all-zero GEPs from OtherPtr. This is an
  2105. // optimization, but it's also required to detect the corner case where
  2106. // both pointer operands are referencing the same memory, and where
  2107. // OtherPtr may be a bitcast or GEP that currently being rewritten. (This
  2108. // function is only called for mem intrinsics that access the whole
  2109. // aggregate, so non-zero GEPs are not an issue here.)
  2110. OtherPtr = OtherPtr->stripPointerCasts();
  2111. // Copying the alloca to itself is a no-op: just delete it.
  2112. if (OtherPtr == OldVal || OtherPtr == NewElts[0]) {
  2113. // This code will run twice for a no-op memcpy -- once for each operand.
  2114. // Put only one reference to MI on the DeadInsts list.
  2115. for (SmallVectorImpl<Value *>::const_iterator I = DeadInsts.begin(),
  2116. E = DeadInsts.end();
  2117. I != E; ++I)
  2118. if (*I == MI)
  2119. return;
  2120. // Remove the uses so that the caller can keep iterating over its other users
  2121. MI->setOperand(0, UndefValue::get(MI->getOperand(0)->getType()));
  2122. MI->setOperand(1, UndefValue::get(MI->getOperand(1)->getType()));
  2123. DeadInsts.push_back(MI);
  2124. return;
  2125. }
  2126. // If the pointer is not the right type, insert a bitcast to the right
  2127. // type.
  2128. Type *NewTy =
  2129. PointerType::get(OldVal->getType()->getPointerElementType(), AddrSpace);
  2130. if (OtherPtr->getType() != NewTy)
  2131. OtherPtr = new BitCastInst(OtherPtr, NewTy, OtherPtr->getName(), MI);
  2132. }
  2133. // Process each element of the aggregate.
  2134. bool SROADest = MI->getRawDest() == OldV;
  2135. Constant *Zero = Constant::getNullValue(Type::getInt32Ty(MI->getContext()));
  2136. const DataLayout &DL = MI->getModule()->getDataLayout();
  2137. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2138. // If this is a memcpy/memmove, emit a GEP of the other element address.
  2139. Value *OtherElt = nullptr;
  2140. unsigned OtherEltAlign = MemAlignment;
  2141. if (OtherPtr) {
  2142. Value *Idx[2] = {Zero,
  2143. ConstantInt::get(Type::getInt32Ty(MI->getContext()), i)};
  2144. OtherElt = GetElementPtrInst::CreateInBounds(
  2145. OtherPtr, Idx, OtherPtr->getName() + "." + Twine(i), MI);
  2146. uint64_t EltOffset;
  2147. PointerType *OtherPtrTy = cast<PointerType>(OtherPtr->getType());
  2148. Type *OtherTy = OtherPtrTy->getElementType();
  2149. if (StructType *ST = dyn_cast<StructType>(OtherTy)) {
  2150. EltOffset = DL.getStructLayout(ST)->getElementOffset(i);
  2151. } else {
  2152. Type *EltTy = cast<SequentialType>(OtherTy)->getElementType();
  2153. EltOffset = DL.getTypeAllocSize(EltTy) * i;
  2154. }
  2155. // The alignment of the other pointer is the guaranteed alignment of the
  2156. // element, which is affected by both the known alignment of the whole
  2157. // mem intrinsic and the alignment of the element. If the alignment of
  2158. // the memcpy (f.e.) is 32 but the element is at a 4-byte offset, then the
  2159. // known alignment is just 4 bytes.
  2160. OtherEltAlign = (unsigned)MinAlign(OtherEltAlign, EltOffset);
  2161. }
  2162. Value *EltPtr = NewElts[i];
  2163. Type *EltTy = cast<PointerType>(EltPtr->getType())->getElementType();
  2164. // If we got down to a scalar, insert a load or store as appropriate.
  2165. if (EltTy->isSingleValueType()) {
  2166. if (isa<MemTransferInst>(MI)) {
  2167. if (SROADest) {
  2168. // From Other to Alloca.
  2169. Value *Elt = new LoadInst(OtherElt, "tmp", false, OtherEltAlign, MI);
  2170. new StoreInst(Elt, EltPtr, MI);
  2171. } else {
  2172. // From Alloca to Other.
  2173. Value *Elt = new LoadInst(EltPtr, "tmp", MI);
  2174. new StoreInst(Elt, OtherElt, false, OtherEltAlign, MI);
  2175. }
  2176. continue;
  2177. }
  2178. assert(isa<MemSetInst>(MI));
  2179. // If the stored element is zero (common case), just store a null
  2180. // constant.
  2181. Constant *StoreVal;
  2182. if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getArgOperand(1))) {
  2183. if (CI->isZero()) {
  2184. StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0>
  2185. } else {
  2186. // If EltTy is a vector type, get the element type.
  2187. Type *ValTy = EltTy->getScalarType();
  2188. // Construct an integer with the right value.
  2189. unsigned EltSize = DL.getTypeSizeInBits(ValTy);
  2190. APInt OneVal(EltSize, CI->getZExtValue());
  2191. APInt TotalVal(OneVal);
  2192. // Set each byte.
  2193. for (unsigned i = 0; 8 * i < EltSize; ++i) {
  2194. TotalVal = TotalVal.shl(8);
  2195. TotalVal |= OneVal;
  2196. }
  2197. // Convert the integer value to the appropriate type.
  2198. StoreVal = ConstantInt::get(CI->getContext(), TotalVal);
  2199. if (ValTy->isPointerTy())
  2200. StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy);
  2201. else if (ValTy->isFloatingPointTy())
  2202. StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy);
  2203. assert(StoreVal->getType() == ValTy && "Type mismatch!");
  2204. // If the requested value was a vector constant, create it.
  2205. if (EltTy->isVectorTy()) {
  2206. unsigned NumElts = cast<VectorType>(EltTy)->getNumElements();
  2207. StoreVal = ConstantVector::getSplat(NumElts, StoreVal);
  2208. }
  2209. }
  2210. new StoreInst(StoreVal, EltPtr, MI);
  2211. continue;
  2212. }
  2213. // Otherwise, if we're storing a byte variable, use a memset call for
  2214. // this element.
  2215. }
  2216. unsigned EltSize = DL.getTypeAllocSize(EltTy);
  2217. if (!EltSize)
  2218. continue;
  2219. IRBuilder<> Builder(MI);
  2220. // Finally, insert the meminst for this element.
  2221. if (isa<MemSetInst>(MI)) {
  2222. Builder.CreateMemSet(EltPtr, MI->getArgOperand(1), EltSize,
  2223. MI->isVolatile());
  2224. } else {
  2225. assert(isa<MemTransferInst>(MI));
  2226. Value *Dst = SROADest ? EltPtr : OtherElt; // Dest ptr
  2227. Value *Src = SROADest ? OtherElt : EltPtr; // Src ptr
  2228. if (isa<MemCpyInst>(MI))
  2229. Builder.CreateMemCpy(Dst, Src, EltSize, OtherEltAlign,
  2230. MI->isVolatile());
  2231. else
  2232. Builder.CreateMemMove(Dst, Src, EltSize, OtherEltAlign,
  2233. MI->isVolatile());
  2234. }
  2235. }
  2236. // Remove the use so that the caller can keep iterating over its other users
  2237. MI->setOperand(0, UndefValue::get(MI->getOperand(0)->getType()));
  2238. if (isa<MemTransferInst>(MI))
  2239. MI->setOperand(1, UndefValue::get(MI->getOperand(1)->getType()));
  2240. DeadInsts.push_back(MI);
  2241. }
  2242. void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
  2243. // Unused bitcast may be leftover from temporary memcpy
  2244. if (BCI->use_empty()) {
  2245. BCI->eraseFromParent();
  2246. return;
  2247. }
  2248. Type *DstTy = BCI->getType();
  2249. Value *Val = BCI->getOperand(0);
  2250. Type *SrcTy = Val->getType();
  2251. if (!DstTy->isPointerTy()) {
  2252. assert(0 && "Type mismatch.");
  2253. return;
  2254. }
  2255. if (!SrcTy->isPointerTy()) {
  2256. assert(0 && "Type mismatch.");
  2257. return;
  2258. }
  2259. DstTy = DstTy->getPointerElementType();
  2260. SrcTy = SrcTy->getPointerElementType();
  2261. if (!DstTy->isStructTy()) {
  2262. // This is an llvm.lifetime.* intrinsic. Replace bitcast by a bitcast for each element.
  2263. SmallVector<IntrinsicInst*, 16> ToReplace;
  2264. DXASSERT(onlyUsedByLifetimeMarkers(BCI),
  2265. "expected struct bitcast to only be used by lifetime intrinsics");
  2266. for (User *User : BCI->users()) {
  2267. IntrinsicInst *Intrin = cast<IntrinsicInst>(User);
  2268. ToReplace.push_back(Intrin);
  2269. }
  2270. const DataLayout &DL = BCI->getModule()->getDataLayout();
  2271. for (IntrinsicInst *Intrin : ToReplace) {
  2272. IRBuilder<> Builder(Intrin);
  2273. for (Value *Elt : NewElts) {
  2274. assert(Elt->getType()->isPointerTy());
  2275. Type *ElPtrTy = Elt->getType();
  2276. Type *ElTy = ElPtrTy->getPointerElementType();
  2277. Value *SizeV = Builder.getInt64( DL.getTypeAllocSize(ElTy) );
  2278. Value *Ptr = Builder.CreateBitCast(Elt, Builder.getInt8PtrTy());
  2279. Value *Args[] = {SizeV, Ptr};
  2280. CallInst *C = Builder.CreateCall(Intrin->getCalledFunction(), Args);
  2281. C->setDoesNotThrow();
  2282. }
  2283. assert(Intrin->use_empty());
  2284. Intrin->eraseFromParent();
  2285. }
  2286. assert(BCI->use_empty());
  2287. BCI->eraseFromParent();
  2288. return;
  2289. }
  2290. if (!SrcTy->isStructTy()) {
  2291. assert(0 && "Type mismatch.");
  2292. return;
  2293. }
  2294. // Only support bitcast to parent struct type.
  2295. StructType *DstST = cast<StructType>(DstTy);
  2296. StructType *SrcST = cast<StructType>(SrcTy);
  2297. bool bTypeMatch = false;
  2298. unsigned level = 0;
  2299. while (SrcST) {
  2300. level++;
  2301. Type *EltTy = SrcST->getElementType(0);
  2302. if (EltTy == DstST) {
  2303. bTypeMatch = true;
  2304. break;
  2305. }
  2306. SrcST = dyn_cast<StructType>(EltTy);
  2307. }
  2308. if (!bTypeMatch) {
  2309. // If the layouts match, just replace the type
  2310. SrcST = cast<StructType>(SrcTy);
  2311. if (SrcST->isLayoutIdentical(DstST)) {
  2312. BCI->mutateType(Val->getType());
  2313. BCI->replaceAllUsesWith(Val);
  2314. BCI->eraseFromParent();
  2315. return;
  2316. }
  2317. assert(0 && "Type mismatch.");
  2318. return;
  2319. }
  2320. std::vector<Value*> idxList(level+1);
  2321. ConstantInt *zeroIdx = ConstantInt::get(Type::getInt32Ty(Val->getContext()), 0);
  2322. for (unsigned i=0;i<(level+1);i++)
  2323. idxList[i] = zeroIdx;
  2324. IRBuilder<> Builder(BCI);
  2325. Builder.AllowFolding = false; // We need an Instruction, so make sure we don't get a constant
  2326. Instruction *GEP = cast<Instruction>(Builder.CreateInBoundsGEP(Val, idxList));
  2327. BCI->replaceAllUsesWith(GEP);
  2328. BCI->eraseFromParent();
  2329. IRBuilder<> GEPBuilder(GEP);
  2330. RewriteForGEP(cast<GEPOperator>(GEP), GEPBuilder);
  2331. }
  2332. /// RewriteCallArg - For Functions which don't flat,
  2333. /// replace OldVal with alloca and
  2334. /// copy in copy out data between alloca and flattened NewElts
  2335. /// in CallInst.
  2336. void SROA_Helper::RewriteCallArg(CallInst *CI, unsigned ArgIdx, bool bIn,
  2337. bool bOut) {
  2338. Function *F = CI->getParent()->getParent();
  2339. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(F));
  2340. const DataLayout &DL = F->getParent()->getDataLayout();
  2341. Value *userTyV = CI->getArgOperand(ArgIdx);
  2342. PointerType *userTy = cast<PointerType>(userTyV->getType());
  2343. Type *userTyElt = userTy->getElementType();
  2344. Value *Alloca = AllocaBuilder.CreateAlloca(userTyElt);
  2345. IRBuilder<> Builder(CI);
  2346. if (bIn) {
  2347. MemCpyInst *cpy = cast<MemCpyInst>(Builder.CreateMemCpy(
  2348. Alloca, userTyV, DL.getTypeAllocSize(userTyElt), false));
  2349. RewriteMemIntrin(cpy, cpy->getRawSource());
  2350. }
  2351. CI->setArgOperand(ArgIdx, Alloca);
  2352. if (bOut) {
  2353. Builder.SetInsertPoint(CI->getNextNode());
  2354. MemCpyInst *cpy = cast<MemCpyInst>(Builder.CreateMemCpy(
  2355. userTyV, Alloca, DL.getTypeAllocSize(userTyElt), false));
  2356. RewriteMemIntrin(cpy, cpy->getRawSource());
  2357. }
  2358. }
  2359. // Flatten matching OldVal arg to NewElts, optionally loading values (loadElts).
  2360. // Does not replace or clean up old CallInst.
  2361. static CallInst *CreateFlattenedHLIntrinsicCall(
  2362. CallInst *CI, Value* OldVal, ArrayRef<Value*> NewElts, bool loadElts) {
  2363. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2364. Function *F = CI->getCalledFunction();
  2365. DXASSERT_NOMSG(group == HLOpcodeGroup::HLIntrinsic);
  2366. unsigned opcode = GetHLOpcode(CI);
  2367. IRBuilder<> Builder(CI);
  2368. SmallVector<Value *, 4> flatArgs;
  2369. for (Value *arg : CI->arg_operands()) {
  2370. if (arg == OldVal) {
  2371. for (Value *Elt : NewElts) {
  2372. if (loadElts && Elt->getType()->isPointerTy())
  2373. Elt = Builder.CreateLoad(Elt);
  2374. flatArgs.emplace_back(Elt);
  2375. }
  2376. } else
  2377. flatArgs.emplace_back(arg);
  2378. }
  2379. SmallVector<Type *, 4> flatParamTys;
  2380. for (Value *arg : flatArgs)
  2381. flatParamTys.emplace_back(arg->getType());
  2382. FunctionType *flatFuncTy =
  2383. FunctionType::get(CI->getType(), flatParamTys, false);
  2384. Function *flatF =
  2385. GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode,
  2386. F->getAttributes().getFnAttributes());
  2387. return Builder.CreateCall(flatF, flatArgs);
  2388. }
  2389. static CallInst *RewriteWithFlattenedHLIntrinsicCall(
  2390. CallInst *CI, Value* OldVal, ArrayRef<Value*> NewElts, bool loadElts) {
  2391. CallInst *flatCI = CreateFlattenedHLIntrinsicCall(
  2392. CI, OldVal, NewElts, /*loadElts*/loadElts);
  2393. CI->replaceAllUsesWith(flatCI);
  2394. // Clear CI operands so we don't try to translate old call again
  2395. for (auto& opit : CI->operands())
  2396. opit.set(UndefValue::get(opit->getType()));
  2397. return flatCI;
  2398. }
  2399. /// RewriteCall - Replace OldVal with flattened NewElts in CallInst.
  2400. void SROA_Helper::RewriteCall(CallInst *CI) {
  2401. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2402. if (group != HLOpcodeGroup::NotHL) {
  2403. unsigned opcode = GetHLOpcode(CI);
  2404. if (group == HLOpcodeGroup::HLIntrinsic) {
  2405. IntrinsicOp IOP = static_cast<IntrinsicOp>(opcode);
  2406. switch (IOP) {
  2407. case IntrinsicOp::MOP_Append: {
  2408. // Buffer Append already expand in code gen.
  2409. // Must be OutputStream Append here.
  2410. // Every Elt has a pointer type.
  2411. // For Append, this is desired, so don't load.
  2412. RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts, /*loadElts*/false);
  2413. DeadInsts.push_back(CI);
  2414. } break;
  2415. case IntrinsicOp::IOP_TraceRay: {
  2416. if (OldVal ==
  2417. CI->getArgOperand(HLOperandIndex::kTraceRayRayDescOpIdx)) {
  2418. RewriteCallArg(CI, HLOperandIndex::kTraceRayRayDescOpIdx,
  2419. /*bIn*/ true, /*bOut*/ false);
  2420. } else {
  2421. DXASSERT(OldVal ==
  2422. CI->getArgOperand(HLOperandIndex::kTraceRayPayLoadOpIdx),
  2423. "else invalid TraceRay");
  2424. RewriteCallArg(CI, HLOperandIndex::kTraceRayPayLoadOpIdx,
  2425. /*bIn*/ true, /*bOut*/ true);
  2426. }
  2427. } break;
  2428. case IntrinsicOp::IOP_ReportHit: {
  2429. RewriteCallArg(CI, HLOperandIndex::kReportIntersectionAttributeOpIdx,
  2430. /*bIn*/ true, /*bOut*/ false);
  2431. } break;
  2432. case IntrinsicOp::IOP_CallShader: {
  2433. RewriteCallArg(CI, HLOperandIndex::kCallShaderPayloadOpIdx,
  2434. /*bIn*/ true, /*bOut*/ true);
  2435. } break;
  2436. case IntrinsicOp::MOP_TraceRayInline: {
  2437. if (OldVal ==
  2438. CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
  2439. RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts, /*loadElts*/true);
  2440. DeadInsts.push_back(CI);
  2441. break;
  2442. }
  2443. }
  2444. __fallthrough;
  2445. default:
  2446. // RayQuery this pointer replacement.
  2447. if (OldVal->getType()->isPointerTy() &&
  2448. CI->getNumArgOperands() >= HLOperandIndex::kHandleOpIdx &&
  2449. OldVal == CI->getArgOperand(HLOperandIndex::kHandleOpIdx) &&
  2450. dxilutil::IsHLSLRayQueryType(
  2451. OldVal->getType()->getPointerElementType())) {
  2452. // For RayQuery methods, we want to replace the RayQuery this pointer
  2453. // with a load and use of the underlying handle value.
  2454. // This will allow elimination of RayQuery types earlier.
  2455. RewriteWithFlattenedHLIntrinsicCall(CI, OldVal, NewElts, /*loadElts*/true);
  2456. DeadInsts.push_back(CI);
  2457. break;
  2458. }
  2459. DXASSERT(0, "cannot flatten hlsl intrinsic.");
  2460. }
  2461. }
  2462. // TODO: check other high level dx operations if need to.
  2463. } else {
  2464. DXASSERT(0, "should done at inline");
  2465. }
  2466. }
  2467. /// RewriteForAddrSpaceCast - Rewrite the AddrSpaceCast, either ConstExpr or Inst.
  2468. void SROA_Helper::RewriteForAddrSpaceCast(Value *CE,
  2469. IRBuilder<> &Builder) {
  2470. SmallVector<Value *, 8> NewCasts;
  2471. // create new AddrSpaceCast.
  2472. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2473. Value *NewCast = Builder.CreateAddrSpaceCast(
  2474. NewElts[i],
  2475. PointerType::get(NewElts[i]->getType()->getPointerElementType(),
  2476. CE->getType()->getPointerAddressSpace()));
  2477. NewCasts.emplace_back(NewCast);
  2478. }
  2479. SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL, DT);
  2480. helper.RewriteForScalarRepl(CE, Builder);
  2481. // Remove the use so that the caller can keep iterating over its other users
  2482. DXASSERT(CE->user_empty(), "All uses of the addrspacecast should have been eliminated");
  2483. if (Instruction *I = dyn_cast<Instruction>(CE))
  2484. I->eraseFromParent();
  2485. else
  2486. cast<Constant>(CE)->destroyConstant();
  2487. }
  2488. /// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
  2489. void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
  2490. if (GEPOperator *GEP = dyn_cast<GEPOperator>(CE)) {
  2491. if (OldVal == GEP->getPointerOperand()) {
  2492. // Flatten GEP.
  2493. RewriteForGEP(GEP, Builder);
  2494. return;
  2495. }
  2496. }
  2497. if (CE->getOpcode() == Instruction::AddrSpaceCast) {
  2498. if (OldVal == CE->getOperand(0)) {
  2499. // Flatten AddrSpaceCast.
  2500. RewriteForAddrSpaceCast(CE, Builder);
  2501. return;
  2502. }
  2503. }
  2504. for (Value::use_iterator UI = CE->use_begin(), E = CE->use_end(); UI != E;) {
  2505. Use &TheUse = *UI++;
  2506. if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) {
  2507. IRBuilder<> tmpBuilder(I);
  2508. // Replace CE with constInst.
  2509. Instruction *tmpInst = CE->getAsInstruction();
  2510. tmpBuilder.Insert(tmpInst);
  2511. TheUse.set(tmpInst);
  2512. }
  2513. else {
  2514. RewriteForConstExpr(cast<ConstantExpr>(TheUse.getUser()), Builder);
  2515. }
  2516. }
  2517. // Remove the use so that the caller can keep iterating over its other users
  2518. DXASSERT(CE->user_empty(), "All uses of the constantexpr should have been eliminated");
  2519. CE->destroyConstant();
  2520. }
  2521. /// RewriteForScalarRepl - OldVal is being split into NewElts, so rewrite
  2522. /// users of V, which references it, to use the separate elements.
  2523. void SROA_Helper::RewriteForScalarRepl(Value *V, IRBuilder<> &Builder) {
  2524. // Don't iterate upon the uses explicitly because we'll be removing them,
  2525. // and potentially adding new ones (if expanding memcpys) during the iteration.
  2526. Use* PrevUse = nullptr;
  2527. while (!V->use_empty()) {
  2528. Use &TheUse = *V->use_begin();
  2529. DXASSERT_LOCALVAR(PrevUse, &TheUse != PrevUse,
  2530. "Infinite loop while SROA'ing value, use isn't getting eliminated.");
  2531. PrevUse = &TheUse;
  2532. // Each of these must either call ->eraseFromParent()
  2533. // or null out the use of V so that we make progress.
  2534. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(TheUse.getUser())) {
  2535. RewriteForConstExpr(CE, Builder);
  2536. }
  2537. else {
  2538. Instruction *User = cast<Instruction>(TheUse.getUser());
  2539. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  2540. IRBuilder<> Builder(GEP);
  2541. RewriteForGEP(cast<GEPOperator>(GEP), Builder);
  2542. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
  2543. RewriteForLoad(ldInst);
  2544. else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
  2545. RewriteForStore(stInst);
  2546. else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
  2547. RewriteMemIntrin(MI, V);
  2548. else if (CallInst *CI = dyn_cast<CallInst>(User))
  2549. RewriteCall(CI);
  2550. else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
  2551. RewriteBitCast(BCI);
  2552. else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(User)) {
  2553. RewriteForAddrSpaceCast(CI, Builder);
  2554. } else {
  2555. assert(0 && "not support.");
  2556. }
  2557. }
  2558. }
  2559. }
  2560. static ArrayType *CreateNestArrayTy(Type *FinalEltTy,
  2561. ArrayRef<ArrayType *> nestArrayTys) {
  2562. Type *newAT = FinalEltTy;
  2563. for (auto ArrayTy = nestArrayTys.rbegin(), E=nestArrayTys.rend(); ArrayTy != E;
  2564. ++ArrayTy)
  2565. newAT = ArrayType::get(newAT, (*ArrayTy)->getNumElements());
  2566. return cast<ArrayType>(newAT);
  2567. }
  2568. /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  2569. /// Then do SROA on V.
  2570. bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
  2571. Type *&BrokenUpTy, uint64_t &NumInstances,
  2572. IRBuilder<> &Builder, bool bFlatVector,
  2573. bool hasPrecise, DxilTypeSystem &typeSys,
  2574. const DataLayout &DL,
  2575. SmallVector<Value *, 32> &DeadInsts,
  2576. DominatorTree *DT) {
  2577. DEBUG(dbgs() << "Found inst to SROA: " << *V << '\n');
  2578. Type *Ty = V->getType();
  2579. // Skip none pointer types.
  2580. if (!Ty->isPointerTy())
  2581. return false;
  2582. Ty = Ty->getPointerElementType();
  2583. // Skip none aggregate types.
  2584. if (!Ty->isAggregateType())
  2585. return false;
  2586. // Skip matrix types.
  2587. if (HLMatrixType::isa(Ty))
  2588. return false;
  2589. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
  2590. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2591. // Skip HLSL object types and RayQuery.
  2592. if (dxilutil::IsHLSLObjectType(ST)) {
  2593. return false;
  2594. }
  2595. BrokenUpTy = ST;
  2596. NumInstances = 1;
  2597. unsigned numTypes = ST->getNumContainedTypes();
  2598. Elts.reserve(numTypes);
  2599. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2600. // Skip empty struct.
  2601. if (SA && SA->IsEmptyStruct())
  2602. return true;
  2603. for (int i = 0, e = numTypes; i != e; ++i) {
  2604. AllocaInst *NA = AllocaBuilder.CreateAlloca(ST->getContainedType(i), nullptr, V->getName() + "." + Twine(i));
  2605. bool markPrecise = hasPrecise;
  2606. if (SA) {
  2607. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2608. markPrecise |= FA.IsPrecise();
  2609. }
  2610. if (markPrecise)
  2611. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2612. Elts.push_back(NA);
  2613. }
  2614. } else {
  2615. ArrayType *AT = cast<ArrayType>(Ty);
  2616. if (AT->getNumContainedTypes() == 0) {
  2617. // Skip case like [0 x %struct].
  2618. return false;
  2619. }
  2620. Type *ElTy = AT->getElementType();
  2621. SmallVector<ArrayType *, 4> nestArrayTys;
  2622. nestArrayTys.emplace_back(AT);
  2623. NumInstances = AT->getNumElements();
  2624. // support multi level of array
  2625. while (ElTy->isArrayTy()) {
  2626. ArrayType *ElAT = cast<ArrayType>(ElTy);
  2627. nestArrayTys.emplace_back(ElAT);
  2628. NumInstances *= ElAT->getNumElements();
  2629. ElTy = ElAT->getElementType();
  2630. }
  2631. BrokenUpTy = ElTy;
  2632. if (ElTy->isStructTy() &&
  2633. // Skip Matrix type.
  2634. !HLMatrixType::isa(ElTy)) {
  2635. if (!dxilutil::IsHLSLObjectType(ElTy)) {
  2636. // for array of struct
  2637. // split into arrays of struct elements
  2638. StructType *ElST = cast<StructType>(ElTy);
  2639. unsigned numTypes = ElST->getNumContainedTypes();
  2640. Elts.reserve(numTypes);
  2641. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
  2642. // Skip empty struct.
  2643. if (SA && SA->IsEmptyStruct())
  2644. return true;
  2645. for (int i = 0, e = numTypes; i != e; ++i) {
  2646. AllocaInst *NA = AllocaBuilder.CreateAlloca(
  2647. CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys),
  2648. nullptr, V->getName() + "." + Twine(i));
  2649. bool markPrecise = hasPrecise;
  2650. if (SA) {
  2651. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2652. markPrecise |= FA.IsPrecise();
  2653. }
  2654. if (markPrecise)
  2655. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2656. Elts.push_back(NA);
  2657. }
  2658. } else {
  2659. // For local resource array which not dynamic indexing,
  2660. // split it.
  2661. if (dxilutil::HasDynamicIndexing(V) ||
  2662. // Only support 1 dim split.
  2663. nestArrayTys.size() > 1)
  2664. return false;
  2665. BrokenUpTy = AT;
  2666. NumInstances = 1;
  2667. for (int i = 0, e = AT->getNumElements(); i != e; ++i) {
  2668. AllocaInst *NA = AllocaBuilder.CreateAlloca(ElTy, nullptr,
  2669. V->getName() + "." + Twine(i));
  2670. Elts.push_back(NA);
  2671. }
  2672. }
  2673. } else if (ElTy->isVectorTy()) {
  2674. // Skip vector if required.
  2675. if (!bFlatVector)
  2676. return false;
  2677. // for array of vector
  2678. // split into arrays of scalar
  2679. VectorType *ElVT = cast<VectorType>(ElTy);
  2680. BrokenUpTy = ElVT;
  2681. Elts.reserve(ElVT->getNumElements());
  2682. ArrayType *scalarArrayTy = CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
  2683. for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
  2684. AllocaInst *NA = AllocaBuilder.CreateAlloca(scalarArrayTy, nullptr,
  2685. V->getName() + "." + Twine(i));
  2686. if (hasPrecise)
  2687. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2688. Elts.push_back(NA);
  2689. }
  2690. } else
  2691. // Skip array of basic types.
  2692. return false;
  2693. }
  2694. // Now that we have created the new alloca instructions, rewrite all the
  2695. // uses of the old alloca.
  2696. SROA_Helper helper(V, Elts, DeadInsts, typeSys, DL, DT);
  2697. helper.RewriteForScalarRepl(V, Builder);
  2698. return true;
  2699. }
  2700. static Constant *GetEltInit(Type *Ty, Constant *Init, unsigned idx,
  2701. Type *EltTy) {
  2702. if (isa<UndefValue>(Init))
  2703. return UndefValue::get(EltTy);
  2704. if (dyn_cast<StructType>(Ty)) {
  2705. return Init->getAggregateElement(idx);
  2706. } else if (dyn_cast<VectorType>(Ty)) {
  2707. return Init->getAggregateElement(idx);
  2708. } else {
  2709. ArrayType *AT = cast<ArrayType>(Ty);
  2710. ArrayType *EltArrayTy = cast<ArrayType>(EltTy);
  2711. std::vector<Constant *> Elts;
  2712. if (!AT->getElementType()->isArrayTy()) {
  2713. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2714. // Get Array[i]
  2715. Constant *InitArrayElt = Init->getAggregateElement(i);
  2716. // Get Array[i].idx
  2717. InitArrayElt = InitArrayElt->getAggregateElement(idx);
  2718. Elts.emplace_back(InitArrayElt);
  2719. }
  2720. return ConstantArray::get(EltArrayTy, Elts);
  2721. } else {
  2722. Type *EltTy = AT->getElementType();
  2723. ArrayType *NestEltArrayTy = cast<ArrayType>(EltArrayTy->getElementType());
  2724. // Nested array.
  2725. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2726. // Get Array[i]
  2727. Constant *InitArrayElt = Init->getAggregateElement(i);
  2728. // Get Array[i].idx
  2729. InitArrayElt = GetEltInit(EltTy, InitArrayElt, idx, NestEltArrayTy);
  2730. Elts.emplace_back(InitArrayElt);
  2731. }
  2732. return ConstantArray::get(EltArrayTy, Elts);
  2733. }
  2734. }
  2735. }
  2736. unsigned SROA_Helper::GetEltAlign(unsigned ValueAlign, const DataLayout &DL,
  2737. Type *EltTy, unsigned Offset) {
  2738. unsigned Alignment = ValueAlign;
  2739. if (ValueAlign == 0) {
  2740. // The minimum alignment which users can rely on when the explicit
  2741. // alignment is omitted or zero is that required by the ABI for this
  2742. // type.
  2743. Alignment = DL.getABITypeAlignment(EltTy);
  2744. }
  2745. return MinAlign(Alignment, Offset);
  2746. }
  2747. /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  2748. /// Then do SROA on V.
  2749. bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
  2750. std::vector<Value *> &Elts,
  2751. IRBuilder<> &Builder, bool bFlatVector,
  2752. bool hasPrecise, DxilTypeSystem &typeSys,
  2753. const DataLayout &DL,
  2754. SmallVector<Value *, 32> &DeadInsts,
  2755. DominatorTree *DT) {
  2756. DEBUG(dbgs() << "Found inst to SROA: " << *GV << '\n');
  2757. Type *Ty = GV->getType();
  2758. // Skip none pointer types.
  2759. if (!Ty->isPointerTy())
  2760. return false;
  2761. Ty = Ty->getPointerElementType();
  2762. // Skip none aggregate types.
  2763. if (!Ty->isAggregateType() && !bFlatVector)
  2764. return false;
  2765. // Skip basic types.
  2766. if (Ty->isSingleValueType() && !Ty->isVectorTy())
  2767. return false;
  2768. // Skip matrix types.
  2769. if (HLMatrixType::isa(Ty))
  2770. return false;
  2771. Module *M = GV->getParent();
  2772. Constant *Init = GV->hasInitializer() ? GV->getInitializer() : UndefValue::get(Ty);
  2773. bool isConst = GV->isConstant();
  2774. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2775. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2776. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2777. const unsigned Alignment = GV->getAlignment();
  2778. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2779. // Skip HLSL object types.
  2780. if (dxilutil::IsHLSLObjectType(ST))
  2781. return false;
  2782. unsigned numTypes = ST->getNumContainedTypes();
  2783. Elts.reserve(numTypes);
  2784. unsigned Offset = 0;
  2785. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2786. for (int i = 0, e = numTypes; i != e; ++i) {
  2787. Type *EltTy = ST->getElementType(i);
  2788. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  2789. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2790. *M, ST->getContainedType(i), /*IsConstant*/ isConst, linkage,
  2791. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  2792. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2793. EltGV->setAlignment(GetEltAlign(Alignment, DL, EltTy, Offset));
  2794. Offset += DL.getTypeAllocSize(EltTy);
  2795. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2796. // TODO: set precise.
  2797. // if (hasPrecise || FA.IsPrecise())
  2798. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  2799. Elts.push_back(EltGV);
  2800. }
  2801. } else if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  2802. // TODO: support dynamic indexing on vector by change it to array.
  2803. unsigned numElts = VT->getNumElements();
  2804. Elts.reserve(numElts);
  2805. Type *EltTy = VT->getElementType();
  2806. unsigned Offset = 0;
  2807. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2808. for (int i = 0, e = numElts; i != e; ++i) {
  2809. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  2810. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2811. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2812. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  2813. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2814. EltGV->setAlignment(GetEltAlign(Alignment, DL, EltTy, Offset));
  2815. Offset += DL.getTypeAllocSize(EltTy);
  2816. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2817. // TODO: set precise.
  2818. // if (hasPrecise || FA.IsPrecise())
  2819. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  2820. Elts.push_back(EltGV);
  2821. }
  2822. } else {
  2823. ArrayType *AT = cast<ArrayType>(Ty);
  2824. if (AT->getNumContainedTypes() == 0) {
  2825. // Skip case like [0 x %struct].
  2826. return false;
  2827. }
  2828. Type *ElTy = AT->getElementType();
  2829. SmallVector<ArrayType *, 4> nestArrayTys;
  2830. nestArrayTys.emplace_back(AT);
  2831. // support multi level of array
  2832. while (ElTy->isArrayTy()) {
  2833. ArrayType *ElAT = cast<ArrayType>(ElTy);
  2834. nestArrayTys.emplace_back(ElAT);
  2835. ElTy = ElAT->getElementType();
  2836. }
  2837. if (ElTy->isStructTy() &&
  2838. // Skip Matrix and Resource type.
  2839. !HLMatrixType::isa(ElTy) &&
  2840. !dxilutil::IsHLSLResourceType(ElTy)) {
  2841. // for array of struct
  2842. // split into arrays of struct elements
  2843. StructType *ElST = cast<StructType>(ElTy);
  2844. unsigned numTypes = ElST->getNumContainedTypes();
  2845. Elts.reserve(numTypes);
  2846. unsigned Offset = 0;
  2847. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
  2848. for (int i = 0, e = numTypes; i != e; ++i) {
  2849. Type *EltTy =
  2850. CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys);
  2851. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  2852. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2853. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2854. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  2855. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2856. EltGV->setAlignment(GetEltAlign(Alignment, DL, EltTy, Offset));
  2857. Offset += DL.getTypeAllocSize(EltTy);
  2858. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2859. // TODO: set precise.
  2860. // if (hasPrecise || FA.IsPrecise())
  2861. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  2862. Elts.push_back(EltGV);
  2863. }
  2864. } else if (ElTy->isVectorTy()) {
  2865. // Skip vector if required.
  2866. if (!bFlatVector)
  2867. return false;
  2868. // for array of vector
  2869. // split into arrays of scalar
  2870. VectorType *ElVT = cast<VectorType>(ElTy);
  2871. Elts.reserve(ElVT->getNumElements());
  2872. ArrayType *scalarArrayTy =
  2873. CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
  2874. unsigned Offset = 0;
  2875. for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
  2876. Constant *EltInit = GetEltInit(Ty, Init, i, scalarArrayTy);
  2877. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2878. *M, scalarArrayTy, /*IsConstant*/ isConst, linkage,
  2879. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  2880. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2881. // TODO: set precise.
  2882. // if (hasPrecise)
  2883. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  2884. EltGV->setAlignment(GetEltAlign(Alignment, DL, scalarArrayTy, Offset));
  2885. Offset += DL.getTypeAllocSize(scalarArrayTy);
  2886. Elts.push_back(EltGV);
  2887. }
  2888. } else
  2889. // Skip array of basic types.
  2890. return false;
  2891. }
  2892. // Now that we have created the new alloca instructions, rewrite all the
  2893. // uses of the old alloca.
  2894. SROA_Helper helper(GV, Elts, DeadInsts, typeSys, DL, DT);
  2895. helper.RewriteForScalarRepl(GV, Builder);
  2896. return true;
  2897. }
  2898. static void ReplaceConstantWithInst(Constant *C, Value *V, IRBuilder<> &Builder) {
  2899. for (auto it = C->user_begin(); it != C->user_end(); ) {
  2900. User *U = *(it++);
  2901. if (Instruction *I = dyn_cast<Instruction>(U)) {
  2902. I->replaceUsesOfWith(C, V);
  2903. } else {
  2904. // Skip unused ConstantExpr.
  2905. if (U->user_empty())
  2906. continue;
  2907. ConstantExpr *CE = cast<ConstantExpr>(U);
  2908. Instruction *Inst = CE->getAsInstruction();
  2909. Builder.Insert(Inst);
  2910. Inst->replaceUsesOfWith(C, V);
  2911. ReplaceConstantWithInst(CE, Inst, Builder);
  2912. }
  2913. }
  2914. C->removeDeadConstantUsers();
  2915. }
  2916. static void ReplaceUnboundedArrayUses(Value *V, Value *Src) {
  2917. for (auto it = V->user_begin(); it != V->user_end(); ) {
  2918. User *U = *(it++);
  2919. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  2920. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  2921. // Must set the insert point to the GEP itself (instead of the memcpy),
  2922. // because the indices might not dominate the memcpy.
  2923. IRBuilder<> Builder(GEP);
  2924. Value *NewGEP = Builder.CreateGEP(Src, idxList);
  2925. GEP->replaceAllUsesWith(NewGEP);
  2926. } else if (BitCastInst *BC = dyn_cast<BitCastInst>(U)) {
  2927. BC->setOperand(0, Src);
  2928. } else {
  2929. DXASSERT(false, "otherwise unbounded array used in unexpected instruction");
  2930. }
  2931. }
  2932. }
  2933. static bool IsUnboundedArrayMemcpy(Type *destTy, Type *srcTy) {
  2934. return (destTy->isArrayTy() && srcTy->isArrayTy()) &&
  2935. (destTy->getArrayNumElements() == 0 || srcTy->getArrayNumElements() == 0);
  2936. }
  2937. static bool ArePointersToStructsOfIdenticalLayouts(Type *DstTy, Type *SrcTy) {
  2938. if (!SrcTy->isPointerTy() || !DstTy->isPointerTy())
  2939. return false;
  2940. DstTy = DstTy->getPointerElementType();
  2941. SrcTy = SrcTy->getPointerElementType();
  2942. if (!SrcTy->isStructTy() || !DstTy->isStructTy())
  2943. return false;
  2944. StructType *DstST = cast<StructType>(DstTy);
  2945. StructType *SrcST = cast<StructType>(SrcTy);
  2946. return SrcST->isLayoutIdentical(DstST);
  2947. }
  2948. static std::vector<Value *> GetConstValueIdxList(IRBuilder<> &builder,
  2949. std::vector<unsigned> idxlist) {
  2950. std::vector<Value *> idxConstList;
  2951. for (unsigned idx : idxlist) {
  2952. idxConstList.push_back(ConstantInt::get(builder.getInt32Ty(), idx));
  2953. }
  2954. return idxConstList;
  2955. }
  2956. static void CopyElementsOfStructsWithIdenticalLayout(
  2957. IRBuilder<> &builder, Value *destPtr, Value *srcPtr, Type *ty,
  2958. std::vector<unsigned>& idxlist) {
  2959. if (ty->isStructTy()) {
  2960. for (unsigned i = 0; i < ty->getStructNumElements(); i++) {
  2961. idxlist.push_back(i);
  2962. CopyElementsOfStructsWithIdenticalLayout(
  2963. builder, destPtr, srcPtr, ty->getStructElementType(i), idxlist);
  2964. idxlist.pop_back();
  2965. }
  2966. }
  2967. else if (ty->isArrayTy()) {
  2968. for (unsigned i = 0; i < ty->getArrayNumElements(); i++) {
  2969. idxlist.push_back(i);
  2970. CopyElementsOfStructsWithIdenticalLayout(
  2971. builder, destPtr, srcPtr, ty->getArrayElementType(), idxlist);
  2972. idxlist.pop_back();
  2973. }
  2974. }
  2975. else if (ty->isIntegerTy() || ty->isFloatTy() || ty->isDoubleTy() ||
  2976. ty->isHalfTy() || ty->isVectorTy()) {
  2977. Value *srcGEP =
  2978. builder.CreateInBoundsGEP(srcPtr, GetConstValueIdxList(builder, idxlist));
  2979. Value *destGEP =
  2980. builder.CreateInBoundsGEP(destPtr, GetConstValueIdxList(builder, idxlist));
  2981. LoadInst *LI = builder.CreateLoad(srcGEP);
  2982. builder.CreateStore(LI, destGEP);
  2983. }
  2984. else {
  2985. DXASSERT(0, "encountered unsupported type when copying elements of identical structs.");
  2986. }
  2987. }
  2988. static void removeLifetimeUsers(Value *V) {
  2989. std::set<Value*> users(V->users().begin(), V->users().end());
  2990. for (Value *U : users) {
  2991. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) {
  2992. if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
  2993. II->getIntrinsicID() == Intrinsic::lifetime_end) {
  2994. II->eraseFromParent();
  2995. }
  2996. } else if (isa<BitCastInst>(U) ||
  2997. isa<AddrSpaceCastInst>(U) ||
  2998. isa<GetElementPtrInst>(U)) {
  2999. // Recurse into bitcast, addrspacecast, GEP.
  3000. removeLifetimeUsers(U);
  3001. if (U->use_empty())
  3002. cast<Instruction>(U)->eraseFromParent();
  3003. }
  3004. }
  3005. }
  3006. // Conservatively remove all lifetime users of both source and target.
  3007. // Otherwise, wrong lifetimes could be inherited either way.
  3008. // TODO: We should be merging the lifetimes. For convenience, just remove them
  3009. // for now to be safe.
  3010. static void updateLifetimeForReplacement(Value *From, Value *To)
  3011. {
  3012. removeLifetimeUsers(From);
  3013. removeLifetimeUsers(To);
  3014. }
  3015. static bool DominateAllUsers(Instruction *I, Value *V, DominatorTree *DT);
  3016. namespace {
  3017. void replaceScalarArrayGEPWithVectorArrayGEP(User *GEP, Value *VectorArray,
  3018. IRBuilder<> &Builder,
  3019. unsigned sizeInDwords) {
  3020. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  3021. Value *PtrOffset = GEPIt.getOperand();
  3022. ++GEPIt;
  3023. Value *ArrayIdx = GEPIt.getOperand();
  3024. ++GEPIt;
  3025. ArrayIdx = Builder.CreateAdd(PtrOffset, ArrayIdx);
  3026. DXASSERT_LOCALVAR(E, GEPIt == E, "invalid gep on scalar array");
  3027. unsigned shift = 2;
  3028. unsigned mask = 0x3;
  3029. switch (sizeInDwords) {
  3030. case 2:
  3031. shift = 1;
  3032. mask = 1;
  3033. break;
  3034. case 1:
  3035. shift = 2;
  3036. mask = 0x3;
  3037. break;
  3038. default:
  3039. DXASSERT(0, "invalid scalar size");
  3040. break;
  3041. }
  3042. Value *VecIdx = Builder.CreateLShr(ArrayIdx, shift);
  3043. Value *VecPtr = Builder.CreateGEP(
  3044. VectorArray, {ConstantInt::get(VecIdx->getType(), 0), VecIdx});
  3045. Value *CompIdx = Builder.CreateAnd(ArrayIdx, mask);
  3046. Value *NewGEP = Builder.CreateGEP(
  3047. VecPtr, {ConstantInt::get(CompIdx->getType(), 0), CompIdx});
  3048. GEP->replaceAllUsesWith(NewGEP);
  3049. }
  3050. void replaceScalarArrayWithVectorArray(Value *ScalarArray, Value *VectorArray,
  3051. MemCpyInst *MC, unsigned sizeInDwords) {
  3052. LLVMContext &Context = ScalarArray->getContext();
  3053. // All users should be element type.
  3054. // Replace users of AI or GV.
  3055. for (auto it = ScalarArray->user_begin(); it != ScalarArray->user_end();) {
  3056. User *U = *(it++);
  3057. if (U->user_empty())
  3058. continue;
  3059. if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  3060. BCI->setOperand(0, VectorArray);
  3061. continue;
  3062. }
  3063. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
  3064. IRBuilder<> Builder(Context);
  3065. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  3066. // NewGEP must be GEPOperator too.
  3067. // No instruction will be build.
  3068. replaceScalarArrayGEPWithVectorArrayGEP(U, VectorArray, Builder,
  3069. sizeInDwords);
  3070. } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
  3071. Value *NewAddrSpaceCast = Builder.CreateAddrSpaceCast(
  3072. VectorArray,
  3073. PointerType::get(VectorArray->getType()->getPointerElementType(),
  3074. CE->getType()->getPointerAddressSpace()));
  3075. replaceScalarArrayWithVectorArray(CE, NewAddrSpaceCast, MC,
  3076. sizeInDwords);
  3077. } else if (CE->hasOneUse() && CE->user_back() == MC) {
  3078. continue;
  3079. } else {
  3080. DXASSERT(0, "not implemented");
  3081. }
  3082. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  3083. IRBuilder<> Builder(GEP);
  3084. replaceScalarArrayGEPWithVectorArrayGEP(U, VectorArray, Builder,
  3085. sizeInDwords);
  3086. GEP->eraseFromParent();
  3087. } else {
  3088. DXASSERT(0, "not implemented");
  3089. }
  3090. }
  3091. }
  3092. // For pattern like
  3093. // float4 cb[16];
  3094. // float v[64] = cb;
  3095. bool tryToReplaceCBVec4ArrayToScalarArray(Value *V, Type *TyV, Value *Src,
  3096. Type *TySrc, MemCpyInst *MC,
  3097. const DataLayout &DL) {
  3098. if (!isCBVec4ArrayToScalarArray(TyV, Src, TySrc, DL))
  3099. return false;
  3100. ArrayType *AT = cast<ArrayType>(TyV);
  3101. Type *EltTy = AT->getElementType();
  3102. unsigned sizeInBits = DL.getTypeSizeInBits(EltTy);
  3103. // Convert array of float4 to array of float.
  3104. replaceScalarArrayWithVectorArray(V, Src, MC, sizeInBits >> 5);
  3105. return true;
  3106. }
  3107. } // namespace
  3108. static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
  3109. DxilFieldAnnotation *annotation, DxilTypeSystem &typeSys,
  3110. const DataLayout &DL, DominatorTree *DT) {
  3111. // If the only user of the src and dst is the memcpy,
  3112. // this memcpy was probably produced by splitting another.
  3113. // Regardless, the goal here is to replace, not remove the memcpy
  3114. // we won't have enough information to determine if we can do that before mem2reg
  3115. if (V != Src && V->hasOneUse() && Src->hasOneUse())
  3116. return false;
  3117. // If the memcpy doesn't dominate all its users,
  3118. // full replacement isn't possible without complicated PHI insertion
  3119. // This will likely replace with ld/st which will be replaced in mem2reg
  3120. if (Instruction *SrcI = dyn_cast<Instruction>(Src))
  3121. if (!DominateAllUsers(SrcI, V, DT))
  3122. return false;
  3123. Type *TyV = V->getType()->getPointerElementType();
  3124. Type *TySrc = Src->getType()->getPointerElementType();
  3125. if (Constant *C = dyn_cast<Constant>(V)) {
  3126. updateLifetimeForReplacement(V, Src);
  3127. if (TyV == TySrc) {
  3128. if (isa<Constant>(Src)) {
  3129. V->replaceAllUsesWith(Src);
  3130. } else {
  3131. // Replace Constant with a non-Constant.
  3132. IRBuilder<> Builder(MC);
  3133. ReplaceConstantWithInst(C, Src, Builder);
  3134. }
  3135. } else {
  3136. // Try convert special pattern for cbuffer which copy array of float4 to
  3137. // array of float.
  3138. if (!tryToReplaceCBVec4ArrayToScalarArray(V, TyV, Src, TySrc, MC, DL)) {
  3139. IRBuilder<> Builder(MC);
  3140. Src = Builder.CreateBitCast(Src, V->getType());
  3141. ReplaceConstantWithInst(C, Src, Builder);
  3142. }
  3143. }
  3144. } else {
  3145. if (TyV == TySrc) {
  3146. if (V != Src) {
  3147. updateLifetimeForReplacement(V, Src);
  3148. V->replaceAllUsesWith(Src);
  3149. }
  3150. } else if (!IsUnboundedArrayMemcpy(TyV, TySrc)) {
  3151. Value* DestVal = MC->getRawDest();
  3152. Value* SrcVal = MC->getRawSource();
  3153. if (!isa<BitCastInst>(SrcVal) || !isa<BitCastInst>(DestVal)) {
  3154. DXASSERT(0, "Encountered unexpected instruction sequence");
  3155. return false;
  3156. }
  3157. BitCastInst *DestBCI = cast<BitCastInst>(DestVal);
  3158. BitCastInst *SrcBCI = cast<BitCastInst>(SrcVal);
  3159. Type* DstTy = DestBCI->getSrcTy();
  3160. Type *SrcTy = SrcBCI->getSrcTy();
  3161. if (ArePointersToStructsOfIdenticalLayouts(DstTy, SrcTy)) {
  3162. const DataLayout &DL = SrcBCI->getModule()->getDataLayout();
  3163. unsigned SrcSize = DL.getTypeAllocSize(
  3164. SrcBCI->getOperand(0)->getType()->getPointerElementType());
  3165. unsigned MemcpySize = cast<ConstantInt>(MC->getLength())->getZExtValue();
  3166. if (SrcSize != MemcpySize) {
  3167. DXASSERT(0, "Cannot handle partial memcpy");
  3168. return false;
  3169. }
  3170. if (DestBCI->hasOneUse() && SrcBCI->hasOneUse()) {
  3171. IRBuilder<> Builder(MC);
  3172. StructType *srcStTy = cast<StructType>(
  3173. SrcBCI->getOperand(0)->getType()->getPointerElementType());
  3174. std::vector<unsigned> idxlist = {0};
  3175. CopyElementsOfStructsWithIdenticalLayout(
  3176. Builder, DestBCI->getOperand(0), SrcBCI->getOperand(0), srcStTy,
  3177. idxlist);
  3178. }
  3179. } else {
  3180. if (DstTy == SrcTy) {
  3181. Value *DstPtr = DestBCI->getOperand(0);
  3182. Value *SrcPtr = SrcBCI->getOperand(0);
  3183. if (isa<GEPOperator>(DstPtr) || isa<GEPOperator>(SrcPtr)) {
  3184. MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
  3185. return true;
  3186. } else {
  3187. updateLifetimeForReplacement(V, Src);
  3188. DstPtr->replaceAllUsesWith(SrcPtr);
  3189. }
  3190. } else {
  3191. DXASSERT(0, "Can't handle structs of different layouts");
  3192. return false;
  3193. }
  3194. }
  3195. } else {
  3196. updateLifetimeForReplacement(V, Src);
  3197. DXASSERT(IsUnboundedArrayMemcpy(TyV, TySrc), "otherwise mismatched types in memcpy are not unbounded array");
  3198. ReplaceUnboundedArrayUses(V, Src);
  3199. }
  3200. }
  3201. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Src)) {
  3202. // For const GV, if has stored, mark as non-constant.
  3203. if (GV->isConstant()) {
  3204. hlutil::PointerStatus PS(GV, 0, /*bLdStOnly*/ true);
  3205. PS.analyze(typeSys, /*bStructElt*/ false);
  3206. if (PS.HasStored())
  3207. GV->setConstant(false);
  3208. }
  3209. }
  3210. Value *RawDest = MC->getOperand(0);
  3211. Value *RawSrc = MC->getOperand(1);
  3212. MC->eraseFromParent();
  3213. if (Instruction *I = dyn_cast<Instruction>(RawDest)) {
  3214. if (I->user_empty())
  3215. I->eraseFromParent();
  3216. }
  3217. if (Instruction *I = dyn_cast<Instruction>(RawSrc)) {
  3218. if (I->user_empty())
  3219. I->eraseFromParent();
  3220. }
  3221. return true;
  3222. }
  3223. static bool ReplaceUseOfZeroInitEntry(Instruction *I, Value *V) {
  3224. BasicBlock *BB = I->getParent();
  3225. Function *F = I->getParent()->getParent();
  3226. for (auto U = V->user_begin(); U != V->user_end(); ) {
  3227. Instruction *UI = dyn_cast<Instruction>(*(U++));
  3228. if (!UI)
  3229. continue;
  3230. if (UI->getParent()->getParent() != F)
  3231. continue;
  3232. if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
  3233. if (!ReplaceUseOfZeroInitEntry(I, UI))
  3234. return false;
  3235. else
  3236. continue;
  3237. }
  3238. if (BB != UI->getParent() || UI == I)
  3239. continue;
  3240. // I is the last inst in the block after split.
  3241. // Any inst in current block is before I.
  3242. if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  3243. LI->replaceAllUsesWith(ConstantAggregateZero::get(LI->getType()));
  3244. LI->eraseFromParent();
  3245. continue;
  3246. }
  3247. return false;
  3248. }
  3249. return true;
  3250. }
  3251. static bool ReplaceUseOfZeroInitPostDom(Instruction *I, Value *V,
  3252. PostDominatorTree &PDT) {
  3253. BasicBlock *BB = I->getParent();
  3254. Function *F = I->getParent()->getParent();
  3255. for (auto U = V->user_begin(); U != V->user_end(); ) {
  3256. Instruction *UI = dyn_cast<Instruction>(*(U++));
  3257. if (!UI)
  3258. continue;
  3259. if (UI->getParent()->getParent() != F)
  3260. continue;
  3261. if (!PDT.dominates(BB, UI->getParent()))
  3262. return false;
  3263. if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
  3264. if (!ReplaceUseOfZeroInitPostDom(I, UI, PDT))
  3265. return false;
  3266. else
  3267. continue;
  3268. }
  3269. if (BB != UI->getParent() || UI == I)
  3270. continue;
  3271. // I is the last inst in the block after split.
  3272. // Any inst in current block is before I.
  3273. if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  3274. LI->replaceAllUsesWith(ConstantAggregateZero::get(LI->getType()));
  3275. LI->eraseFromParent();
  3276. continue;
  3277. }
  3278. return false;
  3279. }
  3280. return true;
  3281. }
  3282. // When zero initialized GV has only one define, all uses before the def should
  3283. // use zero.
  3284. static bool ReplaceUseOfZeroInitBeforeDef(Instruction *I, GlobalVariable *GV) {
  3285. BasicBlock *BB = I->getParent();
  3286. Function *F = I->getParent()->getParent();
  3287. // Make sure I is the last inst for BB.
  3288. if (I != BB->getTerminator())
  3289. BB->splitBasicBlock(I->getNextNode());
  3290. if (&F->getEntryBlock() == I->getParent()) {
  3291. return ReplaceUseOfZeroInitEntry(I, GV);
  3292. } else {
  3293. // Post dominator tree.
  3294. PostDominatorTree PDT;
  3295. PDT.runOnFunction(*F);
  3296. return ReplaceUseOfZeroInitPostDom(I, GV, PDT);
  3297. }
  3298. }
  3299. // Use `DT` to trace all users and make sure `I`'s BB dominates them all
  3300. static bool DominateAllUsersDom(Instruction *I, Value *V, DominatorTree *DT) {
  3301. BasicBlock *BB = I->getParent();
  3302. Function *F = I->getParent()->getParent();
  3303. for (auto U = V->user_begin(); U != V->user_end(); ) {
  3304. Instruction *UI = dyn_cast<Instruction>(*(U++));
  3305. // If not an instruction or from a differnt function, nothing to check, move along.
  3306. if (!UI || UI->getParent()->getParent() != F)
  3307. continue;
  3308. if (!DT->dominates(BB, UI->getParent()))
  3309. return false;
  3310. if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
  3311. if (!DominateAllUsersDom(I, UI, DT))
  3312. return false;
  3313. }
  3314. }
  3315. return true;
  3316. }
  3317. // Determine if `I` dominates all the users of `V`
  3318. static bool DominateAllUsers(Instruction *I, Value *V, DominatorTree *DT) {
  3319. Function *F = I->getParent()->getParent();
  3320. // The Entry Block dominates everything, trivially true
  3321. if (&F->getEntryBlock() == I->getParent())
  3322. return true;
  3323. if (!DT) {
  3324. DominatorTree TempDT;
  3325. TempDT.recalculate(*F);
  3326. return DominateAllUsersDom(I, V, &TempDT);
  3327. } else {
  3328. return DominateAllUsersDom(I, V, DT);
  3329. }
  3330. }
  3331. static bool isReadOnlyPtr(CallInst *PtrCI) {
  3332. HLSubscriptOpcode opcode =
  3333. static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
  3334. if (opcode == HLSubscriptOpcode::CBufferSubscript) {
  3335. // Ptr from CBuffer is readonly.
  3336. return true;
  3337. } else if (opcode == HLSubscriptOpcode::DefaultSubscript) {
  3338. Value *ptr = PtrCI->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx);
  3339. // Resource ptr.
  3340. if (CallInst *handleCI = dyn_cast<CallInst>(ptr)) {
  3341. hlsl::HLOpcodeGroup group =
  3342. hlsl::GetHLOpcodeGroup(handleCI->getCalledFunction());
  3343. if (group == HLOpcodeGroup::HLAnnotateHandle) {
  3344. Constant *Props = cast<Constant>(handleCI->getArgOperand(
  3345. HLOperandIndex::kAnnotateHandleResourcePropertiesOpIdx));
  3346. DxilResourceProperties RP = resource_helper::loadPropsFromConstant(*Props);
  3347. if (RP.getResourceClass() == DXIL::ResourceClass::SRV) {
  3348. // Ptr from SRV is readonly.
  3349. return true;
  3350. }
  3351. }
  3352. }
  3353. }
  3354. return false;
  3355. }
  3356. bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
  3357. DxilTypeSystem &typeSys, const DataLayout &DL,
  3358. DominatorTree *DT, bool bAllowReplace) {
  3359. Type *Ty = V->getType();
  3360. if (!Ty->isPointerTy()) {
  3361. return false;
  3362. }
  3363. // Get access status and collect memcpy uses.
  3364. // if MemcpyOnce, replace with dest with src if dest is not out param.
  3365. // else flat memcpy.
  3366. unsigned size = DL.getTypeAllocSize(Ty->getPointerElementType());
  3367. hlutil::PointerStatus PS(V, size, /*bLdStOnly*/ false);
  3368. const bool bStructElt = false;
  3369. PS.analyze(typeSys, bStructElt);
  3370. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
  3371. if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
  3372. if (PS.storedType == hlutil::PointerStatus::StoredType::NotStored) {
  3373. PS.storedType = hlutil::PointerStatus::StoredType::InitializerStored;
  3374. } else if (PS.storedType ==
  3375. hlutil::PointerStatus::StoredType::MemcopyDestOnce) {
  3376. // For single mem store, if the store does not dominate all users.
  3377. // Mark it as Stored.
  3378. // In cases like:
  3379. // struct A { float4 x[25]; };
  3380. // A a;
  3381. // static A a2;
  3382. // void set(A aa) { aa = a; }
  3383. // call set inside entry function then use a2.
  3384. if (isa<ConstantAggregateZero>(GV->getInitializer())) {
  3385. Instruction * Memcpy = PS.StoringMemcpy;
  3386. if (!ReplaceUseOfZeroInitBeforeDef(Memcpy, GV)) {
  3387. PS.storedType = hlutil::PointerStatus::StoredType::Stored;
  3388. }
  3389. }
  3390. } else {
  3391. PS.storedType = hlutil::PointerStatus::StoredType::Stored;
  3392. }
  3393. }
  3394. }
  3395. if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
  3396. if (PS.storedType == hlutil::PointerStatus::StoredType::MemcopyDestOnce &&
  3397. // Skip argument for input argument has input value, it is not dest once anymore.
  3398. !isa<Argument>(V)) {
  3399. // Replace with src of memcpy.
  3400. MemCpyInst *MC = PS.StoringMemcpy;
  3401. if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
  3402. Value *Src = MC->getOperand(1);
  3403. // Only remove one level bitcast generated from inline.
  3404. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  3405. Src = BC->getOperand(0);
  3406. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
  3407. // For GEP, the ptr could have other GEP read/write.
  3408. // Only scan one GEP is not enough.
  3409. Value *Ptr = GEP->getPointerOperand();
  3410. while (GEPOperator *NestedGEP = dyn_cast<GEPOperator>(Ptr))
  3411. Ptr = NestedGEP->getPointerOperand();
  3412. if (CallInst *PtrCI = dyn_cast<CallInst>(Ptr)) {
  3413. hlsl::HLOpcodeGroup group =
  3414. hlsl::GetHLOpcodeGroup(PtrCI->getCalledFunction());
  3415. if (group == HLOpcodeGroup::HLSubscript) {
  3416. if (isReadOnlyPtr(PtrCI)) {
  3417. // Ptr from CBuffer/SRV is safe.
  3418. if (ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL, DT)) {
  3419. if (V->user_empty())
  3420. return true;
  3421. return LowerMemcpy(V, annotation, typeSys, DL, DT, bAllowReplace);
  3422. }
  3423. }
  3424. }
  3425. }
  3426. } else if (!isa<CallInst>(Src)) {
  3427. // Resource ptr should not be replaced.
  3428. // Need to make sure src not updated after current memcpy.
  3429. // Check Src only have 1 store now.
  3430. hlutil::PointerStatus SrcPS(Src, size, /*bLdStOnly*/ false);
  3431. SrcPS.analyze(typeSys, bStructElt);
  3432. if (SrcPS.storedType != hlutil::PointerStatus::StoredType::Stored) {
  3433. if (ReplaceMemcpy(V, Src, MC, annotation, typeSys, DL, DT)) {
  3434. if (V->user_empty())
  3435. return true;
  3436. return LowerMemcpy(V, annotation, typeSys, DL, DT, bAllowReplace);
  3437. }
  3438. }
  3439. }
  3440. }
  3441. } else if (PS.loadedType ==
  3442. hlutil::PointerStatus::LoadedType::MemcopySrcOnce) {
  3443. // Replace dst of memcpy.
  3444. MemCpyInst *MC = PS.LoadingMemcpy;
  3445. if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
  3446. Value *Dest = MC->getOperand(0);
  3447. // Only remove one level bitcast generated from inline.
  3448. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  3449. Dest = BC->getOperand(0);
  3450. // For GEP, the ptr could have other GEP read/write.
  3451. // Only scan one GEP is not enough.
  3452. // And resource ptr should not be replaced.
  3453. if (!isa<GEPOperator>(Dest) && !isa<CallInst>(Dest) &&
  3454. !isa<BitCastOperator>(Dest)) {
  3455. // Need to make sure Dest not updated after current memcpy.
  3456. // Check Dest only have 1 store now.
  3457. hlutil::PointerStatus DestPS(Dest, size, /*bLdStOnly*/ false);
  3458. DestPS.analyze(typeSys, bStructElt);
  3459. if (DestPS.storedType != hlutil::PointerStatus::StoredType::Stored) {
  3460. if (ReplaceMemcpy(Dest, V, MC, annotation, typeSys, DL, DT)) {
  3461. // V still needs to be flattened.
  3462. // Lower memcpy come from Dest.
  3463. return LowerMemcpy(V, annotation, typeSys, DL, DT, bAllowReplace);
  3464. }
  3465. }
  3466. }
  3467. }
  3468. }
  3469. }
  3470. for (MemCpyInst *MC : PS.memcpySet) {
  3471. MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
  3472. }
  3473. return false;
  3474. }
  3475. /// MarkEmptyStructUsers - Add instruction related to Empty struct to DeadInsts.
  3476. void SROA_Helper::MarkEmptyStructUsers(Value *V, SmallVector<Value *, 32> &DeadInsts) {
  3477. UndefValue *undef = UndefValue::get(V->getType());
  3478. for (auto itU = V->user_begin(), E = V->user_end(); itU != E;) {
  3479. Value *U = *(itU++);
  3480. // Kill memcpy, set operands to undef for call and ret, and recurse
  3481. if (MemCpyInst *MC = dyn_cast<MemCpyInst>(U)) {
  3482. DeadInsts.emplace_back(MC);
  3483. } else if (CallInst *CI = dyn_cast<CallInst>(U)) {
  3484. for (auto &operand : CI->operands()) {
  3485. if (operand == V)
  3486. operand.set(undef);
  3487. }
  3488. } else if (ReturnInst *Ret = dyn_cast<ReturnInst>(U)) {
  3489. Ret->setOperand(0, undef);
  3490. } else if (isa<Constant>(U) || isa<GetElementPtrInst>(U) ||
  3491. isa<BitCastInst>(U) || isa<LoadInst>(U) || isa<StoreInst>(U)) {
  3492. // Recurse users
  3493. MarkEmptyStructUsers(U, DeadInsts);
  3494. } else {
  3495. DXASSERT(false, "otherwise, recursing unexpected empty struct user");
  3496. }
  3497. }
  3498. if (Instruction *I = dyn_cast<Instruction>(V)) {
  3499. // Only need to add no use inst here.
  3500. // DeleteDeadInst will delete everything.
  3501. if (I->user_empty())
  3502. DeadInsts.emplace_back(I);
  3503. }
  3504. }
  3505. bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
  3506. if (isa<ArrayType>(Ty))
  3507. Ty = Ty->getArrayElementType();
  3508. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3509. if (!HLMatrixType::isa(Ty)) {
  3510. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  3511. if (SA && SA->IsEmptyStruct())
  3512. return true;
  3513. }
  3514. }
  3515. return false;
  3516. }
  3517. //===----------------------------------------------------------------------===//
  3518. // SROA on function parameters.
  3519. //===----------------------------------------------------------------------===//
  3520. static void LegalizeDxilInputOutputs(Function *F,
  3521. DxilFunctionAnnotation *EntryAnnotation,
  3522. const DataLayout &DL,
  3523. DxilTypeSystem &typeSys);
  3524. static void InjectReturnAfterNoReturnPreserveOutput(HLModule &HLM);
  3525. namespace {
  3526. class SROA_Parameter_HLSL : public ModulePass {
  3527. HLModule *m_pHLModule;
  3528. public:
  3529. static char ID; // Pass identification, replacement for typeid
  3530. explicit SROA_Parameter_HLSL() : ModulePass(ID) {}
  3531. const char *getPassName() const override { return "SROA Parameter HLSL"; }
  3532. static void RewriteBitcastWithIdenticalStructs(Function *F);
  3533. static void RewriteBitcastWithIdenticalStructs(BitCastInst *BCI);
  3534. bool runOnModule(Module &M) override {
  3535. // Patch memcpy to cover case bitcast (gep ptr, 0,0) is transformed into
  3536. // bitcast ptr.
  3537. MemcpySplitter::PatchMemCpyWithZeroIdxGEP(M);
  3538. m_pHLModule = &M.GetOrCreateHLModule();
  3539. const DataLayout &DL = M.getDataLayout();
  3540. // Load up debug information, to cross-reference values and the instructions
  3541. // used to load them.
  3542. m_HasDbgInfo = nullptr != M.getNamedMetadata("llvm.dbg.cu");
  3543. InjectReturnAfterNoReturnPreserveOutput(*m_pHLModule);
  3544. std::deque<Function *> WorkList;
  3545. std::vector<Function *> DeadHLFunctions;
  3546. for (Function &F : M.functions()) {
  3547. HLOpcodeGroup group = GetHLOpcodeGroup(&F);
  3548. // Skip HL operations.
  3549. if (group != HLOpcodeGroup::NotHL ||
  3550. group == HLOpcodeGroup::HLExtIntrinsic) {
  3551. if (F.user_empty())
  3552. DeadHLFunctions.emplace_back(&F);
  3553. continue;
  3554. }
  3555. if (F.isDeclaration()) {
  3556. // Skip llvm intrinsic.
  3557. if (F.isIntrinsic())
  3558. continue;
  3559. // Skip unused external function.
  3560. if (F.user_empty())
  3561. continue;
  3562. }
  3563. // Skip void(void) functions.
  3564. if (F.getReturnType()->isVoidTy() && F.arg_size() == 0)
  3565. continue;
  3566. // Skip library function, except to LegalizeDxilInputOutputs
  3567. if (&F != m_pHLModule->GetEntryFunction() &&
  3568. !m_pHLModule->IsEntryThatUsesSignatures(&F)) {
  3569. if (!F.isDeclaration())
  3570. LegalizeDxilInputOutputs(&F, m_pHLModule->GetFunctionAnnotation(&F),
  3571. DL, m_pHLModule->GetTypeSystem());
  3572. continue;
  3573. }
  3574. WorkList.emplace_back(&F);
  3575. }
  3576. // Remove dead hl functions here.
  3577. // This is for hl functions which has body and always inline.
  3578. for (Function *F : DeadHLFunctions) {
  3579. F->eraseFromParent();
  3580. }
  3581. // Preprocess aggregate function param used as function call arg.
  3582. for (Function *F : WorkList) {
  3583. preprocessArgUsedInCall(F);
  3584. }
  3585. // Process the worklist
  3586. while (!WorkList.empty()) {
  3587. Function *F = WorkList.front();
  3588. WorkList.pop_front();
  3589. RewriteBitcastWithIdenticalStructs(F);
  3590. createFlattenedFunction(F);
  3591. }
  3592. // Replace functions with flattened version when we flat all the functions.
  3593. for (auto Iter : funcMap)
  3594. replaceCall(Iter.first, Iter.second);
  3595. // Update patch constant function.
  3596. for (Function &F : M.functions()) {
  3597. if (F.isDeclaration())
  3598. continue;
  3599. if (!m_pHLModule->HasDxilFunctionProps(&F))
  3600. continue;
  3601. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(&F);
  3602. if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
  3603. Function *oldPatchConstantFunc =
  3604. funcProps.ShaderProps.HS.patchConstantFunc;
  3605. if (funcMap.count(oldPatchConstantFunc))
  3606. m_pHLModule->SetPatchConstantFunctionForHS(&F, funcMap[oldPatchConstantFunc]);
  3607. }
  3608. }
  3609. // Remove flattened functions.
  3610. for (auto Iter : funcMap) {
  3611. Function *F = Iter.first;
  3612. Function *flatF = Iter.second;
  3613. flatF->takeName(F);
  3614. F->eraseFromParent();
  3615. }
  3616. // SROA globals and allocas.
  3617. SROAGlobalAndAllocas(*m_pHLModule, m_HasDbgInfo);
  3618. // Move up allocas that might have been pushed down by instruction inserts
  3619. for (Function &F : M) {
  3620. if (F.isDeclaration())
  3621. continue;
  3622. Instruction *insertPt = nullptr;
  3623. // SROA only potentially "incorrectly" inserts non-allocas into the entry block.
  3624. for (llvm::Instruction &I : F.getEntryBlock()) {
  3625. if (!insertPt) {
  3626. // Find the first non-alloca to move the allocas above
  3627. if (!isa<AllocaInst>(I) && !isa<DbgInfoIntrinsic>(I))
  3628. insertPt = &I;
  3629. } else if (isa<AllocaInst>(I)) {
  3630. // Move any alloca to before the first non-alloca
  3631. I.moveBefore(insertPt);
  3632. }
  3633. }
  3634. }
  3635. return true;
  3636. }
  3637. private:
  3638. void DeleteDeadInstructions();
  3639. void preprocessArgUsedInCall(Function *F);
  3640. void moveFunctionBody(Function *F, Function *flatF);
  3641. void replaceCall(Function *F, Function *flatF);
  3642. void createFlattenedFunction(Function *F);
  3643. void
  3644. flattenArgument(Function *F, Value *Arg, bool bForParam,
  3645. DxilParameterAnnotation &paramAnnotation,
  3646. std::vector<Value *> &FlatParamList,
  3647. std::vector<DxilParameterAnnotation> &FlatRetAnnotationList,
  3648. BasicBlock *EntryBlock, DbgDeclareInst *DDI);
  3649. Value *castResourceArgIfRequired(Value *V, Type *Ty, bool bOut,
  3650. DxilParamInputQual inputQual,
  3651. IRBuilder<> &Builder);
  3652. Value *castArgumentIfRequired(Value *V, Type *Ty, bool bOut,
  3653. DxilParamInputQual inputQual,
  3654. DxilFieldAnnotation &annotation,
  3655. IRBuilder<> &Builder,
  3656. DxilTypeSystem &TypeSys);
  3657. // Replace use of parameter which changed type when flatten.
  3658. // Also add information to Arg if required.
  3659. void replaceCastParameter(Value *NewParam, Value *OldParam, Function &F,
  3660. Argument *Arg, const DxilParamInputQual inputQual,
  3661. IRBuilder<> &Builder);
  3662. void allocateSemanticIndex(
  3663. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  3664. unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap);
  3665. //static std::vector<Value*> GetConstValueIdxList(IRBuilder<>& builder, std::vector<unsigned> idxlist);
  3666. /// DeadInsts - Keep track of instructions we have made dead, so that
  3667. /// we can remove them after we are done working.
  3668. SmallVector<Value *, 32> DeadInsts;
  3669. // Map from orginal function to the flatten version.
  3670. MapVector<Function *, Function *> funcMap; // Need deterministic order of iteration
  3671. // Map from original arg/param to flatten cast version.
  3672. std::unordered_map<Value *, std::pair<Value*, DxilParamInputQual>> castParamMap;
  3673. // Map form first element of a vector the list of all elements of the vector.
  3674. std::unordered_map<Value *, SmallVector<Value*, 4> > vectorEltsMap;
  3675. // Set for row major matrix parameter.
  3676. std::unordered_set<Value *> castRowMajorParamMap;
  3677. bool m_HasDbgInfo;
  3678. };
  3679. // When replacing aggregates by its scalar elements,
  3680. // the first element will preserve the original semantic,
  3681. // and the subsequent ones will temporarily use this value.
  3682. // We then run a pass to fix the semantics and properly renumber them
  3683. // once the aggregate has been fully expanded.
  3684. //
  3685. // For example:
  3686. // struct Foo { float a; float b; };
  3687. // void main(Foo foo : TEXCOORD0, float bar : TEXCOORD0)
  3688. //
  3689. // Will be expanded to
  3690. // void main(float a : TEXCOORD0, float b : *, float bar : TEXCOORD0)
  3691. //
  3692. // And then fixed up to
  3693. // void main(float a : TEXCOORD0, float b : TEXCOORD1, float bar : TEXCOORD0)
  3694. //
  3695. // (which will later on fail validation due to duplicate semantics).
  3696. constexpr const char *ContinuedPseudoSemantic = "*";
  3697. }
  3698. char SROA_Parameter_HLSL::ID = 0;
  3699. INITIALIZE_PASS(SROA_Parameter_HLSL, "scalarrepl-param-hlsl",
  3700. "Scalar Replacement of Aggregates HLSL (parameters)", false,
  3701. false)
  3702. void SROA_Parameter_HLSL::RewriteBitcastWithIdenticalStructs(Function *F) {
  3703. if (F->isDeclaration())
  3704. return;
  3705. // Gather list of bitcast involving src and dest structs with identical layout
  3706. std::vector<BitCastInst*> worklist;
  3707. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  3708. if (BitCastInst *BCI = dyn_cast<BitCastInst>(&*I)) {
  3709. Type *DstTy = BCI->getDestTy();
  3710. Type *SrcTy = BCI->getSrcTy();
  3711. if(ArePointersToStructsOfIdenticalLayouts(DstTy, SrcTy))
  3712. worklist.push_back(BCI);
  3713. }
  3714. }
  3715. // Replace bitcast involving src and dest structs with identical layout
  3716. while (!worklist.empty()) {
  3717. BitCastInst *BCI = worklist.back();
  3718. worklist.pop_back();
  3719. RewriteBitcastWithIdenticalStructs(BCI);
  3720. }
  3721. }
  3722. void SROA_Parameter_HLSL::RewriteBitcastWithIdenticalStructs(BitCastInst *BCI) {
  3723. StructType *srcStTy = cast<StructType>(BCI->getSrcTy()->getPointerElementType());
  3724. StructType *destStTy = cast<StructType>(BCI->getDestTy()->getPointerElementType());
  3725. Value* srcPtr = BCI->getOperand(0);
  3726. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(BCI->getParent()->getParent()));
  3727. AllocaInst *destPtr = AllocaBuilder.CreateAlloca(destStTy);
  3728. IRBuilder<> InstBuilder(BCI);
  3729. std::vector<unsigned> idxlist = { 0 };
  3730. CopyElementsOfStructsWithIdenticalLayout(InstBuilder, destPtr, srcPtr, srcStTy, idxlist);
  3731. BCI->replaceAllUsesWith(destPtr);
  3732. BCI->eraseFromParent();
  3733. }
  3734. /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
  3735. /// recursively including all their operands that become trivially dead.
  3736. void SROA_Parameter_HLSL::DeleteDeadInstructions() {
  3737. while (!DeadInsts.empty()) {
  3738. Instruction *I = cast<Instruction>(DeadInsts.pop_back_val());
  3739. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI)
  3740. if (Instruction *U = dyn_cast<Instruction>(*OI)) {
  3741. // Zero out the operand and see if it becomes trivially dead.
  3742. // (But, don't add allocas to the dead instruction list -- they are
  3743. // already on the worklist and will be deleted separately.)
  3744. *OI = nullptr;
  3745. if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U))
  3746. DeadInsts.push_back(U);
  3747. }
  3748. I->eraseFromParent();
  3749. }
  3750. }
  3751. static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAnnotation &annotation, DxilTypeSystem &dxilTypeSys) {
  3752. while (Ty->isArrayTy())
  3753. Ty = Ty->getArrayElementType();
  3754. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3755. if (HLMatrixType::isa(Ty))
  3756. return annotation;
  3757. DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
  3758. if (SA) {
  3759. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(idx);
  3760. return FA;
  3761. }
  3762. }
  3763. return annotation;
  3764. }
  3765. // Note: Semantic index allocation.
  3766. // Semantic index is allocated base on linear layout.
  3767. // For following code
  3768. /*
  3769. struct S {
  3770. float4 m;
  3771. float4 m2;
  3772. };
  3773. S s[2] : semantic;
  3774. struct S2 {
  3775. float4 m[2];
  3776. float4 m2[2];
  3777. };
  3778. S2 s2 : semantic;
  3779. */
  3780. // The semantic index is like this:
  3781. // s[0].m : semantic0
  3782. // s[0].m2 : semantic1
  3783. // s[1].m : semantic2
  3784. // s[1].m2 : semantic3
  3785. // s2.m[0] : semantic0
  3786. // s2.m[1] : semantic1
  3787. // s2.m2[0] : semantic2
  3788. // s2.m2[1] : semantic3
  3789. // But when flatten argument, the result is like this:
  3790. // float4 s_m[2], float4 s_m2[2].
  3791. // float4 s2_m[2], float4 s2_m2[2].
  3792. // To do the allocation, need to map from each element to its flattened argument.
  3793. // Say arg index of float4 s_m[2] is 0, float4 s_m2[2] is 1.
  3794. // Need to get 0 from s[0].m and s[1].m, get 1 from s[0].m2 and s[1].m2.
  3795. // Allocate the argments with same semantic string from type where the
  3796. // semantic starts( S2 for s2.m[2] and s2.m2[2]).
  3797. // Iterate each elements of the type, save the semantic index and update it.
  3798. // The map from element to the arg ( s[0].m2 -> s.m2[2]) is done by argIdx.
  3799. // ArgIdx only inc by 1 when finish a struct field.
  3800. static unsigned AllocateSemanticIndex(
  3801. Type *Ty, unsigned &semIndex, unsigned argIdx, unsigned endArgIdx,
  3802. std::vector<DxilParameterAnnotation> &FlatAnnotationList) {
  3803. if (Ty->isPointerTy()) {
  3804. return AllocateSemanticIndex(Ty->getPointerElementType(), semIndex, argIdx,
  3805. endArgIdx, FlatAnnotationList);
  3806. } else if (Ty->isArrayTy()) {
  3807. unsigned arraySize = Ty->getArrayNumElements();
  3808. unsigned updatedArgIdx = argIdx;
  3809. Type *EltTy = Ty->getArrayElementType();
  3810. for (unsigned i = 0; i < arraySize; i++) {
  3811. updatedArgIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
  3812. FlatAnnotationList);
  3813. }
  3814. return updatedArgIdx;
  3815. } else if (Ty->isStructTy() && !HLMatrixType::isa(Ty)) {
  3816. unsigned fieldsCount = Ty->getStructNumElements();
  3817. for (unsigned i = 0; i < fieldsCount; i++) {
  3818. Type *EltTy = Ty->getStructElementType(i);
  3819. argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
  3820. FlatAnnotationList);
  3821. if (!(EltTy->isStructTy() && !HLMatrixType::isa(EltTy))) {
  3822. // Update argIdx only when it is a leaf node.
  3823. argIdx++;
  3824. }
  3825. }
  3826. return argIdx;
  3827. } else {
  3828. DXASSERT(argIdx < endArgIdx, "arg index out of bound");
  3829. DxilParameterAnnotation &paramAnnotation = FlatAnnotationList[argIdx];
  3830. // Get element size.
  3831. unsigned rows = 1;
  3832. if (paramAnnotation.HasMatrixAnnotation()) {
  3833. const DxilMatrixAnnotation &matrix =
  3834. paramAnnotation.GetMatrixAnnotation();
  3835. if (matrix.Orientation == MatrixOrientation::RowMajor) {
  3836. rows = matrix.Rows;
  3837. } else {
  3838. DXASSERT_NOMSG(matrix.Orientation == MatrixOrientation::ColumnMajor);
  3839. rows = matrix.Cols;
  3840. }
  3841. }
  3842. // Save semIndex.
  3843. for (unsigned i = 0; i < rows; i++)
  3844. paramAnnotation.AppendSemanticIndex(semIndex + i);
  3845. // Update semIndex.
  3846. semIndex += rows;
  3847. return argIdx;
  3848. }
  3849. }
  3850. void SROA_Parameter_HLSL::allocateSemanticIndex(
  3851. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  3852. unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap) {
  3853. unsigned endArgIndex = FlatAnnotationList.size();
  3854. // Allocate semantic index.
  3855. for (unsigned i = startArgIndex; i < endArgIndex; ++i) {
  3856. // Group by semantic names.
  3857. DxilParameterAnnotation &flatParamAnnotation = FlatAnnotationList[i];
  3858. const std::string &semantic = flatParamAnnotation.GetSemanticString();
  3859. // If semantic is undefined, an error will be emitted elsewhere. For now,
  3860. // we should avoid asserting.
  3861. if (semantic.empty())
  3862. continue;
  3863. StringRef baseSemName; // The 'FOO' in 'FOO1'.
  3864. uint32_t semIndex; // The '1' in 'FOO1'
  3865. // Split semName and index.
  3866. Semantic::DecomposeNameAndIndex(semantic, &baseSemName, &semIndex);
  3867. unsigned semGroupEnd = i + 1;
  3868. while (semGroupEnd < endArgIndex &&
  3869. FlatAnnotationList[semGroupEnd].GetSemanticString() == ContinuedPseudoSemantic) {
  3870. FlatAnnotationList[semGroupEnd].SetSemanticString(baseSemName);
  3871. ++semGroupEnd;
  3872. }
  3873. DXASSERT(semanticTypeMap.count(semantic) > 0, "Must has semantic type");
  3874. Type *semanticTy = semanticTypeMap[semantic];
  3875. AllocateSemanticIndex(semanticTy, semIndex, /*argIdx*/ i,
  3876. /*endArgIdx*/ semGroupEnd, FlatAnnotationList);
  3877. // Update i.
  3878. i = semGroupEnd - 1;
  3879. }
  3880. }
  3881. //
  3882. // Cast parameters.
  3883. //
  3884. static void CopyHandleToResourcePtr(Value *Handle, Value *ResPtr, HLModule &HLM,
  3885. IRBuilder<> &Builder) {
  3886. // Cast it to resource.
  3887. Type *ResTy = ResPtr->getType()->getPointerElementType();
  3888. Value *Res = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3889. (unsigned)HLCastOpcode::HandleToResCast,
  3890. ResTy, {Handle}, *HLM.GetModule());
  3891. // Store casted resource to OldArg.
  3892. Builder.CreateStore(Res, ResPtr);
  3893. }
  3894. static void CopyHandlePtrToResourcePtr(Value *HandlePtr, Value *ResPtr,
  3895. HLModule &HLM, IRBuilder<> &Builder) {
  3896. // Load the handle.
  3897. Value *Handle = Builder.CreateLoad(HandlePtr);
  3898. CopyHandleToResourcePtr(Handle, ResPtr, HLM, Builder);
  3899. }
  3900. static Value *CastResourcePtrToHandle(Value *Res, Type *HandleTy, HLModule &HLM,
  3901. IRBuilder<> &Builder) {
  3902. // Load OldArg.
  3903. Value *LdRes = Builder.CreateLoad(Res);
  3904. Value *Handle = HLM.EmitHLOperationCall(
  3905. Builder, HLOpcodeGroup::HLCreateHandle,
  3906. /*opcode*/ 0, HandleTy, {LdRes}, *HLM.GetModule());
  3907. return Handle;
  3908. }
  3909. static void CopyResourcePtrToHandlePtr(Value *Res, Value *HandlePtr,
  3910. HLModule &HLM, IRBuilder<> &Builder) {
  3911. Type *HandleTy = HandlePtr->getType()->getPointerElementType();
  3912. Value *Handle = CastResourcePtrToHandle(Res, HandleTy, HLM, Builder);
  3913. Builder.CreateStore(Handle, HandlePtr);
  3914. }
  3915. static void CopyVectorPtrToEltsPtr(Value *VecPtr, ArrayRef<Value *> elts,
  3916. unsigned vecSize, IRBuilder<> &Builder) {
  3917. Value *Vec = Builder.CreateLoad(VecPtr);
  3918. for (unsigned i = 0; i < vecSize; i++) {
  3919. Value *Elt = Builder.CreateExtractElement(Vec, i);
  3920. Builder.CreateStore(Elt, elts[i]);
  3921. }
  3922. }
  3923. static void CopyEltsPtrToVectorPtr(ArrayRef<Value *> elts, Value *VecPtr,
  3924. Type *VecTy, unsigned vecSize,
  3925. IRBuilder<> &Builder) {
  3926. Value *Vec = UndefValue::get(VecTy);
  3927. for (unsigned i = 0; i < vecSize; i++) {
  3928. Value *Elt = Builder.CreateLoad(elts[i]);
  3929. Vec = Builder.CreateInsertElement(Vec, Elt, i);
  3930. }
  3931. Builder.CreateStore(Vec, VecPtr);
  3932. }
  3933. static void CopyMatToArrayPtr(Value *Mat, Value *ArrayPtr,
  3934. unsigned arrayBaseIdx, HLModule &HLM,
  3935. IRBuilder<> &Builder, bool bRowMajor) {
  3936. // Mat val is row major.
  3937. HLMatrixType MatTy = HLMatrixType::cast(Mat->getType());
  3938. Type *VecTy = MatTy.getLoweredVectorTypeForReg();
  3939. Value *Vec =
  3940. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3941. (unsigned)HLCastOpcode::RowMatrixToVecCast, VecTy,
  3942. {Mat}, *HLM.GetModule());
  3943. Value *zero = Builder.getInt32(0);
  3944. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  3945. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  3946. unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
  3947. Value *Elt = Builder.CreateExtractElement(Vec, matIdx);
  3948. Value *Ptr = Builder.CreateInBoundsGEP(
  3949. ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
  3950. Builder.CreateStore(Elt, Ptr);
  3951. }
  3952. }
  3953. }
  3954. static void CopyMatPtrToArrayPtr(Value *MatPtr, Value *ArrayPtr,
  3955. unsigned arrayBaseIdx, HLModule &HLM,
  3956. IRBuilder<> &Builder, bool bRowMajor) {
  3957. Type *Ty = MatPtr->getType()->getPointerElementType();
  3958. Value *Mat = nullptr;
  3959. if (bRowMajor) {
  3960. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3961. (unsigned)HLMatLoadStoreOpcode::RowMatLoad,
  3962. Ty, {MatPtr}, *HLM.GetModule());
  3963. } else {
  3964. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3965. (unsigned)HLMatLoadStoreOpcode::ColMatLoad,
  3966. Ty, {MatPtr}, *HLM.GetModule());
  3967. // Matrix value should be row major.
  3968. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3969. (unsigned)HLCastOpcode::ColMatrixToRowMatrix,
  3970. Ty, {Mat}, *HLM.GetModule());
  3971. }
  3972. CopyMatToArrayPtr(Mat, ArrayPtr, arrayBaseIdx, HLM, Builder, bRowMajor);
  3973. }
  3974. static Value *LoadArrayPtrToMat(Value *ArrayPtr, unsigned arrayBaseIdx,
  3975. Type *Ty, HLModule &HLM, IRBuilder<> &Builder,
  3976. bool bRowMajor) {
  3977. HLMatrixType MatTy = HLMatrixType::cast(Ty);
  3978. // HLInit operands are in row major.
  3979. SmallVector<Value *, 16> Elts;
  3980. Value *zero = Builder.getInt32(0);
  3981. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  3982. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  3983. unsigned matIdx = bRowMajor
  3984. ? MatTy.getRowMajorIndex(r, c)
  3985. : MatTy.getColumnMajorIndex(r, c);
  3986. Value *Ptr = Builder.CreateInBoundsGEP(
  3987. ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
  3988. Value *Elt = Builder.CreateLoad(Ptr);
  3989. Elts.emplace_back(Elt);
  3990. }
  3991. }
  3992. return HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLInit,
  3993. /*opcode*/ 0, Ty, {Elts}, *HLM.GetModule());
  3994. }
  3995. static void CopyArrayPtrToMatPtr(Value *ArrayPtr, unsigned arrayBaseIdx,
  3996. Value *MatPtr, HLModule &HLM,
  3997. IRBuilder<> &Builder, bool bRowMajor) {
  3998. Type *Ty = MatPtr->getType()->getPointerElementType();
  3999. Value *Mat =
  4000. LoadArrayPtrToMat(ArrayPtr, arrayBaseIdx, Ty, HLM, Builder, bRowMajor);
  4001. if (bRowMajor) {
  4002. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  4003. (unsigned)HLMatLoadStoreOpcode::RowMatStore, Ty,
  4004. {MatPtr, Mat}, *HLM.GetModule());
  4005. } else {
  4006. // Mat is row major.
  4007. // Cast it to col major before store.
  4008. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  4009. (unsigned)HLCastOpcode::RowMatrixToColMatrix,
  4010. Ty, {Mat}, *HLM.GetModule());
  4011. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  4012. (unsigned)HLMatLoadStoreOpcode::ColMatStore, Ty,
  4013. {MatPtr, Mat}, *HLM.GetModule());
  4014. }
  4015. }
  4016. using CopyFunctionTy = void(Value *FromPtr, Value *ToPtr, HLModule &HLM,
  4017. Type *HandleTy, IRBuilder<> &Builder,
  4018. bool bRowMajor);
  4019. static void
  4020. CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
  4021. std::vector<Value *> &idxList, unsigned calcIdx,
  4022. Type *HandleTy, HLModule &HLM, IRBuilder<> &Builder,
  4023. CopyFunctionTy CastCopyFn, bool bRowMajor) {
  4024. if (CurFromTy->isVectorTy()) {
  4025. // Copy vector to array.
  4026. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  4027. Value *V = Builder.CreateLoad(FromPtr);
  4028. unsigned vecSize = CurFromTy->getVectorNumElements();
  4029. Value *zeroIdx = Builder.getInt32(0);
  4030. for (unsigned i = 0; i < vecSize; i++) {
  4031. Value *ToPtr = Builder.CreateInBoundsGEP(
  4032. ToArray, {zeroIdx, Builder.getInt32(calcIdx++)});
  4033. Value *Elt = Builder.CreateExtractElement(V, i);
  4034. Builder.CreateStore(Elt, ToPtr);
  4035. }
  4036. } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(CurFromTy)) {
  4037. // Copy matrix to array.
  4038. // Calculate the offset.
  4039. unsigned offset = calcIdx * MatTy.getNumElements();
  4040. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  4041. CopyMatPtrToArrayPtr(FromPtr, ToArray, offset, HLM, Builder, bRowMajor);
  4042. } else if (!CurFromTy->isArrayTy()) {
  4043. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  4044. Value *ToPtr = Builder.CreateInBoundsGEP(
  4045. ToArray, {Builder.getInt32(0), Builder.getInt32(calcIdx)});
  4046. CastCopyFn(FromPtr, ToPtr, HLM, HandleTy, Builder, bRowMajor);
  4047. } else {
  4048. unsigned size = CurFromTy->getArrayNumElements();
  4049. Type *FromEltTy = CurFromTy->getArrayElementType();
  4050. for (unsigned i = 0; i < size; i++) {
  4051. idxList.push_back(Builder.getInt32(i));
  4052. unsigned idx = calcIdx * size + i;
  4053. CastCopyArrayMultiDimTo1Dim(FromArray, ToArray, FromEltTy, idxList, idx,
  4054. HandleTy, HLM, Builder, CastCopyFn,
  4055. bRowMajor);
  4056. idxList.pop_back();
  4057. }
  4058. }
  4059. }
  4060. static void
  4061. CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
  4062. std::vector<Value *> &idxList, unsigned calcIdx,
  4063. Type *HandleTy, HLModule &HLM, IRBuilder<> &Builder,
  4064. CopyFunctionTy CastCopyFn, bool bRowMajor) {
  4065. if (CurToTy->isVectorTy()) {
  4066. // Copy array to vector.
  4067. Value *V = UndefValue::get(CurToTy);
  4068. unsigned vecSize = CurToTy->getVectorNumElements();
  4069. // Calculate the offset.
  4070. unsigned offset = calcIdx * vecSize;
  4071. Value *zeroIdx = Builder.getInt32(0);
  4072. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4073. for (unsigned i = 0; i < vecSize; i++) {
  4074. Value *FromPtr = Builder.CreateInBoundsGEP(
  4075. FromArray, {zeroIdx, Builder.getInt32(offset++)});
  4076. Value *Elt = Builder.CreateLoad(FromPtr);
  4077. V = Builder.CreateInsertElement(V, Elt, i);
  4078. }
  4079. Builder.CreateStore(V, ToPtr);
  4080. } else if (HLMatrixType MatTy = HLMatrixType::cast(CurToTy)) {
  4081. // Copy array to matrix.
  4082. // Calculate the offset.
  4083. unsigned offset = calcIdx * MatTy.getNumElements();
  4084. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4085. CopyArrayPtrToMatPtr(FromArray, offset, ToPtr, HLM, Builder, bRowMajor);
  4086. } else if (!CurToTy->isArrayTy()) {
  4087. Value *FromPtr = Builder.CreateInBoundsGEP(
  4088. FromArray, {Builder.getInt32(0), Builder.getInt32(calcIdx)});
  4089. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4090. CastCopyFn(FromPtr, ToPtr, HLM, HandleTy, Builder, bRowMajor);
  4091. } else {
  4092. unsigned size = CurToTy->getArrayNumElements();
  4093. Type *ToEltTy = CurToTy->getArrayElementType();
  4094. for (unsigned i = 0; i < size; i++) {
  4095. idxList.push_back(Builder.getInt32(i));
  4096. unsigned idx = calcIdx * size + i;
  4097. CastCopyArray1DimToMultiDim(FromArray, ToArray, ToEltTy, idxList, idx,
  4098. HandleTy, HLM, Builder, CastCopyFn,
  4099. bRowMajor);
  4100. idxList.pop_back();
  4101. }
  4102. }
  4103. }
  4104. static void CastCopyOldPtrToNewPtr(Value *OldPtr, Value *NewPtr, HLModule &HLM,
  4105. Type *HandleTy, IRBuilder<> &Builder,
  4106. bool bRowMajor) {
  4107. Type *NewTy = NewPtr->getType()->getPointerElementType();
  4108. Type *OldTy = OldPtr->getType()->getPointerElementType();
  4109. if (NewTy == HandleTy) {
  4110. CopyResourcePtrToHandlePtr(OldPtr, NewPtr, HLM, Builder);
  4111. } else if (OldTy->isVectorTy()) {
  4112. // Copy vector to array.
  4113. Value *V = Builder.CreateLoad(OldPtr);
  4114. unsigned vecSize = OldTy->getVectorNumElements();
  4115. Value *zeroIdx = Builder.getInt32(0);
  4116. for (unsigned i = 0; i < vecSize; i++) {
  4117. Value *EltPtr = Builder.CreateGEP(NewPtr, {zeroIdx, Builder.getInt32(i)});
  4118. Value *Elt = Builder.CreateExtractElement(V, i);
  4119. Builder.CreateStore(Elt, EltPtr);
  4120. }
  4121. } else if (HLMatrixType::isa(OldTy)) {
  4122. CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
  4123. bRowMajor);
  4124. } else if (OldTy->isArrayTy()) {
  4125. std::vector<Value *> idxList;
  4126. idxList.emplace_back(Builder.getInt32(0));
  4127. CastCopyArrayMultiDimTo1Dim(OldPtr, NewPtr, OldTy, idxList, /*calcIdx*/ 0,
  4128. HandleTy, HLM, Builder, CastCopyOldPtrToNewPtr,
  4129. bRowMajor);
  4130. }
  4131. }
  4132. static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
  4133. Type *HandleTy, IRBuilder<> &Builder,
  4134. bool bRowMajor) {
  4135. Type *NewTy = NewPtr->getType()->getPointerElementType();
  4136. Type *OldTy = OldPtr->getType()->getPointerElementType();
  4137. if (NewTy == HandleTy) {
  4138. CopyHandlePtrToResourcePtr(NewPtr, OldPtr, HLM, Builder);
  4139. } else if (OldTy->isVectorTy()) {
  4140. // Copy array to vector.
  4141. Value *V = UndefValue::get(OldTy);
  4142. unsigned vecSize = OldTy->getVectorNumElements();
  4143. Value *zeroIdx = Builder.getInt32(0);
  4144. for (unsigned i = 0; i < vecSize; i++) {
  4145. Value *EltPtr = Builder.CreateGEP(NewPtr, {zeroIdx, Builder.getInt32(i)});
  4146. Value *Elt = Builder.CreateLoad(EltPtr);
  4147. V = Builder.CreateInsertElement(V, Elt, i);
  4148. }
  4149. Builder.CreateStore(V, OldPtr);
  4150. } else if (HLMatrixType::isa(OldTy)) {
  4151. CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
  4152. bRowMajor);
  4153. } else if (OldTy->isArrayTy()) {
  4154. std::vector<Value *> idxList;
  4155. idxList.emplace_back(Builder.getInt32(0));
  4156. CastCopyArray1DimToMultiDim(NewPtr, OldPtr, OldTy, idxList, /*calcIdx*/ 0,
  4157. HandleTy, HLM, Builder, CastCopyNewPtrToOldPtr,
  4158. bRowMajor);
  4159. }
  4160. }
  4161. void SROA_Parameter_HLSL::replaceCastParameter(
  4162. Value *NewParam, Value *OldParam, Function &F, Argument *Arg,
  4163. const DxilParamInputQual inputQual, IRBuilder<> &Builder) {
  4164. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4165. Type *NewTy = NewParam->getType();
  4166. Type *OldTy = OldParam->getType();
  4167. bool bIn = inputQual == DxilParamInputQual::Inout ||
  4168. inputQual == DxilParamInputQual::In;
  4169. bool bOut = inputQual == DxilParamInputQual::Inout ||
  4170. inputQual == DxilParamInputQual::Out;
  4171. // Make sure InsertPoint after OldParam inst.
  4172. if (Instruction *I = dyn_cast<Instruction>(OldParam)) {
  4173. Builder.SetInsertPoint(I->getNextNode());
  4174. }
  4175. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(OldParam)) {
  4176. // Add debug info to new param.
  4177. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  4178. DIExpression *DDIExp = DDI->getExpression();
  4179. DIB.insertDeclare(NewParam, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  4180. Builder.GetInsertPoint());
  4181. }
  4182. if (isa<Argument>(OldParam) && OldTy->isPointerTy()) {
  4183. // OldParam will be removed with Old function.
  4184. // Create alloca to replace it.
  4185. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(&F));
  4186. Value *AllocParam = AllocaBuilder.CreateAlloca(OldTy->getPointerElementType());
  4187. OldParam->replaceAllUsesWith(AllocParam);
  4188. OldParam = AllocParam;
  4189. }
  4190. if (NewTy == HandleTy) {
  4191. CopyHandleToResourcePtr(NewParam, OldParam, *m_pHLModule, Builder);
  4192. } else if (vectorEltsMap.count(NewParam)) {
  4193. // Vector is flattened to scalars.
  4194. Type *VecTy = OldTy;
  4195. if (VecTy->isPointerTy())
  4196. VecTy = VecTy->getPointerElementType();
  4197. // Flattened vector.
  4198. SmallVector<Value *, 4> &elts = vectorEltsMap[NewParam];
  4199. unsigned vecSize = elts.size();
  4200. if (NewTy->isPointerTy()) {
  4201. if (bIn) {
  4202. // Copy NewParam to OldParam at entry.
  4203. CopyEltsPtrToVectorPtr(elts, OldParam, VecTy, vecSize, Builder);
  4204. }
  4205. // bOut must be true here.
  4206. // Store the OldParam to NewParam before every return.
  4207. for (auto &BB : F.getBasicBlockList()) {
  4208. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4209. IRBuilder<> RetBuilder(RI);
  4210. CopyVectorPtrToEltsPtr(OldParam, elts, vecSize, RetBuilder);
  4211. }
  4212. }
  4213. } else {
  4214. // Must be in parameter.
  4215. // Copy NewParam to OldParam at entry.
  4216. Value *Vec = UndefValue::get(VecTy);
  4217. for (unsigned i = 0; i < vecSize; i++) {
  4218. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  4219. }
  4220. if (OldTy->isPointerTy()) {
  4221. Builder.CreateStore(Vec, OldParam);
  4222. } else {
  4223. OldParam->replaceAllUsesWith(Vec);
  4224. }
  4225. }
  4226. // Don't need elts anymore.
  4227. vectorEltsMap.erase(NewParam);
  4228. } else if (!NewTy->isPointerTy()) {
  4229. // Ptr param is cast to non-ptr param.
  4230. // Must be in param.
  4231. // Store NewParam to OldParam at entry.
  4232. Builder.CreateStore(NewParam, OldParam);
  4233. } else if (HLMatrixType::isa(OldTy)) {
  4234. bool bRowMajor = castRowMajorParamMap.count(NewParam);
  4235. Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
  4236. *m_pHLModule, Builder, bRowMajor);
  4237. OldParam->replaceAllUsesWith(Mat);
  4238. } else {
  4239. bool bRowMajor = castRowMajorParamMap.count(NewParam);
  4240. // NewTy is pointer type.
  4241. if (bIn) {
  4242. // Copy NewParam to OldParam at entry.
  4243. CastCopyNewPtrToOldPtr(NewParam, OldParam, *m_pHLModule, HandleTy,
  4244. Builder, bRowMajor);
  4245. }
  4246. if (bOut) {
  4247. // Store the OldParam to NewParam before every return.
  4248. for (auto &BB : F.getBasicBlockList()) {
  4249. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4250. IRBuilder<> RetBuilder(RI);
  4251. CastCopyOldPtrToNewPtr(OldParam, NewParam, *m_pHLModule, HandleTy,
  4252. RetBuilder, bRowMajor);
  4253. }
  4254. }
  4255. }
  4256. }
  4257. }
  4258. Value *SROA_Parameter_HLSL::castResourceArgIfRequired(
  4259. Value *V, Type *Ty, bool bOut,
  4260. DxilParamInputQual inputQual,
  4261. IRBuilder<> &Builder) {
  4262. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4263. Module &M = *m_pHLModule->GetModule();
  4264. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
  4265. // Lower resource type to handle ty.
  4266. if (dxilutil::IsHLSLResourceType(Ty)) {
  4267. Value *Res = V;
  4268. if (!bOut) {
  4269. Value *LdRes = Builder.CreateLoad(Res);
  4270. V = m_pHLModule->EmitHLOperationCall(Builder,
  4271. HLOpcodeGroup::HLCreateHandle,
  4272. /*opcode*/ 0, HandleTy, { LdRes }, M);
  4273. }
  4274. else {
  4275. V = AllocaBuilder.CreateAlloca(HandleTy);
  4276. }
  4277. castParamMap[V] = std::make_pair(Res, inputQual);
  4278. }
  4279. else if (Ty->isArrayTy()) {
  4280. unsigned arraySize = 1;
  4281. Type *AT = Ty;
  4282. while (AT->isArrayTy()) {
  4283. arraySize *= AT->getArrayNumElements();
  4284. AT = AT->getArrayElementType();
  4285. }
  4286. if (dxilutil::IsHLSLResourceType(AT)) {
  4287. Value *Res = V;
  4288. Type *Ty = ArrayType::get(HandleTy, arraySize);
  4289. V = AllocaBuilder.CreateAlloca(Ty);
  4290. castParamMap[V] = std::make_pair(Res, inputQual);
  4291. }
  4292. }
  4293. return V;
  4294. }
  4295. Value *SROA_Parameter_HLSL::castArgumentIfRequired(
  4296. Value *V, Type *Ty, bool bOut,
  4297. DxilParamInputQual inputQual, DxilFieldAnnotation &annotation,
  4298. IRBuilder<> &Builder,
  4299. DxilTypeSystem &TypeSys) {
  4300. Module &M = *m_pHLModule->GetModule();
  4301. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
  4302. if (inputQual == DxilParamInputQual::InPayload) {
  4303. DXASSERT_NOMSG(isa<StructType>(Ty));
  4304. // Lower payload type here
  4305. StructType *LoweredTy = GetLoweredUDT(cast<StructType>(Ty), &TypeSys);
  4306. if (LoweredTy != Ty) {
  4307. Value *Ptr = AllocaBuilder.CreateAlloca(LoweredTy);
  4308. ReplaceUsesForLoweredUDT(V, Ptr);
  4309. castParamMap[V] = std::make_pair(Ptr, inputQual);
  4310. V = Ptr;
  4311. }
  4312. return V;
  4313. }
  4314. // Remove pointer for vector/scalar which is not out.
  4315. if (V->getType()->isPointerTy() && !Ty->isAggregateType() && !bOut) {
  4316. Value *Ptr = AllocaBuilder.CreateAlloca(Ty);
  4317. V->replaceAllUsesWith(Ptr);
  4318. // Create load here to make correct type.
  4319. // The Ptr will be store with correct value in replaceCastParameter.
  4320. if (Ptr->hasOneUse()) {
  4321. // Load after existing user for call arg replace.
  4322. // If not, call arg will load undef.
  4323. // This will not hurt parameter, new load is only after first load.
  4324. // It still before all the load users.
  4325. Instruction *User = cast<Instruction>(*(Ptr->user_begin()));
  4326. IRBuilder<> CallBuilder(User->getNextNode());
  4327. V = CallBuilder.CreateLoad(Ptr);
  4328. } else {
  4329. V = Builder.CreateLoad(Ptr);
  4330. }
  4331. castParamMap[V] = std::make_pair(Ptr, inputQual);
  4332. }
  4333. V = castResourceArgIfRequired(V, Ty, bOut, inputQual, Builder);
  4334. // Entry function matrix value parameter has major.
  4335. // Make sure its user use row major matrix value.
  4336. bool updateToColMajor = annotation.HasMatrixAnnotation() &&
  4337. annotation.GetMatrixAnnotation().Orientation ==
  4338. MatrixOrientation::ColumnMajor;
  4339. if (updateToColMajor) {
  4340. if (V->getType()->isPointerTy()) {
  4341. for (User *user : V->users()) {
  4342. CallInst *CI = dyn_cast<CallInst>(user);
  4343. if (!CI)
  4344. continue;
  4345. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  4346. if (group != HLOpcodeGroup::HLMatLoadStore)
  4347. continue;
  4348. HLMatLoadStoreOpcode opcode =
  4349. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  4350. Type *opcodeTy = Builder.getInt32Ty();
  4351. switch (opcode) {
  4352. case HLMatLoadStoreOpcode::RowMatLoad: {
  4353. // Update matrix function opcode to col major version.
  4354. Value *rowOpArg = ConstantInt::get(
  4355. opcodeTy,
  4356. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad));
  4357. CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
  4358. // Cast it to row major.
  4359. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4360. Builder, HLOpcodeGroup::HLCast,
  4361. (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {CI}, M);
  4362. CI->replaceAllUsesWith(RowMat);
  4363. // Set arg to CI again.
  4364. RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, CI);
  4365. } break;
  4366. case HLMatLoadStoreOpcode::RowMatStore:
  4367. // Update matrix function opcode to col major version.
  4368. Value *rowOpArg = ConstantInt::get(
  4369. opcodeTy,
  4370. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore));
  4371. CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
  4372. Value *Mat = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  4373. // Cast it to col major.
  4374. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4375. Builder, HLOpcodeGroup::HLCast,
  4376. (unsigned)HLCastOpcode::RowMatrixToColMatrix, Ty, {Mat}, M);
  4377. CI->setArgOperand(HLOperandIndex::kMatStoreValOpIdx, RowMat);
  4378. break;
  4379. }
  4380. }
  4381. } else {
  4382. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4383. Builder, HLOpcodeGroup::HLCast,
  4384. (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {V}, M);
  4385. V->replaceAllUsesWith(RowMat);
  4386. // Set arg to V again.
  4387. RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, V);
  4388. }
  4389. }
  4390. return V;
  4391. }
  4392. struct AnnotatedValue {
  4393. llvm::Value *Value;
  4394. DxilFieldAnnotation Annotation;
  4395. };
  4396. void SROA_Parameter_HLSL::flattenArgument(
  4397. Function *F, Value *Arg, bool bForParam,
  4398. DxilParameterAnnotation &paramAnnotation,
  4399. std::vector<Value *> &FlatParamList,
  4400. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  4401. BasicBlock *EntryBlock, DbgDeclareInst *DDI) {
  4402. std::deque<AnnotatedValue> WorkList;
  4403. WorkList.push_back({ Arg, paramAnnotation });
  4404. unsigned startArgIndex = FlatAnnotationList.size();
  4405. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  4406. const std::string &semantic = paramAnnotation.GetSemanticString();
  4407. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  4408. bool bOut = inputQual == DxilParamInputQual::Out ||
  4409. inputQual == DxilParamInputQual::Inout ||
  4410. inputQual == DxilParamInputQual::OutStream0 ||
  4411. inputQual == DxilParamInputQual::OutStream1 ||
  4412. inputQual == DxilParamInputQual::OutStream2 ||
  4413. inputQual == DxilParamInputQual::OutStream3;
  4414. // Map from semantic string to type.
  4415. llvm::StringMap<Type *> semanticTypeMap;
  4416. // Original semantic type.
  4417. if (!semantic.empty()) {
  4418. // Unwrap top-level array if primitive
  4419. if (inputQual == DxilParamInputQual::InputPatch ||
  4420. inputQual == DxilParamInputQual::OutputPatch ||
  4421. inputQual == DxilParamInputQual::InputPrimitive) {
  4422. Type *Ty = Arg->getType();
  4423. if (Ty->isPointerTy())
  4424. Ty = Ty->getPointerElementType();
  4425. if (Ty->isArrayTy())
  4426. semanticTypeMap[semantic] = Ty->getArrayElementType();
  4427. } else {
  4428. semanticTypeMap[semantic] = Arg->getType();
  4429. }
  4430. }
  4431. std::vector<Instruction*> deadAllocas;
  4432. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  4433. unsigned debugOffset = 0;
  4434. const DataLayout &DL = F->getParent()->getDataLayout();
  4435. // Process the worklist
  4436. while (!WorkList.empty()) {
  4437. AnnotatedValue AV = WorkList.front();
  4438. WorkList.pop_front();
  4439. // Do not skip unused parameter.
  4440. Value *V = AV.Value;
  4441. DxilFieldAnnotation &annotation = AV.Annotation;
  4442. // We can never replace memcpy for arguments because they have an implicit
  4443. // first memcpy that happens from argument passing, and pointer analysis
  4444. // will not reveal that, especially if we've done a first SROA pass on V.
  4445. // No DomTree needed for that reason
  4446. const bool bAllowReplace = false;
  4447. SROA_Helper::LowerMemcpy(V, &annotation, dxilTypeSys, DL, nullptr /*DT */, bAllowReplace);
  4448. // Now is safe to create the IRBuilder.
  4449. // If we create it before LowerMemcpy, the insertion pointer instruction may get deleted
  4450. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(EntryBlock));
  4451. std::vector<Value *> Elts;
  4452. // Not flat vector for entry function currently.
  4453. bool SROAed = false;
  4454. Type *BrokenUpTy = nullptr;
  4455. uint64_t NumInstances = 1;
  4456. if (inputQual != DxilParamInputQual::InPayload) {
  4457. // DomTree isn't used by arguments
  4458. SROAed = SROA_Helper::DoScalarReplacement(
  4459. V, Elts, BrokenUpTy, NumInstances, Builder,
  4460. /*bFlatVector*/ false, annotation.IsPrecise(),
  4461. dxilTypeSys, DL, DeadInsts, /*DT*/ nullptr);
  4462. }
  4463. if (SROAed) {
  4464. Type *Ty = V->getType()->getPointerElementType();
  4465. // Skip empty struct parameters.
  4466. if (SROA_Helper::IsEmptyStructType(Ty, dxilTypeSys)) {
  4467. SROA_Helper::MarkEmptyStructUsers(V, DeadInsts);
  4468. DeleteDeadInstructions();
  4469. continue;
  4470. }
  4471. bool precise = annotation.IsPrecise();
  4472. const std::string &semantic = annotation.GetSemanticString();
  4473. hlsl::InterpolationMode interpMode = annotation.GetInterpolationMode();
  4474. // Push Elts into workList from right to left to preserve the order.
  4475. for (unsigned ri=0;ri<Elts.size();ri++) {
  4476. unsigned i = Elts.size() - ri - 1;
  4477. DxilFieldAnnotation EltAnnotation = GetEltAnnotation(Ty, i, annotation, dxilTypeSys);
  4478. const std::string &eltSem = EltAnnotation.GetSemanticString();
  4479. if (!semantic.empty()) {
  4480. if (!eltSem.empty()) {
  4481. // It doesn't look like we can provide source location information from here
  4482. F->getContext().emitWarning(
  4483. Twine("semantic '") + eltSem + "' on field overridden by function or enclosing type");
  4484. }
  4485. // Inherit semantic from parent, but only preserve it for the first element.
  4486. // Subsequent elements are noted with a special value that gets resolved
  4487. // once the argument is completely flattened.
  4488. EltAnnotation.SetSemanticString(i == 0 ? semantic : ContinuedPseudoSemantic);
  4489. } else if (!eltSem.empty() &&
  4490. semanticTypeMap.count(eltSem) == 0) {
  4491. Type *EltTy = dxilutil::GetArrayEltTy(Ty);
  4492. DXASSERT(EltTy->isStructTy(), "must be a struct type to has semantic.");
  4493. semanticTypeMap[eltSem] = EltTy->getStructElementType(i);
  4494. }
  4495. if (precise)
  4496. EltAnnotation.SetPrecise();
  4497. if (EltAnnotation.GetInterpolationMode().GetKind() == DXIL::InterpolationMode::Undefined)
  4498. EltAnnotation.SetInterpolationMode(interpMode);
  4499. WorkList.push_front({ Elts[i], EltAnnotation });
  4500. }
  4501. ++NumReplaced;
  4502. if (Instruction *I = dyn_cast<Instruction>(V))
  4503. deadAllocas.emplace_back(I);
  4504. } else {
  4505. Type *Ty = V->getType();
  4506. if (Ty->isPointerTy())
  4507. Ty = Ty->getPointerElementType();
  4508. // Flatten array of SV_Target.
  4509. StringRef semanticStr = annotation.GetSemanticString();
  4510. if (semanticStr.upper().find("SV_TARGET") == 0 &&
  4511. Ty->isArrayTy()) {
  4512. Type *Ty = cast<ArrayType>(V->getType()->getPointerElementType());
  4513. StringRef targetStr;
  4514. unsigned targetIndex;
  4515. Semantic::DecomposeNameAndIndex(semanticStr, &targetStr, &targetIndex);
  4516. // Replace target parameter with local target.
  4517. AllocaInst *localTarget = Builder.CreateAlloca(Ty);
  4518. V->replaceAllUsesWith(localTarget);
  4519. unsigned arraySize = 1;
  4520. std::vector<unsigned> arraySizeList;
  4521. while (Ty->isArrayTy()) {
  4522. unsigned size = Ty->getArrayNumElements();
  4523. arraySizeList.emplace_back(size);
  4524. arraySize *= size;
  4525. Ty = Ty->getArrayElementType();
  4526. }
  4527. unsigned arrayLevel = arraySizeList.size();
  4528. std::vector<unsigned> arrayIdxList(arrayLevel, 0);
  4529. // Create flattened target.
  4530. DxilFieldAnnotation EltAnnotation = annotation;
  4531. for (unsigned i=0;i<arraySize;i++) {
  4532. Value *Elt = Builder.CreateAlloca(Ty);
  4533. EltAnnotation.SetSemanticString(targetStr.str()+std::to_string(targetIndex+i));
  4534. // Add semantic type.
  4535. semanticTypeMap[EltAnnotation.GetSemanticString()] = Ty;
  4536. WorkList.push_front({ Elt, EltAnnotation });
  4537. // Copy local target to flattened target.
  4538. std::vector<Value*> idxList(arrayLevel+1);
  4539. idxList[0] = Builder.getInt32(0);
  4540. for (unsigned idx=0;idx<arrayLevel; idx++) {
  4541. idxList[idx+1] = Builder.getInt32(arrayIdxList[idx]);
  4542. }
  4543. if (bForParam) {
  4544. // If Argument, copy before each return.
  4545. for (auto &BB : F->getBasicBlockList()) {
  4546. TerminatorInst *TI = BB.getTerminator();
  4547. if (isa<ReturnInst>(TI)) {
  4548. IRBuilder<> RetBuilder(TI);
  4549. Value *Ptr = RetBuilder.CreateGEP(localTarget, idxList);
  4550. Value *V = RetBuilder.CreateLoad(Ptr);
  4551. RetBuilder.CreateStore(V, Elt);
  4552. }
  4553. }
  4554. } else {
  4555. // Else, copy with Builder.
  4556. Value *Ptr = Builder.CreateGEP(localTarget, idxList);
  4557. Value *V = Builder.CreateLoad(Ptr);
  4558. Builder.CreateStore(V, Elt);
  4559. }
  4560. // Update arrayIdxList.
  4561. for (unsigned idx=arrayLevel;idx>0;idx--) {
  4562. arrayIdxList[idx-1]++;
  4563. if (arrayIdxList[idx-1] < arraySizeList[idx-1])
  4564. break;
  4565. arrayIdxList[idx-1] = 0;
  4566. }
  4567. }
  4568. continue;
  4569. }
  4570. // Cast vector/matrix/resource parameter.
  4571. V = castArgumentIfRequired(V, Ty, bOut, inputQual,
  4572. annotation, Builder, dxilTypeSys);
  4573. // Cannot SROA, save it to final parameter list.
  4574. FlatParamList.emplace_back(V);
  4575. // Create ParamAnnotation for V.
  4576. FlatAnnotationList.emplace_back(DxilParameterAnnotation());
  4577. DxilParameterAnnotation &flatParamAnnotation = FlatAnnotationList.back();
  4578. flatParamAnnotation.SetParamInputQual(paramAnnotation.GetParamInputQual());
  4579. flatParamAnnotation.SetInterpolationMode(annotation.GetInterpolationMode());
  4580. flatParamAnnotation.SetSemanticString(annotation.GetSemanticString());
  4581. flatParamAnnotation.SetCompType(annotation.GetCompType().GetKind());
  4582. flatParamAnnotation.SetMatrixAnnotation(annotation.GetMatrixAnnotation());
  4583. flatParamAnnotation.SetPrecise(annotation.IsPrecise());
  4584. flatParamAnnotation.SetResourceAttribute(annotation.GetResourceAttribute());
  4585. // Add debug info.
  4586. if (DDI && V != Arg) {
  4587. Value *TmpV = V;
  4588. // If V is casted, add debug into to original V.
  4589. if (castParamMap.count(V)) {
  4590. TmpV = castParamMap[V].first;
  4591. // One more level for ptr of input vector.
  4592. // It cast from ptr to non-ptr then cast to scalars.
  4593. if (castParamMap.count(TmpV)) {
  4594. TmpV = castParamMap[TmpV].first;
  4595. }
  4596. }
  4597. Type *Ty = TmpV->getType();
  4598. if (Ty->isPointerTy())
  4599. Ty = Ty->getPointerElementType();
  4600. unsigned size = DL.getTypeAllocSize(Ty);
  4601. #if 0 // HLSL Change
  4602. DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
  4603. #else // HLSL Change
  4604. Type *argTy = Arg->getType();
  4605. if (argTy->isPointerTy())
  4606. argTy = argTy->getPointerElementType();
  4607. DIExpression *DDIExp = nullptr;
  4608. if (debugOffset == 0 && DL.getTypeAllocSize(argTy) == size) {
  4609. DDIExp = DIB.createExpression();
  4610. }
  4611. else {
  4612. DDIExp = DIB.createBitPieceExpression(debugOffset * 8, size * 8);
  4613. }
  4614. #endif // HLSL Change
  4615. debugOffset += size;
  4616. DIB.insertDeclare(TmpV, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  4617. Builder.GetInsertPoint());
  4618. }
  4619. // Flatten stream out.
  4620. if (HLModule::IsStreamOutputPtrType(V->getType())) {
  4621. // For stream output objects.
  4622. // Create a value as output value.
  4623. Type *outputType = V->getType()->getPointerElementType()->getStructElementType(0);
  4624. Value *outputVal = Builder.CreateAlloca(outputType);
  4625. // For each stream.Append(data)
  4626. // transform into
  4627. // d = load data
  4628. // store outputVal, d
  4629. // stream.Append(outputVal)
  4630. for (User *user : V->users()) {
  4631. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  4632. unsigned opcode = GetHLOpcode(CI);
  4633. if (opcode == static_cast<unsigned>(IntrinsicOp::MOP_Append)) {
  4634. // At this point, the stream append data argument might or not have been SROA'd
  4635. Value *firstDataPtr = CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  4636. DXASSERT(firstDataPtr->getType()->isPointerTy(), "Append value must be a pointer.");
  4637. if (firstDataPtr->getType()->getPointerElementType() == outputType) {
  4638. // The data has not been SROA'd
  4639. DXASSERT(CI->getNumArgOperands() == (HLOperandIndex::kStreamAppendDataOpIndex + 1),
  4640. "Unexpected number of arguments for non-SROA'd StreamOutput.Append");
  4641. IRBuilder<> Builder(CI);
  4642. llvm::SmallVector<llvm::Value *, 16> idxList;
  4643. SplitCpy(firstDataPtr->getType(), outputVal, firstDataPtr, idxList, Builder, DL,
  4644. dxilTypeSys, &flatParamAnnotation);
  4645. CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
  4646. }
  4647. else {
  4648. // Append has been SROA'd, we might be operating on multiple values
  4649. // with types differing from the stream output type.
  4650. // Flatten store outputVal.
  4651. // Must be struct to be flatten.
  4652. IRBuilder<> Builder(CI);
  4653. llvm::SmallVector<llvm::Value *, 16> IdxList;
  4654. llvm::SmallVector<llvm::Value *, 16> EltPtrList;
  4655. llvm::SmallVector<const DxilFieldAnnotation*, 16> EltAnnotationList;
  4656. // split
  4657. SplitPtr(outputVal, IdxList, outputVal->getType(), flatParamAnnotation,
  4658. EltPtrList, EltAnnotationList, dxilTypeSys, Builder);
  4659. unsigned eltCount = CI->getNumArgOperands()-2;
  4660. DXASSERT_LOCALVAR(eltCount, eltCount == EltPtrList.size(), "invalid element count");
  4661. for (unsigned i = HLOperandIndex::kStreamAppendDataOpIndex; i < CI->getNumArgOperands(); i++) {
  4662. Value *DataPtr = CI->getArgOperand(i);
  4663. Value *EltPtr = EltPtrList[i - HLOperandIndex::kStreamAppendDataOpIndex];
  4664. const DxilFieldAnnotation *EltAnnotation = EltAnnotationList[i - HLOperandIndex::kStreamAppendDataOpIndex];
  4665. llvm::SmallVector<llvm::Value *, 16> IdxList;
  4666. SplitCpy(DataPtr->getType(), EltPtr, DataPtr, IdxList,
  4667. Builder, DL, dxilTypeSys, EltAnnotation);
  4668. CI->setArgOperand(i, EltPtr);
  4669. }
  4670. }
  4671. }
  4672. }
  4673. }
  4674. // Then split output value to generate ParamQual.
  4675. WorkList.push_front({ outputVal, annotation });
  4676. }
  4677. }
  4678. }
  4679. // Now erase any instructions that were made dead while rewriting the
  4680. // alloca.
  4681. DeleteDeadInstructions();
  4682. // Erase dead allocas after all uses deleted.
  4683. for (Instruction *I : deadAllocas)
  4684. I->eraseFromParent();
  4685. unsigned endArgIndex = FlatAnnotationList.size();
  4686. if (bForParam && startArgIndex < endArgIndex) {
  4687. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  4688. if (inputQual == DxilParamInputQual::OutStream0 ||
  4689. inputQual == DxilParamInputQual::OutStream1 ||
  4690. inputQual == DxilParamInputQual::OutStream2 ||
  4691. inputQual == DxilParamInputQual::OutStream3)
  4692. startArgIndex++;
  4693. DxilParameterAnnotation &flatParamAnnotation =
  4694. FlatAnnotationList[startArgIndex];
  4695. const std::string &semantic = flatParamAnnotation.GetSemanticString();
  4696. if (!semantic.empty())
  4697. allocateSemanticIndex(FlatAnnotationList, startArgIndex,
  4698. semanticTypeMap);
  4699. }
  4700. }
  4701. static bool IsUsedAsCallArg(Value *V) {
  4702. for (User *U : V->users()) {
  4703. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  4704. Function *CalledF = CI->getCalledFunction();
  4705. HLOpcodeGroup group = GetHLOpcodeGroup(CalledF);
  4706. // Skip HL operations.
  4707. if (group != HLOpcodeGroup::NotHL ||
  4708. group == HLOpcodeGroup::HLExtIntrinsic) {
  4709. continue;
  4710. }
  4711. // Skip llvm intrinsic.
  4712. if (CalledF->isIntrinsic())
  4713. continue;
  4714. return true;
  4715. }
  4716. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  4717. if (IsUsedAsCallArg(GEP))
  4718. return true;
  4719. }
  4720. }
  4721. return false;
  4722. }
  4723. // For function parameter which used in function call and need to be flattened.
  4724. // Replace with tmp alloca.
  4725. void SROA_Parameter_HLSL::preprocessArgUsedInCall(Function *F) {
  4726. if (F->isDeclaration())
  4727. return;
  4728. const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
  4729. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  4730. DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(F);
  4731. DXASSERT(pFuncAnnot, "else invalid function");
  4732. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(F));
  4733. SmallVector<ReturnInst*, 2> retList;
  4734. for (BasicBlock &bb : F->getBasicBlockList()) {
  4735. if (ReturnInst *RI = dyn_cast<ReturnInst>(bb.getTerminator())) {
  4736. retList.emplace_back(RI);
  4737. }
  4738. }
  4739. for (Argument &arg : F->args()) {
  4740. Type *Ty = arg.getType();
  4741. // Only check pointer types.
  4742. if (!Ty->isPointerTy())
  4743. continue;
  4744. Ty = Ty->getPointerElementType();
  4745. // Skip scalar types.
  4746. if (!Ty->isAggregateType() &&
  4747. Ty->getScalarType() == Ty)
  4748. continue;
  4749. bool bUsedInCall = IsUsedAsCallArg(&arg);
  4750. if (bUsedInCall) {
  4751. // Create tmp.
  4752. Value *TmpArg = Builder.CreateAlloca(Ty);
  4753. // Replace arg with tmp.
  4754. arg.replaceAllUsesWith(TmpArg);
  4755. DxilParameterAnnotation &paramAnnot = pFuncAnnot->GetParameterAnnotation(arg.getArgNo());
  4756. DxilParamInputQual inputQual = paramAnnot.GetParamInputQual();
  4757. unsigned size = DL.getTypeAllocSize(Ty);
  4758. // Copy between arg and tmp.
  4759. if (inputQual == DxilParamInputQual::In ||
  4760. inputQual == DxilParamInputQual::Inout) {
  4761. // copy arg to tmp.
  4762. CallInst *argToTmp = Builder.CreateMemCpy(TmpArg, &arg, size, 0);
  4763. // Split the memcpy.
  4764. MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(argToTmp), DL, nullptr,
  4765. typeSys);
  4766. }
  4767. if (inputQual == DxilParamInputQual::Out ||
  4768. inputQual == DxilParamInputQual::Inout) {
  4769. for (ReturnInst *RI : retList) {
  4770. IRBuilder<> RetBuilder(RI);
  4771. // copy tmp to arg.
  4772. CallInst *tmpToArg =
  4773. RetBuilder.CreateMemCpy(&arg, TmpArg, size, 0);
  4774. // Split the memcpy.
  4775. MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(tmpToArg), DL, nullptr,
  4776. typeSys);
  4777. }
  4778. }
  4779. // TODO: support other DxilParamInputQual.
  4780. }
  4781. }
  4782. }
  4783. /// moveFunctionBlocks - Move body of F to flatF.
  4784. void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
  4785. bool updateRetType = F->getReturnType() != flatF->getReturnType();
  4786. // Splice the body of the old function right into the new function.
  4787. flatF->getBasicBlockList().splice(flatF->begin(), F->getBasicBlockList());
  4788. // Update Block uses.
  4789. if (updateRetType) {
  4790. for (BasicBlock &BB : flatF->getBasicBlockList()) {
  4791. if (updateRetType) {
  4792. // Replace ret with ret void.
  4793. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4794. // Create store for return.
  4795. IRBuilder<> Builder(RI);
  4796. Builder.CreateRetVoid();
  4797. RI->eraseFromParent();
  4798. }
  4799. }
  4800. }
  4801. }
  4802. }
  4803. static void SplitArrayCopy(Value *V, const DataLayout &DL,
  4804. DxilTypeSystem &typeSys,
  4805. DxilFieldAnnotation *fieldAnnotation) {
  4806. for (auto U = V->user_begin(); U != V->user_end();) {
  4807. User *user = *(U++);
  4808. if (StoreInst *ST = dyn_cast<StoreInst>(user)) {
  4809. Value *ptr = ST->getPointerOperand();
  4810. Value *val = ST->getValueOperand();
  4811. IRBuilder<> Builder(ST);
  4812. SmallVector<Value *, 16> idxList;
  4813. SplitCpy(ptr->getType(), ptr, val, idxList, Builder, DL, typeSys,
  4814. fieldAnnotation);
  4815. ST->eraseFromParent();
  4816. }
  4817. }
  4818. }
  4819. static void CheckArgUsage(Value *V, bool &bLoad, bool &bStore) {
  4820. if (bLoad && bStore)
  4821. return;
  4822. for (User *user : V->users()) {
  4823. if (dyn_cast<LoadInst>(user)) {
  4824. bLoad = true;
  4825. } else if (dyn_cast<StoreInst>(user)) {
  4826. bStore = true;
  4827. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  4828. CheckArgUsage(GEP, bLoad, bStore);
  4829. } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
  4830. if (CI->getType()->isPointerTy())
  4831. CheckArgUsage(CI, bLoad, bStore);
  4832. else {
  4833. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  4834. if (group == HLOpcodeGroup::HLMatLoadStore) {
  4835. HLMatLoadStoreOpcode opcode =
  4836. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  4837. switch (opcode) {
  4838. case HLMatLoadStoreOpcode::ColMatLoad:
  4839. case HLMatLoadStoreOpcode::RowMatLoad:
  4840. bLoad = true;
  4841. break;
  4842. case HLMatLoadStoreOpcode::ColMatStore:
  4843. case HLMatLoadStoreOpcode::RowMatStore:
  4844. bStore = true;
  4845. break;
  4846. }
  4847. }
  4848. }
  4849. }
  4850. }
  4851. }
  4852. // AcceptHitAndEndSearch and IgnoreHit both will not return, but require
  4853. // outputs to have been written before the call. Do this by:
  4854. // - inject a return immediately after the call if not there already
  4855. // - LegalizeDxilInputOutputs will inject writes from temp alloca to
  4856. // outputs before each return.
  4857. // - in HLOperationLower, after lowering the intrinsic, move the intrinsic
  4858. // to just before the return.
  4859. static void InjectReturnAfterNoReturnPreserveOutput(HLModule &HLM) {
  4860. for (Function &F : HLM.GetModule()->functions()) {
  4861. if (GetHLOpcodeGroup(&F) == HLOpcodeGroup::HLIntrinsic) {
  4862. for (auto U : F.users()) {
  4863. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  4864. unsigned OpCode = GetHLOpcode(CI);
  4865. if (OpCode == (unsigned)IntrinsicOp::IOP_AcceptHitAndEndSearch ||
  4866. OpCode == (unsigned)IntrinsicOp::IOP_IgnoreHit) {
  4867. Instruction *pNextI = CI->getNextNode();
  4868. // Skip if already has a return immediatly following call
  4869. if (isa<ReturnInst>(pNextI))
  4870. continue;
  4871. // split block and add return:
  4872. BasicBlock *BB = CI->getParent();
  4873. BB->splitBasicBlock(pNextI);
  4874. TerminatorInst *Term = BB->getTerminator();
  4875. Term->eraseFromParent();
  4876. IRBuilder<> Builder(BB);
  4877. llvm::Type *RetTy = CI->getParent()->getParent()->getReturnType();
  4878. if (RetTy->isVoidTy())
  4879. Builder.CreateRetVoid();
  4880. else
  4881. Builder.CreateRet(UndefValue::get(RetTy));
  4882. }
  4883. }
  4884. }
  4885. }
  4886. }
  4887. }
  4888. // Support store to input and load from output.
  4889. static void LegalizeDxilInputOutputs(Function *F,
  4890. DxilFunctionAnnotation *EntryAnnotation,
  4891. const DataLayout &DL,
  4892. DxilTypeSystem &typeSys) {
  4893. BasicBlock &EntryBlk = F->getEntryBlock();
  4894. Module *M = F->getParent();
  4895. // Map from output to the temp created for it.
  4896. MapVector<Argument *, Value*> outputTempMap; // Need deterministic order of iteration
  4897. for (Argument &arg : F->args()) {
  4898. Type *Ty = arg.getType();
  4899. DxilParameterAnnotation &paramAnnotation = EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  4900. DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
  4901. bool isColMajor = false;
  4902. // Skip arg which is not a pointer.
  4903. if (!Ty->isPointerTy()) {
  4904. if (HLMatrixType::isa(Ty)) {
  4905. // Replace matrix arg with cast to vec. It will be lowered in
  4906. // DxilGenerationPass.
  4907. isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
  4908. MatrixOrientation::ColumnMajor;
  4909. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(F));
  4910. HLCastOpcode opcode = isColMajor ? HLCastOpcode::ColMatrixToVecCast
  4911. : HLCastOpcode::RowMatrixToVecCast;
  4912. Value *undefVal = UndefValue::get(Ty);
  4913. Value *Cast = HLModule::EmitHLOperationCall(
  4914. Builder, HLOpcodeGroup::HLCast, static_cast<unsigned>(opcode), Ty,
  4915. {undefVal}, *M);
  4916. arg.replaceAllUsesWith(Cast);
  4917. // Set arg as the operand.
  4918. CallInst *CI = cast<CallInst>(Cast);
  4919. CI->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, &arg);
  4920. }
  4921. continue;
  4922. }
  4923. Ty = Ty->getPointerElementType();
  4924. bool bLoad = false;
  4925. bool bStore = false;
  4926. CheckArgUsage(&arg, bLoad, bStore);
  4927. bool bStoreInputToTemp = false;
  4928. bool bLoadOutputFromTemp = false;
  4929. if (qual == DxilParamInputQual::In && bStore) {
  4930. bStoreInputToTemp = true;
  4931. } else if (qual == DxilParamInputQual::Out && bLoad) {
  4932. bLoadOutputFromTemp = true;
  4933. } else if (bLoad && bStore) {
  4934. switch (qual) {
  4935. case DxilParamInputQual::InPayload:
  4936. case DxilParamInputQual::InputPrimitive:
  4937. case DxilParamInputQual::InputPatch:
  4938. case DxilParamInputQual::OutputPatch: {
  4939. bStoreInputToTemp = true;
  4940. } break;
  4941. case DxilParamInputQual::Inout:
  4942. break;
  4943. default:
  4944. DXASSERT(0, "invalid input qual here");
  4945. }
  4946. } else if (qual == DxilParamInputQual::Inout) {
  4947. // Only replace inout when (bLoad && bStore) == false.
  4948. bLoadOutputFromTemp = true;
  4949. bStoreInputToTemp = true;
  4950. }
  4951. if (HLMatrixType::isa(Ty)) {
  4952. if (qual == DxilParamInputQual::In)
  4953. bStoreInputToTemp = bLoad;
  4954. else if (qual == DxilParamInputQual::Out)
  4955. bLoadOutputFromTemp = bStore;
  4956. else if (qual == DxilParamInputQual::Inout) {
  4957. bStoreInputToTemp = true;
  4958. bLoadOutputFromTemp = true;
  4959. }
  4960. }
  4961. if (bStoreInputToTemp || bLoadOutputFromTemp) {
  4962. IRBuilder<> Builder(EntryBlk.getFirstInsertionPt());
  4963. AllocaInst *temp = Builder.CreateAlloca(Ty);
  4964. // Replace all uses with temp.
  4965. arg.replaceAllUsesWith(temp);
  4966. // Copy input to temp.
  4967. if (bStoreInputToTemp) {
  4968. llvm::SmallVector<llvm::Value *, 16> idxList;
  4969. // split copy.
  4970. SplitCpy(temp->getType(), temp, &arg, idxList, Builder, DL, typeSys,
  4971. &paramAnnotation);
  4972. }
  4973. // Generate store output, temp later.
  4974. if (bLoadOutputFromTemp) {
  4975. outputTempMap[&arg] = temp;
  4976. }
  4977. }
  4978. }
  4979. for (BasicBlock &BB : F->getBasicBlockList()) {
  4980. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4981. IRBuilder<> Builder(RI);
  4982. // Copy temp to output.
  4983. for (auto It : outputTempMap) {
  4984. Argument *output = It.first;
  4985. Value *temp = It.second;
  4986. llvm::SmallVector<llvm::Value *, 16> idxList;
  4987. DxilParameterAnnotation &paramAnnotation =
  4988. EntryAnnotation->GetParameterAnnotation(output->getArgNo());
  4989. auto Iter = Builder.GetInsertPoint();
  4990. if (RI != BB.begin())
  4991. Iter--;
  4992. // split copy.
  4993. SplitCpy(output->getType(), output, temp, idxList, Builder, DL, typeSys,
  4994. &paramAnnotation);
  4995. }
  4996. // Clone the return.
  4997. Builder.CreateRet(RI->getReturnValue());
  4998. RI->eraseFromParent();
  4999. }
  5000. }
  5001. }
  5002. void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
  5003. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  5004. DXASSERT(F == m_pHLModule->GetEntryFunction() ||
  5005. m_pHLModule->IsEntryThatUsesSignatures(F),
  5006. "otherwise, createFlattenedFunction called on library function "
  5007. "that should not be flattened.");
  5008. const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
  5009. // Skip void (void) function.
  5010. if (F->getReturnType()->isVoidTy() && F->getArgumentList().empty()) {
  5011. return;
  5012. }
  5013. // Clear maps for cast.
  5014. castParamMap.clear();
  5015. vectorEltsMap.clear();
  5016. DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
  5017. DXASSERT(funcAnnotation, "must find annotation for function");
  5018. std::deque<Value *> WorkList;
  5019. LLVMContext &Ctx = m_pHLModule->GetCtx();
  5020. std::unique_ptr<BasicBlock> TmpBlockForFuncDecl;
  5021. BasicBlock *EntryBlock;
  5022. if (F->isDeclaration()) {
  5023. // We still want to SROA the parameters, so creaty a dummy
  5024. // function body block to avoid special cases.
  5025. TmpBlockForFuncDecl.reset(BasicBlock::Create(Ctx));
  5026. // Create return as terminator.
  5027. IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get());
  5028. RetBuilder.CreateRetVoid();
  5029. EntryBlock = TmpBlockForFuncDecl.get();
  5030. } else {
  5031. EntryBlock = &F->getEntryBlock();
  5032. }
  5033. std::vector<Value *> FlatParamList;
  5034. std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
  5035. std::vector<int> FlatParamOriArgNoList;
  5036. const bool bForParamTrue = true;
  5037. // Add all argument to worklist.
  5038. for (Argument &Arg : F->args()) {
  5039. // merge GEP use for arg.
  5040. HLModule::MergeGepUse(&Arg);
  5041. unsigned prevFlatParamCount = FlatParamList.size();
  5042. DxilParameterAnnotation &paramAnnotation =
  5043. funcAnnotation->GetParameterAnnotation(Arg.getArgNo());
  5044. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&Arg);
  5045. flattenArgument(F, &Arg, bForParamTrue, paramAnnotation, FlatParamList,
  5046. FlatParamAnnotationList, EntryBlock, DDI);
  5047. unsigned newFlatParamCount = FlatParamList.size() - prevFlatParamCount;
  5048. for (unsigned i = 0; i < newFlatParamCount; i++) {
  5049. FlatParamOriArgNoList.emplace_back(Arg.getArgNo());
  5050. }
  5051. }
  5052. Type *retType = F->getReturnType();
  5053. std::vector<Value *> FlatRetList;
  5054. std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
  5055. // Split and change to out parameter.
  5056. if (!retType->isVoidTy()) {
  5057. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(EntryBlock));
  5058. Value *retValAddr = Builder.CreateAlloca(retType);
  5059. DxilParameterAnnotation &retAnnotation =
  5060. funcAnnotation->GetRetTypeAnnotation();
  5061. Module &M = *m_pHLModule->GetModule();
  5062. Type *voidTy = Type::getVoidTy(m_pHLModule->GetCtx());
  5063. #if 0 // We don't really want this to show up in debug info.
  5064. // Create DbgDecl for the ret value.
  5065. if (DISubprogram *funcDI = getDISubprogram(F)) {
  5066. DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
  5067. DITypeIdentifierMap EmptyMap;
  5068. DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
  5069. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  5070. DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
  5071. funcDI->getLine(), RetDIType);
  5072. DIExpression *Expr = DIB.createExpression();
  5073. // TODO: how to get col?
  5074. DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0, funcDI);
  5075. DIB.insertDeclare(retValAddr, RetVar, Expr, DL, Builder.GetInsertPoint());
  5076. }
  5077. #endif
  5078. for (BasicBlock &BB : F->getBasicBlockList()) {
  5079. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  5080. // Create store for return.
  5081. IRBuilder<> RetBuilder(RI);
  5082. if (!retAnnotation.HasMatrixAnnotation()) {
  5083. RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
  5084. } else {
  5085. bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
  5086. MatrixOrientation::RowMajor;
  5087. Value *RetVal = RI->getReturnValue();
  5088. if (!isRowMajor) {
  5089. // Matrix value is row major. ColMatStore require col major.
  5090. // Cast before store.
  5091. RetVal = HLModule::EmitHLOperationCall(
  5092. RetBuilder, HLOpcodeGroup::HLCast,
  5093. static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
  5094. RetVal->getType(), {RetVal}, M);
  5095. }
  5096. unsigned opcode = static_cast<unsigned>(
  5097. isRowMajor ? HLMatLoadStoreOpcode::RowMatStore
  5098. : HLMatLoadStoreOpcode::ColMatStore);
  5099. HLModule::EmitHLOperationCall(RetBuilder,
  5100. HLOpcodeGroup::HLMatLoadStore, opcode,
  5101. voidTy, {retValAddr, RetVal}, M);
  5102. }
  5103. }
  5104. }
  5105. // Create a fake store to keep retValAddr so it can be flattened.
  5106. if (retValAddr->user_empty()) {
  5107. Builder.CreateStore(UndefValue::get(retType), retValAddr);
  5108. }
  5109. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(retValAddr);
  5110. flattenArgument(F, retValAddr, bForParamTrue,
  5111. funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
  5112. FlatRetAnnotationList, EntryBlock, DDI);
  5113. const int kRetArgNo = -1;
  5114. for (unsigned i = 0; i < FlatRetList.size(); i++) {
  5115. FlatParamOriArgNoList.insert(FlatParamOriArgNoList.begin(), kRetArgNo);
  5116. }
  5117. }
  5118. // Always change return type as parameter.
  5119. // By doing this, no need to check return when generate storeOutput.
  5120. if (FlatRetList.size() ||
  5121. // For empty struct return type.
  5122. !retType->isVoidTy()) {
  5123. // Return value is flattened.
  5124. // Change return value into out parameter.
  5125. retType = Type::getVoidTy(retType->getContext());
  5126. // Merge return data info param data.
  5127. FlatParamList.insert(FlatParamList.begin(), FlatRetList.begin(), FlatRetList.end());
  5128. FlatParamAnnotationList.insert(FlatParamAnnotationList.begin(),
  5129. FlatRetAnnotationList.begin(),
  5130. FlatRetAnnotationList.end());
  5131. }
  5132. std::vector<Type *> FinalTypeList;
  5133. for (Value * arg : FlatParamList) {
  5134. FinalTypeList.emplace_back(arg->getType());
  5135. }
  5136. unsigned extraParamSize = 0;
  5137. if (m_pHLModule->HasDxilFunctionProps(F)) {
  5138. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
  5139. if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
  5140. auto &VS = funcProps.ShaderProps.VS;
  5141. Type *outFloatTy = Type::getFloatPtrTy(F->getContext());
  5142. // Add out float parameter for each clip plane.
  5143. unsigned i=0;
  5144. for (; i < DXIL::kNumClipPlanes; i++) {
  5145. if (!VS.clipPlanes[i])
  5146. break;
  5147. FinalTypeList.emplace_back(outFloatTy);
  5148. }
  5149. extraParamSize = i;
  5150. }
  5151. }
  5152. FunctionType *flatFuncTy = FunctionType::get(retType, FinalTypeList, false);
  5153. // Return if nothing changed.
  5154. if (flatFuncTy == F->getFunctionType()) {
  5155. // Copy semantic allocation.
  5156. if (!FlatParamAnnotationList.empty()) {
  5157. if (!FlatParamAnnotationList[0].GetSemanticString().empty()) {
  5158. for (unsigned i = 0; i < FlatParamAnnotationList.size(); i++) {
  5159. DxilParameterAnnotation &paramAnnotation = funcAnnotation->GetParameterAnnotation(i);
  5160. DxilParameterAnnotation &flatParamAnnotation = FlatParamAnnotationList[i];
  5161. paramAnnotation.SetSemanticIndexVec(flatParamAnnotation.GetSemanticIndexVec());
  5162. paramAnnotation.SetSemanticString(flatParamAnnotation.GetSemanticString());
  5163. }
  5164. }
  5165. }
  5166. if (!F->isDeclaration()) {
  5167. // Support store to input and load from output.
  5168. LegalizeDxilInputOutputs(F, funcAnnotation, DL, typeSys);
  5169. }
  5170. return;
  5171. }
  5172. std::string flatName = F->getName().str() + ".flat";
  5173. DXASSERT(nullptr == F->getParent()->getFunction(flatName),
  5174. "else overwriting existing function");
  5175. Function *flatF =
  5176. cast<Function>(F->getParent()->getOrInsertFunction(flatName, flatFuncTy));
  5177. funcMap[F] = flatF;
  5178. // Update function debug info.
  5179. if (DISubprogram *funcDI = getDISubprogram(F))
  5180. funcDI->replaceFunction(flatF);
  5181. // Create FunctionAnnotation for flatF.
  5182. DxilFunctionAnnotation *flatFuncAnnotation = m_pHLModule->AddFunctionAnnotation(flatF);
  5183. // Don't need to set Ret Info, flatF always return void now.
  5184. // Param Info
  5185. for (unsigned ArgNo = 0; ArgNo < FlatParamAnnotationList.size(); ++ArgNo) {
  5186. DxilParameterAnnotation &paramAnnotation = flatFuncAnnotation->GetParameterAnnotation(ArgNo);
  5187. paramAnnotation = FlatParamAnnotationList[ArgNo];
  5188. }
  5189. // Function Attr and Parameter Attr.
  5190. // Remove sret first.
  5191. if (F->hasStructRetAttr())
  5192. F->removeFnAttr(Attribute::StructRet);
  5193. for (Argument &arg : F->args()) {
  5194. if (arg.hasStructRetAttr()) {
  5195. Attribute::AttrKind SRet [] = {Attribute::StructRet};
  5196. AttributeSet SRetAS = AttributeSet::get(Ctx, arg.getArgNo() + 1, SRet);
  5197. arg.removeAttr(SRetAS);
  5198. }
  5199. }
  5200. AttributeSet AS = F->getAttributes();
  5201. AttrBuilder FnAttrs(AS.getFnAttributes(), AttributeSet::FunctionIndex);
  5202. AttributeSet flatAS;
  5203. flatAS = flatAS.addAttributes(
  5204. Ctx, AttributeSet::FunctionIndex,
  5205. AttributeSet::get(Ctx, AttributeSet::FunctionIndex, FnAttrs));
  5206. if (!F->isDeclaration()) {
  5207. // Only set Param attribute for function has a body.
  5208. for (unsigned ArgNo = 0; ArgNo < FlatParamAnnotationList.size(); ++ArgNo) {
  5209. unsigned oriArgNo = FlatParamOriArgNoList[ArgNo] + 1;
  5210. AttrBuilder paramAttr(AS, oriArgNo);
  5211. if (oriArgNo == AttributeSet::ReturnIndex)
  5212. paramAttr.addAttribute(Attribute::AttrKind::NoAlias);
  5213. flatAS = flatAS.addAttributes(
  5214. Ctx, ArgNo + 1, AttributeSet::get(Ctx, ArgNo + 1, paramAttr));
  5215. }
  5216. }
  5217. flatF->setAttributes(flatAS);
  5218. DXASSERT_LOCALVAR(extraParamSize, flatF->arg_size() == (extraParamSize + FlatParamAnnotationList.size()), "parameter count mismatch");
  5219. // ShaderProps.
  5220. if (m_pHLModule->HasDxilFunctionProps(F)) {
  5221. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
  5222. std::unique_ptr<DxilFunctionProps> flatFuncProps = llvm::make_unique<DxilFunctionProps>();
  5223. flatFuncProps->shaderKind = funcProps.shaderKind;
  5224. flatFuncProps->ShaderProps = funcProps.ShaderProps;
  5225. m_pHLModule->AddDxilFunctionProps(flatF, flatFuncProps);
  5226. if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
  5227. auto &VS = funcProps.ShaderProps.VS;
  5228. unsigned clipArgIndex = FlatParamAnnotationList.size();
  5229. // Add out float SV_ClipDistance for each clip plane.
  5230. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  5231. if (!VS.clipPlanes[i])
  5232. break;
  5233. DxilParameterAnnotation &paramAnnotation =
  5234. flatFuncAnnotation->GetParameterAnnotation(clipArgIndex+i);
  5235. paramAnnotation.SetParamInputQual(DxilParamInputQual::Out);
  5236. Twine semName = Twine("SV_ClipDistance") + Twine(i);
  5237. paramAnnotation.SetSemanticString(semName.str());
  5238. paramAnnotation.SetCompType(DXIL::ComponentType::F32);
  5239. paramAnnotation.AppendSemanticIndex(i);
  5240. }
  5241. }
  5242. }
  5243. if (!F->isDeclaration()) {
  5244. // Move function body into flatF.
  5245. moveFunctionBody(F, flatF);
  5246. // Replace old parameters with flatF Arguments.
  5247. auto argIter = flatF->arg_begin();
  5248. auto flatArgIter = FlatParamList.begin();
  5249. LLVMContext &Context = F->getContext();
  5250. // Parameter cast come from begining of entry block.
  5251. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(flatF));
  5252. while (argIter != flatF->arg_end()) {
  5253. Argument *Arg = argIter++;
  5254. if (flatArgIter == FlatParamList.end()) {
  5255. DXASSERT(extraParamSize > 0, "parameter count mismatch");
  5256. break;
  5257. }
  5258. Value *flatArg = *(flatArgIter++);
  5259. if (castParamMap.count(flatArg)) {
  5260. replaceCastParameter(flatArg, castParamMap[flatArg].first, *flatF, Arg,
  5261. castParamMap[flatArg].second, Builder);
  5262. }
  5263. // Update arg debug info.
  5264. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(flatArg);
  5265. if (DDI) {
  5266. if (!flatArg->getType()->isPointerTy()) {
  5267. // Create alloca to hold the debug info.
  5268. Value *allocaArg = nullptr;
  5269. if (flatArg->hasOneUse() && isa<StoreInst>(*flatArg->user_begin())) {
  5270. StoreInst *SI = cast<StoreInst>(*flatArg->user_begin());
  5271. allocaArg = SI->getPointerOperand();
  5272. } else {
  5273. allocaArg = Builder.CreateAlloca(flatArg->getType());
  5274. StoreInst *initArg = Builder.CreateStore(flatArg, allocaArg);
  5275. Value *ldArg = Builder.CreateLoad(allocaArg);
  5276. flatArg->replaceAllUsesWith(ldArg);
  5277. initArg->setOperand(0, flatArg);
  5278. }
  5279. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(allocaArg));
  5280. DDI->setArgOperand(0, VMD);
  5281. } else {
  5282. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(Arg));
  5283. DDI->setArgOperand(0, VMD);
  5284. }
  5285. }
  5286. flatArg->replaceAllUsesWith(Arg);
  5287. if (isa<Instruction>(flatArg))
  5288. DeadInsts.emplace_back(flatArg);
  5289. HLModule::MergeGepUse(Arg);
  5290. // Flatten store of array parameter.
  5291. if (Arg->getType()->isPointerTy()) {
  5292. Type *Ty = Arg->getType()->getPointerElementType();
  5293. if (Ty->isArrayTy())
  5294. SplitArrayCopy(
  5295. Arg, DL, typeSys,
  5296. &flatFuncAnnotation->GetParameterAnnotation(Arg->getArgNo()));
  5297. }
  5298. }
  5299. // Support store to input and load from output.
  5300. LegalizeDxilInputOutputs(flatF, flatFuncAnnotation, DL, typeSys);
  5301. }
  5302. }
  5303. void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
  5304. // Update entry function.
  5305. if (F == m_pHLModule->GetEntryFunction()) {
  5306. m_pHLModule->SetEntryFunction(flatF);
  5307. }
  5308. DXASSERT(F->user_empty(), "otherwise we flattened a library function.");
  5309. }
  5310. // Public interface to the SROA_Parameter_HLSL pass
  5311. ModulePass *llvm::createSROA_Parameter_HLSL() {
  5312. return new SROA_Parameter_HLSL();
  5313. }
  5314. //===----------------------------------------------------------------------===//
  5315. // Lower static global into Alloca.
  5316. //===----------------------------------------------------------------------===//
  5317. namespace {
  5318. class LowerStaticGlobalIntoAlloca : public ModulePass {
  5319. DebugInfoFinder m_DbgFinder;
  5320. public:
  5321. static char ID; // Pass identification, replacement for typeid
  5322. explicit LowerStaticGlobalIntoAlloca() : ModulePass(ID) {}
  5323. const char *getPassName() const override { return "Lower static global into Alloca"; }
  5324. bool runOnModule(Module &M) override {
  5325. m_DbgFinder.processModule(M);
  5326. Type *handleTy = nullptr;
  5327. DxilTypeSystem *pTypeSys = nullptr;
  5328. SetVector<Function *> entryAndInitFunctionSet;
  5329. if (M.HasHLModule()) {
  5330. auto &HLM = M.GetHLModule();
  5331. pTypeSys = &HLM.GetTypeSystem();
  5332. handleTy = HLM.GetOP()->GetHandleType();
  5333. if (!HLM.GetShaderModel()->IsLib()) {
  5334. entryAndInitFunctionSet.insert(HLM.GetEntryFunction());
  5335. if (HLM.GetShaderModel()->IsHS()) {
  5336. entryAndInitFunctionSet.insert(HLM.GetPatchConstantFunction());
  5337. }
  5338. } else {
  5339. for (Function &F : M) {
  5340. if (!HLM.IsEntry(&F)) {
  5341. continue;
  5342. }
  5343. entryAndInitFunctionSet.insert(&F);
  5344. }
  5345. }
  5346. } else {
  5347. DXASSERT(M.HasDxilModule(), "must have dxilModle or HLModule");
  5348. auto &DM = M.GetDxilModule();
  5349. pTypeSys = &DM.GetTypeSystem();
  5350. handleTy = DM.GetOP()->GetHandleType();
  5351. if (!DM.GetShaderModel()->IsLib()) {
  5352. entryAndInitFunctionSet.insert(DM.GetEntryFunction());
  5353. if (DM.GetShaderModel()->IsHS()) {
  5354. entryAndInitFunctionSet.insert(DM.GetPatchConstantFunction());
  5355. }
  5356. } else {
  5357. for (Function &F : M) {
  5358. if (!DM.IsEntry(&F))
  5359. continue;
  5360. entryAndInitFunctionSet.insert(&F);
  5361. }
  5362. }
  5363. }
  5364. // Collect init functions for static globals.
  5365. if (GlobalVariable *Ctors = M.getGlobalVariable("llvm.global_ctors")) {
  5366. if (ConstantArray *CA =
  5367. dyn_cast<ConstantArray>(Ctors->getInitializer())) {
  5368. for (User::op_iterator i = CA->op_begin(), e = CA->op_end(); i != e;
  5369. ++i) {
  5370. if (isa<ConstantAggregateZero>(*i))
  5371. continue;
  5372. ConstantStruct *CS = cast<ConstantStruct>(*i);
  5373. if (isa<ConstantPointerNull>(CS->getOperand(1)))
  5374. continue;
  5375. // Must have a function or null ptr.
  5376. if (!isa<Function>(CS->getOperand(1)))
  5377. continue;
  5378. Function *Ctor = cast<Function>(CS->getOperand(1));
  5379. assert(Ctor->getReturnType()->isVoidTy() && Ctor->arg_size() == 0 &&
  5380. "function type must be void (void)");
  5381. // Add Ctor.
  5382. entryAndInitFunctionSet.insert(Ctor);
  5383. }
  5384. }
  5385. }
  5386. // Lower static global into allocas.
  5387. std::vector<GlobalVariable *> staticGVs;
  5388. for (GlobalVariable &GV : M.globals()) {
  5389. // only for non-constant static globals
  5390. if (!dxilutil::IsStaticGlobal(&GV) || GV.isConstant())
  5391. continue;
  5392. // Skip dx.ishelper
  5393. if (GV.getName().compare(DXIL::kDxIsHelperGlobalName) == 0)
  5394. continue;
  5395. // Skip if GV used in functions other than entry.
  5396. if (!usedOnlyInEntry(&GV, entryAndInitFunctionSet))
  5397. continue;
  5398. Type *EltTy = GV.getType()->getElementType();
  5399. if (!EltTy->isAggregateType()) {
  5400. staticGVs.emplace_back(&GV);
  5401. } else {
  5402. EltTy = dxilutil::GetArrayEltTy(EltTy);
  5403. // Lower static [array of] resources
  5404. if (dxilutil::IsHLSLObjectType(EltTy) ||
  5405. EltTy == handleTy) {
  5406. staticGVs.emplace_back(&GV);
  5407. }
  5408. }
  5409. }
  5410. bool bUpdated = false;
  5411. const DataLayout &DL = M.getDataLayout();
  5412. // Create AI for each GV in each entry.
  5413. // Replace all users of GV with AI.
  5414. // Collect all users of GV within each entry.
  5415. // Remove unused AI in the end.
  5416. for (GlobalVariable *GV : staticGVs) {
  5417. bUpdated |= lowerStaticGlobalIntoAlloca(GV, DL, *pTypeSys, entryAndInitFunctionSet);
  5418. }
  5419. return bUpdated;
  5420. }
  5421. private:
  5422. bool lowerStaticGlobalIntoAlloca(GlobalVariable *GV, const DataLayout &DL,
  5423. DxilTypeSystem &typeSys,
  5424. SetVector<Function *> &entryAndInitFunctionSet);
  5425. bool usedOnlyInEntry(Value *V, SetVector<Function *> &entryAndInitFunctionSet);
  5426. };
  5427. }
  5428. // Go through the base type chain of TyA and see if
  5429. // we eventually get to TyB
  5430. //
  5431. // Note: Not necessarily about inheritance. Could be
  5432. // typedef, const type, ref type, MEMBER type (TyA
  5433. // being a member of TyB).
  5434. //
  5435. static bool IsDerivedTypeOf(DIType *TyA, DIType *TyB) {
  5436. DITypeIdentifierMap EmptyMap;
  5437. while (TyA) {
  5438. if (DIDerivedType *Derived = dyn_cast<DIDerivedType>(TyA)) {
  5439. if (Derived->getBaseType() == TyB)
  5440. return true;
  5441. else
  5442. TyA = Derived->getBaseType().resolve(EmptyMap);
  5443. }
  5444. else {
  5445. break;
  5446. }
  5447. }
  5448. return false;
  5449. }
  5450. // See if 'DGV' a member type of some other variable, and return that variable
  5451. // and the offset and size DGV is into it.
  5452. //
  5453. // If DGV is not a member, just return nullptr.
  5454. //
  5455. static DIGlobalVariable *FindGlobalVariableFragment(const DebugInfoFinder &DbgFinder, DIGlobalVariable *DGV, unsigned *Out_OffsetInBits, unsigned *Out_SizeInBits) {
  5456. DITypeIdentifierMap EmptyMap;
  5457. StringRef FullName = DGV->getName();
  5458. size_t FirstDot = FullName.find_first_of('.');
  5459. if (FirstDot == StringRef::npos)
  5460. return nullptr;
  5461. StringRef BaseName = FullName.substr(0, FirstDot);
  5462. assert(BaseName.size());
  5463. DIType *Ty = DGV->getType().resolve(EmptyMap);
  5464. assert(isa<DIDerivedType>(Ty) && Ty->getTag() == dwarf::DW_TAG_member);
  5465. DIGlobalVariable *FinalResult = nullptr;
  5466. for (DIGlobalVariable *DGV_It : DbgFinder.global_variables()) {
  5467. if (DGV_It->getName() == BaseName &&
  5468. IsDerivedTypeOf(Ty, DGV_It->getType().resolve(EmptyMap)))
  5469. {
  5470. FinalResult = DGV_It;
  5471. break;
  5472. }
  5473. }
  5474. if (FinalResult) {
  5475. *Out_OffsetInBits = Ty->getOffsetInBits();
  5476. *Out_SizeInBits = Ty->getSizeInBits();
  5477. }
  5478. return FinalResult;
  5479. }
  5480. // Create a fake local variable for the GlobalVariable GV that has just been
  5481. // lowered to local Alloca.
  5482. //
  5483. static
  5484. void PatchDebugInfo(DebugInfoFinder &DbgFinder, Function *F, GlobalVariable *GV, AllocaInst *AI) {
  5485. if (!DbgFinder.compile_unit_count())
  5486. return;
  5487. // Find the subprogram for function
  5488. DISubprogram *Subprogram = nullptr;
  5489. for (DISubprogram *SP : DbgFinder.subprograms()) {
  5490. if (SP->getFunction() == F) {
  5491. Subprogram = SP;
  5492. break;
  5493. }
  5494. }
  5495. DIGlobalVariable *DGV = dxilutil::FindGlobalVariableDebugInfo(GV, DbgFinder);
  5496. if (!DGV)
  5497. return;
  5498. DITypeIdentifierMap EmptyMap;
  5499. DIBuilder DIB(*GV->getParent());
  5500. DIScope *Scope = Subprogram;
  5501. DebugLoc Loc = DebugLoc::get(0, 0, Scope);
  5502. // If the variable is a member of another variable, find the offset and size
  5503. bool IsFragment = false;
  5504. unsigned OffsetInBits = 0,
  5505. SizeInBits = 0;
  5506. if (DIGlobalVariable *UnsplitDGV = FindGlobalVariableFragment(DbgFinder, DGV, &OffsetInBits, &SizeInBits)) {
  5507. DGV = UnsplitDGV;
  5508. IsFragment = true;
  5509. }
  5510. std::string Name = "global.";
  5511. Name += DGV->getName();
  5512. // Using arg_variable instead of auto_variable because arg variables can use
  5513. // Subprogram as its scope, so we don't have to make one up for it.
  5514. llvm::dwarf::Tag Tag = llvm::dwarf::Tag::DW_TAG_arg_variable;
  5515. DIType *Ty = DGV->getType().resolve(EmptyMap);
  5516. DXASSERT(Ty->getTag() != dwarf::DW_TAG_member, "Member type is not allowed for variables.");
  5517. DILocalVariable *ConvertedLocalVar =
  5518. DIB.createLocalVariable(Tag, Scope,
  5519. Name, DGV->getFile(), DGV->getLine(), Ty);
  5520. DIExpression *Expr = nullptr;
  5521. if (IsFragment) {
  5522. Expr = DIB.createBitPieceExpression(OffsetInBits, SizeInBits);
  5523. }
  5524. else {
  5525. Expr = DIB.createExpression(ArrayRef<int64_t>());
  5526. }
  5527. DIB.insertDeclare(AI, ConvertedLocalVar, Expr, Loc, AI->getNextNode());
  5528. }
  5529. //Collect instructions using GV and the value used by the instruction.
  5530. //For direct use, the value == GV
  5531. //For constant operator like GEP/Bitcast, the value is the operator used by the instruction.
  5532. //This requires recursion to unwrap nested constant operators using the GV.
  5533. static void collectGVInstUsers(Value *V,
  5534. DenseMap<Instruction *, Value *> &InstUserMap) {
  5535. for (User *U : V->users()) {
  5536. if (Instruction *I = dyn_cast<Instruction>(U)) {
  5537. InstUserMap[I] = V;
  5538. } else {
  5539. collectGVInstUsers(U, InstUserMap);
  5540. }
  5541. }
  5542. }
  5543. static Instruction *replaceGVUseWithAI(GlobalVariable *GV, AllocaInst *AI,
  5544. Value *U, IRBuilder<> &B) {
  5545. if (U == GV)
  5546. return AI;
  5547. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  5548. Instruction *PtrInst =
  5549. replaceGVUseWithAI(GV, AI, GEP->getPointerOperand(), B);
  5550. SmallVector<Value *, 2> Index(GEP->idx_begin(), GEP->idx_end());
  5551. return cast<Instruction>(B.CreateGEP(PtrInst, Index));
  5552. }
  5553. if (BitCastOperator *BCO = dyn_cast<BitCastOperator>(U)) {
  5554. Instruction *SrcInst = replaceGVUseWithAI(GV, AI, BCO->getOperand(0), B);
  5555. return cast<Instruction>(B.CreateBitCast(SrcInst, BCO->getType()));
  5556. }
  5557. DXASSERT(false, "unsupported user of static global");
  5558. return nullptr;
  5559. }
  5560. bool LowerStaticGlobalIntoAlloca::lowerStaticGlobalIntoAlloca(
  5561. GlobalVariable *GV, const DataLayout &DL, DxilTypeSystem &typeSys,
  5562. SetVector<Function *> &entryAndInitFunctionSet) {
  5563. GV->removeDeadConstantUsers();
  5564. // Create alloca for each entry.
  5565. DenseMap<Function *, AllocaInst *> allocaMap;
  5566. for (Function *F : entryAndInitFunctionSet) {
  5567. IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(F));
  5568. AllocaInst *AI = Builder.CreateAlloca(GV->getType()->getElementType());
  5569. allocaMap[F] = AI;
  5570. // Store initializer is exist.
  5571. if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
  5572. Builder.CreateStore(GV->getInitializer(), GV);
  5573. }
  5574. }
  5575. DenseMap<Instruction *, Value *> InstUserMap;
  5576. collectGVInstUsers(GV, InstUserMap);
  5577. for (auto it : InstUserMap) {
  5578. Instruction *I = it.first;
  5579. Value *U = it.second;
  5580. Function *F = I->getParent()->getParent();
  5581. AllocaInst *AI = allocaMap[F];
  5582. IRBuilder<> B(I);
  5583. Instruction *UI = replaceGVUseWithAI(GV, AI, U, B);
  5584. I->replaceUsesOfWith(U, UI);
  5585. }
  5586. for (Function *F : entryAndInitFunctionSet) {
  5587. AllocaInst *AI = allocaMap[F];
  5588. if (AI->user_empty())
  5589. AI->eraseFromParent();
  5590. else
  5591. PatchDebugInfo(m_DbgFinder, F, GV, AI);
  5592. }
  5593. GV->removeDeadConstantUsers();
  5594. if (GV->user_empty())
  5595. GV->eraseFromParent();
  5596. return true;
  5597. }
  5598. bool LowerStaticGlobalIntoAlloca::usedOnlyInEntry(
  5599. Value *V, SetVector<Function *> &entryAndInitFunctionSet) {
  5600. bool bResult = true;
  5601. for (User *U : V->users()) {
  5602. if (Instruction *I = dyn_cast<Instruction>(U)) {
  5603. Function *F = I->getParent()->getParent();
  5604. if (entryAndInitFunctionSet.count(F) == 0) {
  5605. bResult = false;
  5606. break;
  5607. }
  5608. } else {
  5609. bResult = usedOnlyInEntry(U, entryAndInitFunctionSet);
  5610. if (!bResult)
  5611. break;
  5612. }
  5613. }
  5614. return bResult;
  5615. }
  5616. char LowerStaticGlobalIntoAlloca::ID = 0;
  5617. INITIALIZE_PASS(LowerStaticGlobalIntoAlloca, "static-global-to-alloca",
  5618. "Lower static global into Alloca", false,
  5619. false)
  5620. // Public interface to the LowerStaticGlobalIntoAlloca pass
  5621. ModulePass *llvm::createLowerStaticGlobalIntoAlloca() {
  5622. return new LowerStaticGlobalIntoAlloca();
  5623. }